From 8b6fc070199585431cf3e2fe6dcb53cdb1046460 Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Mon, 23 Mar 2026 20:16:59 +0800 Subject: [PATCH 01/34] test(workflow): improve dataset item tests with edit and remove functionality (#33937) --- .../__tests__/code-block.spec.tsx | 8 ++ .../base/markdown-blocks/code-block.tsx | 46 ++++++++--- .../__tests__/integration.spec.tsx | 79 +++++++++++++------ .../components/dataset-item.tsx | 4 + 4 files changed, 102 insertions(+), 35 deletions(-) diff --git a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx index 308232fd0f..745b7657d7 100644 --- a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx @@ -21,6 +21,8 @@ let clientWidthSpy: { mockRestore: () => void } | null = null let clientHeightSpy: { mockRestore: () => void } | null = null let offsetWidthSpy: { mockRestore: () => void } | null = null let offsetHeightSpy: { mockRestore: () => void } | null = null +let consoleErrorSpy: ReturnType | null = null +let consoleWarnSpy: ReturnType | null = null type AudioContextCtor = new () => unknown type WindowWithLegacyAudio = Window & { @@ -83,6 +85,8 @@ describe('CodeBlock', () => { beforeEach(() => { vi.clearAllMocks() mockUseTheme.mockReturnValue({ theme: Theme.light }) + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) clientWidthSpy = vi.spyOn(HTMLElement.prototype, 'clientWidth', 'get').mockReturnValue(900) clientHeightSpy = vi.spyOn(HTMLElement.prototype, 'clientHeight', 'get').mockReturnValue(400) offsetWidthSpy = vi.spyOn(HTMLElement.prototype, 'offsetWidth', 'get').mockReturnValue(900) @@ -98,6 +102,10 @@ describe('CodeBlock', () => { afterEach(() => { vi.useRealTimers() + consoleErrorSpy?.mockRestore() + consoleWarnSpy?.mockRestore() + consoleErrorSpy = null + consoleWarnSpy = null clientWidthSpy?.mockRestore() clientHeightSpy?.mockRestore() offsetWidthSpy?.mockRestore() diff --git a/web/app/components/base/markdown-blocks/code-block.tsx b/web/app/components/base/markdown-blocks/code-block.tsx index b36d8d7788..412c61d52d 100644 --- a/web/app/components/base/markdown-blocks/code-block.tsx +++ b/web/app/components/base/markdown-blocks/code-block.tsx @@ -85,13 +85,30 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any const processedRef = useRef(false) // Track if content was successfully processed const isInitialRenderRef = useRef(true) // Track if this is initial render const chartInstanceRef = useRef(null) // Direct reference to ECharts instance - const resizeTimerRef = useRef(null) // For debounce handling + const resizeTimerRef = useRef | null>(null) // For debounce handling + const chartReadyTimerRef = useRef | null>(null) const finishedEventCountRef = useRef(0) // Track finished event trigger count const match = /language-(\w+)/.exec(className || '') const language = match?.[1] const languageShowName = getCorrectCapitalizationLanguageName(language || '') const isDarkMode = theme === Theme.dark + const clearResizeTimer = useCallback(() => { + if (!resizeTimerRef.current) + return + + clearTimeout(resizeTimerRef.current) + resizeTimerRef.current = null + }, []) + + const clearChartReadyTimer = useCallback(() => { + if (!chartReadyTimerRef.current) + return + + clearTimeout(chartReadyTimerRef.current) + chartReadyTimerRef.current = null + }, []) + const echartsStyle = useMemo(() => ({ height: '350px', width: '100%', @@ -104,26 +121,27 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any // Debounce resize operations const debouncedResize = useCallback(() => { - if (resizeTimerRef.current) - clearTimeout(resizeTimerRef.current) + clearResizeTimer() resizeTimerRef.current = setTimeout(() => { if (chartInstanceRef.current) chartInstanceRef.current.resize() resizeTimerRef.current = null }, 200) - }, []) + }, [clearResizeTimer]) // Handle ECharts instance initialization const handleChartReady = useCallback((instance: any) => { chartInstanceRef.current = instance // Force resize to ensure timeline displays correctly - setTimeout(() => { + clearChartReadyTimer() + chartReadyTimerRef.current = setTimeout(() => { if (chartInstanceRef.current) chartInstanceRef.current.resize() + chartReadyTimerRef.current = null }, 200) - }, []) + }, [clearChartReadyTimer]) // Store event handlers in useMemo to avoid recreating them const echartsEvents = useMemo(() => ({ @@ -157,10 +175,20 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any return () => { window.removeEventListener('resize', handleResize) - if (resizeTimerRef.current) - clearTimeout(resizeTimerRef.current) + clearResizeTimer() + clearChartReadyTimer() + chartInstanceRef.current = null } - }, [language, debouncedResize]) + }, [language, debouncedResize, clearResizeTimer, clearChartReadyTimer]) + + useEffect(() => { + return () => { + clearResizeTimer() + clearChartReadyTimer() + chartInstanceRef.current = null + echartsRef.current = null + } + }, [clearResizeTimer, clearChartReadyTimer]) // Process chart data when content changes useEffect(() => { // Only process echarts content diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/__tests__/integration.spec.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/__tests__/integration.spec.tsx index dbf201670b..b9f2b17bb2 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/__tests__/integration.spec.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/__tests__/integration.spec.tsx @@ -4,8 +4,9 @@ import type { MetadataShape, } from '../types' import type { DataSet, MetadataInDoc } from '@/models/datasets' -import { fireEvent, render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' +import { useEffect, useRef } from 'react' import { ChunkingMode, DatasetPermission, @@ -173,17 +174,26 @@ vi.mock('@/app/components/app/configuration/dataset-config/select-dataset', () = vi.mock('@/app/components/app/configuration/dataset-config/settings-modal', () => ({ __esModule: true, - default: ({ currentDataset, onSave, onCancel }: { currentDataset: DataSet, onSave: (dataset: DataSet) => void, onCancel: () => void }) => ( -
-
{currentDataset.name}
- - -
- ), + default: function MockSettingsModal({ currentDataset, onSave, onCancel }: { currentDataset: DataSet, onSave: (dataset: DataSet) => void, onCancel: () => void }) { + const hasSavedRef = useRef(false) + + useEffect(() => { + if (hasSavedRef.current) + return + + hasSavedRef.current = true + onSave(createDataset({ ...currentDataset, name: 'Updated Dataset' })) + }, [currentDataset, onSave]) + + return ( +
+
{currentDataset.name}
+ +
+ ) + }, })) vi.mock('@/app/components/app/configuration/dataset-config/params-config/config-content', () => ({ @@ -265,6 +275,13 @@ vi.mock('../components/metadata/metadata-panel', () => ({ })) describe('knowledge-retrieval path', () => { + const getDatasetItem = () => { + const datasetItem = screen.getByText('Dataset Name').closest('.group\\/dataset-item') + if (!(datasetItem instanceof HTMLElement)) + throw new Error('Dataset item container not found') + return datasetItem + } + beforeEach(() => { vi.clearAllMocks() mockHasEditPermissionForDataset.mockReturnValue(true) @@ -293,33 +310,43 @@ describe('knowledge-retrieval path', () => { ]) }) - it('should support editing and removing a dataset item', async () => { - const user = userEvent.setup() + it('should support editing a dataset item', async () => { const onChange = vi.fn() - const onRemove = vi.fn() render( , ) expect(screen.getByText('Dataset Name')).toBeInTheDocument() - fireEvent.mouseOver(screen.getByText('Dataset Name').closest('.group\\/dataset-item')!) + const datasetItem = getDatasetItem() + fireEvent.click(within(datasetItem).getByRole('button', { name: 'common.operation.edit' })) - const buttons = screen.getAllByRole('button') - await user.click(buttons[0]!) - await user.click(screen.getByText('save-settings')) - await user.click(buttons[1]!) + await waitFor(() => { + expect(onChange).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated Dataset' })) + }) + }) - expect(onChange).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated Dataset' })) + it('should support removing a dataset item', () => { + const onRemove = vi.fn() + + render( + , + ) + + const datasetItem = getDatasetItem() + fireEvent.click(within(datasetItem).getByRole('button', { name: 'common.operation.remove' })) expect(onRemove).toHaveBeenCalled() }) - it('should render empty and populated dataset lists', async () => { - const user = userEvent.setup() + it('should render empty and populated dataset lists', () => { const onChange = vi.fn() const { rerender } = render( @@ -338,8 +365,8 @@ describe('knowledge-retrieval path', () => { />, ) - fireEvent.mouseOver(screen.getByText('Dataset Name').closest('.group\\/dataset-item')!) - await user.click(screen.getAllByRole('button')[1]!) + const datasetItem = getDatasetItem() + fireEvent.click(within(datasetItem).getByRole('button', { name: 'common.operation.remove' })) expect(onChange).toHaveBeenCalledWith([]) }) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/dataset-item.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/dataset-item.tsx index c865a49ba9..f0f0d3191a 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/dataset-item.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/dataset-item.tsx @@ -85,6 +85,8 @@ const DatasetItem: FC = ({ { editable && ( { e.stopPropagation() showSettingsModal() @@ -95,6 +97,8 @@ const DatasetItem: FC = ({ ) } setIsDeleteHovered(true)} From d956b919a0797fbd3d1a59300de2e712462b2e02 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 23 Mar 2026 21:27:14 +0900 Subject: [PATCH 02/34] ci: fix AttributeError: 'Flask' object has no attribute 'login_manager' FAILED #33891 (#33896) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../controllers/console/app/test_message.py | 342 ++++++++++++++ .../controllers/console/app/test_statistic.py | 334 ++++++++++++++ .../app/test_workflow_draft_variable.py | 415 +++++++++++++++++ .../auth/test_data_source_bearer_auth.py | 131 ++++++ .../console/auth/test_data_source_oauth.py | 120 +++++ .../console/auth/test_oauth_server.py | 365 +++++++++++++++ .../controllers/console/helpers.py | 85 ++++ .../controllers/console/app/test_message.py | 320 -------------- .../controllers/console/app/test_statistic.py | 275 ------------ .../app/test_workflow_draft_variable.py | 313 ------------- .../auth/test_data_source_bearer_auth.py | 209 --------- .../console/auth/test_data_source_oauth.py | 192 -------- .../console/auth/test_oauth_server.py | 417 ------------------ 13 files changed, 1792 insertions(+), 1726 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/controllers/console/app/test_message.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/helpers.py delete mode 100644 api/tests/unit_tests/controllers/console/app/test_message.py delete mode 100644 api/tests/unit_tests/controllers/console/app/test_statistic.py delete mode 100644 api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py delete mode 100644 api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py delete mode 100644 api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py delete mode 100644 api/tests/unit_tests/controllers/console/auth/test_oauth_server.py diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py new file mode 100644 index 0000000000..6b51ec98bc --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -0,0 +1,342 @@ +"""Authenticated controller integration tests for console message APIs.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload +from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation: + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Test Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + account_id: str, + *, + created_at_offset_seconds: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=1, + message_unit_price=Decimal("0.0001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=1, + answer_unit_price=Decimal("0.0001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=0, + total_price=Decimal("0.0002"), + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +class TestMessageValidators: + def test_chat_messages_query_validators(self) -> None: + assert ChatMessagesQuery.empty_to_none("") is None + assert ChatMessagesQuery.empty_to_none("val") == "val" + assert ChatMessagesQuery.validate_uuid(None) is None + assert ( + ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_message_feedback_validators(self) -> None: + assert ( + MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_feedback_export_validators(self) -> None: + assert FeedbackExportQuery.parse_bool(None) is None + assert FeedbackExportQuery.parse_bool(True) is True + assert FeedbackExportQuery.parse_bool("1") is True + assert FeedbackExportQuery.parse_bool("0") is False + assert FeedbackExportQuery.parse_bool("off") is False + + with pytest.raises(ValueError): + FeedbackExportQuery.parse_bool("invalid") + + +def test_chat_message_list_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": str(uuid4())}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_chat_message_list_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0) + second = _create_message( + db_session_with_containers, + app.id, + conversation.id, + account.id, + created_at_offset_seconds=1, + ) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": conversation.id, "limit": 1}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["limit"] == 1 + assert payload["has_more"] is True + assert len(payload["data"]) == 1 + assert payload["data"][0]["id"] == second.id + + +def test_message_feedback_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": str(uuid4()), "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_message_feedback_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": message.id, "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + feedback = db_session_with_containers.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == message.id) + ) + assert feedback is not None + assert feedback.rating == FeedbackRating.LIKE + assert feedback.from_account_id == account.id + + +def test_message_annotation_count( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + db_session_with_containers.add( + MessageAnnotation( + app_id=app.id, + conversation_id=conversation.id, + message_id=message.id, + question="Q", + content="A", + account_id=account.id, + ) + ) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/annotations/count", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"count": 1} + + +def test_message_suggested_questions_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"data": ["q1", "q2"]} + + +@pytest.mark.parametrize( + ("exc", "expected_status", "expected_code"), + [ + (MessageNotExistsError(), 404, "not_found"), + (ConversationNotExistsError(), 404, "not_found"), + (ProviderTokenNotInitError(), 400, "provider_not_initialize"), + (QuotaExceededError(), 400, "provider_quota_exceeded"), + (ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"), + (SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"), + (Exception(), 500, "internal_server_error"), + ], +) +def test_message_suggested_questions_errors( + exc: Exception, + expected_status: int, + expected_code: str, + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + side_effect=exc, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == expected_status + payload = response.get_json() + assert payload is not None + assert payload["code"] == expected_code + + +def test_message_feedback_export_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/feedbacks/export", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"exported": True} + + +def test_message_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/messages/{message.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == message.id diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py new file mode 100644 index 0000000000..963cfe53e5 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py @@ -0,0 +1,334 @@ +"""Controller integration tests for console statistic routes.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageFeedback +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation( + db_session: Session, + app_id: str, + account_id: str, + *, + mode: AppMode, + created_at_offset_days: int = 0, +) -> Conversation: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Stats Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + *, + from_account_id: str | None, + from_end_user_id: str | None = None, + message_tokens: int = 1, + answer_tokens: int = 1, + total_price: Decimal = Decimal("0.01"), + provider_response_latency: float = 1.0, + created_at_offset_days: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=message_tokens, + message_unit_price=Decimal("0.001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=answer_tokens, + answer_unit_price=Decimal("0.001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=provider_response_latency, + total_price=total_price, + currency="USD", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=from_end_user_id, + from_account_id=from_account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +def _create_like_feedback( + db_session: Session, + app_id: str, + conversation_id: str, + message_id: str, + account_id: str, +) -> None: + db_session.add( + MessageFeedback( + app_id=app_id, + conversation_id=conversation_id, + message_id=message_id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, + from_account_id=account_id, + ) + ) + db_session.commit() + + +def test_daily_message_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["message_count"] == 1 + + +def test_daily_conversation_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-conversations", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["conversation_count"] == 1 + + +def test_daily_terminals_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=None, + from_end_user_id=str(uuid4()), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-end-users", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["terminal_count"] == 1 + + +def test_daily_token_cost_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + message_tokens=40, + answer_tokens=60, + total_price=Decimal("0.02"), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/token-costs", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["data"][0]["token_count"] == 100 + assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02") + + +def test_average_session_interaction_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-session-interactions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["interactions"] == 2.0 + + +def test_user_satisfaction_rate_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + for _ in range(9): + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["rate"] == 100.0 + + +def test_average_response_time_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + provider_response_latency=1.234, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-response-time", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["latency"] == 1234.0 + + +def test_tokens_per_second_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + answer_tokens=31, + provider_response_latency=2.0, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/tokens-per-second", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["tps"] == 15.5 + + +def test_invalid_time_range( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "Invalid time" + + +def test_time_range_params_passed( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + import datetime + + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + start = datetime.datetime.now() + end = datetime.datetime.now() + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse: + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + mock_parse.assert_called_once_with("something", "something", "UTC") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 0000000000..f037ad77c0 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,415 @@ +"""Authenticated controller integration tests for workflow draft variable APIs.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from dify_graph.variables.segments import StringSegment +from factories.variable_factory import segment_to_variable +from models import Workflow +from models.model import AppMode +from models.workflow import WorkflowDraftVariable +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_draft_workflow( + db_session: Session, + app_id: str, + tenant_id: str, + account_id: str, + *, + environment_variables: list | None = None, + conversation_variables: list | None = None, +) -> Workflow: + workflow = Workflow.new( + tenant_id=tenant_id, + app_id=app_id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph='{"nodes": [], "edges": []}', + features="{}", + created_by=account_id, + environment_variables=environment_variables or [], + conversation_variables=conversation_variables or [], + rag_pipeline_variables=[], + ) + db_session.add(workflow) + db_session.commit() + return workflow + + +def _create_node_variable( + db_session: Session, + app_id: str, + user_id: str, + *, + node_id: str = "node_1", + name: str = "test_var", +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_node_variable( + app_id=app_id, + user_id=user_id, + node_id=node_id, + name=name, + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + visible=True, + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _create_system_variable( + db_session: Session, app_id: str, user_id: str, name: str = "query" +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app_id, + user_id=user_id, + name=name, + value=StringSegment(value="system-value"), + node_execution_id=str(uuid.uuid4()), + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _build_environment_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[ENVIRONMENT_VARIABLE_NODE_ID, name], + name=name, + description=f"Environment variable {name}", + ) + + +def _build_conversation_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[CONVERSATION_VARIABLE_NODE_ID, name], + name=name, + description=f"Conversation variable {name}", + ) + + +def test_workflow_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"items": [], "total": 0} + + +def test_workflow_variable_collection_get_not_exist( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "draft_workflow_not_exist" + + +def test_workflow_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_node_variable(db_session_with_containers, app.id, account.id) + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var") + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + remaining = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + ) + ).all() + assert remaining == [] + + +def test_node_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other") + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [node_variable.id] + + +def test_node_variable_collection_get_invalid_node_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "invalid_param" + + +def test_node_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456") + target_id = target.id + untouched_id = untouched.id + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id)) + is None + ) + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id)) + is not None + ) + + +def test_variable_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "test_var" + + +def test_variable_api_get_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_variable_api_patch_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.patch( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + json={"name": "renamed_var"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "renamed_var" + + refreshed = db_session_with_containers.scalar( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id) + ) + assert refreshed is not None + assert refreshed.name == "renamed_var" + + +def test_variable_api_delete_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_variable_reset_api_put_success_returns_no_content_without_execution( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.put( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_conversation_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + conversation_variables=[_build_conversation_variable("session_name", "Alice")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/conversation-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["name"] for item in payload["items"]] == ["session_name"] + + created = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID, + ) + ).all() + assert len(created) == 1 + + +def test_system_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + variable = _create_system_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/system-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [variable.id] + + +def test_environment_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + environment_variables=[_build_environment_variable("api_key", "secret-value")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/environment-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["items"][0]["name"] == "api_key" + assert payload["items"][0]["value"] == "secret-value" diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..00309c25d6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,131 @@ +"""Controller integration tests for API key data source auth routes.""" + +import json +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.source import DataSourceApiKeyAuthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_api_key_auth_data_source( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert len(payload["sources"]) == 1 + assert payload["sources"][0]["provider"] == "custom_provider" + + +def test_get_api_key_auth_data_source_empty( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"sources": []} + + +def test_create_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + +def test_create_binding_failure( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth", + side_effect=ValueError("Invalid structure"), + ), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 500 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "auth_failed" + assert payload["message"] == "Invalid structure" + + +def test_delete_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.delete( + f"/console/api/api-key-auth/data-source/{binding.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id) + ) + is None + ) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py new file mode 100644 index 0000000000..81b5423261 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py @@ -0,0 +1,120 @@ +"""Controller integration tests for console OAuth data source routes.""" + +from unittest.mock import MagicMock, patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.source import DataSourceOauthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_oauth_url_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + provider = MagicMock() + provider.get_authorization_url.return_value = "http://oauth.provider/auth" + + with ( + patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}), + patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None), + ): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/notion", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert tenant.id == account.current_tenant_id + assert response.status_code == 200 + assert response.get_json() == {"data": "http://oauth.provider/auth"} + provider.get_authorization_url.assert_called_once() + + +def test_get_oauth_url_invalid_provider( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/unknown_provider", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code") + + assert response.status_code == 302 + assert "code=mock_code" in response.location + + +def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion") + + assert response.status_code == 302 + assert "error=Access%20denied" in response.location + + +def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None: + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123") + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.get_access_token.assert_called_once_with("auth_code_123") + + +def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid code"} + + +def test_sync_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceOauthBinding( + tenant_id=tenant.id, + access_token="test-access-token", + provider="notion", + source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []}, + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get( + f"/console/api/oauth/data-source/notion/{binding.id}/sync", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.sync_data_source.assert_called_once_with(binding.id) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..2ef27133d8 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,365 @@ +"""Controller integration tests for console OAuth server routes.""" + +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.model import OAuthProviderApp +from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + ensure_dify_setup, +) + + +def _build_oauth_provider_app() -> OAuthProviderApp: + return OAuthProviderApp( + app_icon="icon_url", + client_id="test_client_id", + client_secret="test_secret", + app_label={"en-US": "Test App"}, + redirect_uris=["http://localhost/callback"], + scope="read,write", + ) + + +def test_oauth_provider_successful_post( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["app_icon"] == "icon_url" + assert payload["app_label"] == {"en-US": "Test App"} + assert payload["scope"] == "read,write" + + +def test_oauth_provider_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert "redirect_uri is invalid" in payload["message"] + + +def test_oauth_provider_invalid_client_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert "client_id is invalid" in payload["message"] + + +def test_oauth_authorize_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="auth_code_123", + ) as mock_sign, + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/authorize", + json={"client_id": "test_client_id"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"code": "auth_code_123"} + mock_sign.assert_called_once_with("test_client_id", account.id) + + +def test_oauth_token_authorization_code_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("access_123", "refresh_123"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "access_123", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "refresh_123", + } + + +def test_oauth_token_authorization_code_grant_missing_code( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "code is required" + + +def test_oauth_token_authorization_code_grant_invalid_secret( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "invalid_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "client_secret is invalid" + + +def test_oauth_token_authorization_code_grant_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://invalid/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "redirect_uri is invalid" + + +def test_oauth_token_refresh_token_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("new_access", "new_refresh"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "new_access", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "new_refresh", + } + + +def test_oauth_token_refresh_token_grant_missing_token( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "refresh_token is required" + + +def test_oauth_token_invalid_grant_type( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "invalid_grant"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "invalid grant_type" + + +def test_oauth_account_successful_retrieval( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + account.avatar = "avatar_url" + db_session_with_containers.commit() + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token", + return_value=account, + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer valid_access_token"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "name": "Test User", + "email": account.email, + "avatar": "avatar_url", + "interface_language": "en-US", + "timezone": "UTC", + } + + +def test_oauth_account_missing_authorization_header( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Authorization header is required"} + + +def test_oauth_account_invalid_authorization_header_format( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "InvalidFormat"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Invalid Authorization header format"} diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py new file mode 100644 index 0000000000..9e2084f393 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -0,0 +1,85 @@ +"""Shared helpers for authenticated console controller integration tests.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from constants import HEADER_NAME_CSRF_TOKEN +from libs.datetime_utils import naive_utc_now +from libs.token import _real_cookie_name, generate_csrf_token +from models import Account, DifySetup, Tenant, TenantAccountJoin +from models.account import AccountStatus, TenantAccountRole +from models.model import App, AppMode +from services.account_service import AccountService + + +def ensure_dify_setup(db_session: Session) -> None: + """Create a setup marker once so setup-protected console routes can be exercised.""" + if db_session.scalar(select(DifySetup).limit(1)) is not None: + return + + db_session.add(DifySetup(version=dify_config.project.version)) + db_session.commit() + + +def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: + """Create an initialized owner account with a current tenant.""" + account = Account( + email=f"test-{uuid.uuid4()}@example.com", + name="Test User", + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.initialized_at = naive_utc_now() + db_session.add(account) + db_session.commit() + + tenant = Tenant(name="Test Tenant", status="normal") + db_session.add(tenant) + db_session.commit() + + db_session.add( + TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + ) + db_session.commit() + + account.set_tenant_id(tenant.id) + account.timezone = "UTC" + db_session.commit() + + ensure_dify_setup(db_session) + return account, tenant + + +def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App: + """Create a minimal app row that can be loaded by get_app_model.""" + app = App( + tenant_id=tenant_id, + name="Test App", + mode=mode, + enable_site=True, + enable_api=True, + created_by=account_id, + ) + db_session.add(app) + db_session.commit() + return app + + +def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]: + """Attach console auth cookies/headers for endpoints guarded by login_required.""" + access_token = AccountService.get_account_jwt_token(account) + csrf_token = generate_csrf_token(account.id) + test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost") + return { + "Authorization": f"Bearer {access_token}", + HEADER_NAME_CSRF_TOKEN: csrf_token, + } diff --git a/api/tests/unit_tests/controllers/console/app/test_message.py b/api/tests/unit_tests/controllers/console/app/test_message.py deleted file mode 100644 index e6dfc0d3bd..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_message.py +++ /dev/null @@ -1,320 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.exceptions import InternalServerError, NotFound -from werkzeug.local import LocalProxy - -from controllers.console.app.error import ( - ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, - ProviderQuotaExceededError, -) -from controllers.console.app.message import ( - ChatMessageListApi, - ChatMessagesQuery, - FeedbackExportQuery, - MessageAnnotationCountApi, - MessageApi, - MessageFeedbackApi, - MessageFeedbackExportApi, - MessageFeedbackPayload, - MessageSuggestedQuestionApi, -) -from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from models import App, AppMode -from services.errors.conversation import ConversationNotExistsError -from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.CHAT - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -import contextlib - - -@contextlib.contextmanager -def setup_test_context( - test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None -): - with ( - patch("extensions.ext_database.db") as mock_db, - patch("controllers.console.app.wraps.db", mock_db), - patch("controllers.console.wraps.db", mock_db), - patch("controllers.console.app.message.db", mock_db), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - # Set up a generic query mock that usually returns mock_app_model when getting app - app_query_mock = MagicMock() - app_query_mock.filter.return_value.first.return_value = mock_app_model - app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model - app_query_mock.where.return_value.first.return_value = mock_app_model - app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model - - data_query_mock = MagicMock() - - def query_side_effect(*args, **kwargs): - if args and hasattr(args[0], "__name__") and args[0].__name__ == "App": - return app_query_mock - return data_query_mock - - mock_db.session.query.side_effect = query_side_effect - mock_db.data_query = data_query_mock - - # Let the caller override the stat db query logic - proxy_mock = LocalProxy(lambda: mock_account) - - query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()]) - full_path = f"{route_path}?{query_string}" if qs else route_path - - with ( - patch("libs.login.current_user", proxy_mock), - patch("flask_login.current_user", proxy_mock), - patch("controllers.console.app.message.attach_message_extra_contents", return_value=None), - ): - with test_app.test_request_context(full_path, method=method, json=payload): - request.view_args = {"app_id": "app_123"} - - if "suggested-questions" in route_path: - # simplistic extraction for message_id - parts = route_path.split("chat-messages/") - if len(parts) > 1: - request.view_args["message_id"] = parts[1].split("/")[0] - elif "messages/" in route_path and "chat-messages" not in route_path: - parts = route_path.split("messages/") - if len(parts) > 1: - request.view_args["message_id"] = parts[1].split("/")[0] - - api_instance = endpoint_class() - - # Check if it has a dispatch_request or method - if hasattr(api_instance, method.lower()): - yield api_instance, mock_db, request.view_args - - -class TestMessageValidators: - def test_chat_messages_query_validators(self): - # Test empty_to_none - assert ChatMessagesQuery.empty_to_none("") is None - assert ChatMessagesQuery.empty_to_none("val") == "val" - - # Test validate_uuid - assert ChatMessagesQuery.validate_uuid(None) is None - assert ( - ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") - == "123e4567-e89b-12d3-a456-426614174000" - ) - - def test_message_feedback_validators(self): - assert ( - MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") - == "123e4567-e89b-12d3-a456-426614174000" - ) - - def test_feedback_export_validators(self): - assert FeedbackExportQuery.parse_bool(None) is None - assert FeedbackExportQuery.parse_bool(True) is True - assert FeedbackExportQuery.parse_bool("1") is True - assert FeedbackExportQuery.parse_bool("0") is False - assert FeedbackExportQuery.parse_bool("off") is False - - with pytest.raises(ValueError): - FeedbackExportQuery.parse_bool("invalid") - - -class TestMessageEndpoints: - def test_chat_message_list_not_found(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - ChatMessageListApi, - "/apps/app_123/chat-messages", - "GET", - mock_account, - mock_app_model, - qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}, - ) as (api, mock_db, v_args): - mock_db.session.scalar.return_value = None - - with pytest.raises(NotFound): - api.get(**v_args) - - def test_chat_message_list_success(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - ChatMessageListApi, - "/apps/app_123/chat-messages", - "GET", - mock_account, - mock_app_model, - qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1}, - ) as (api, mock_db, v_args): - mock_conv = MagicMock() - mock_conv.id = "123e4567-e89b-12d3-a456-426614174000" - mock_msg = MagicMock() - mock_msg.id = "msg_123" - mock_msg.feedbacks = [] - mock_msg.annotation = None - mock_msg.annotation_hit_history = None - mock_msg.agent_thoughts = [] - mock_msg.message_files = [] - mock_msg.extra_contents = [] - mock_msg.message = {} - mock_msg.message_metadata_dict = {} - - # scalar() is called twice: first for conversation lookup, second for has_more check - mock_db.session.scalar.side_effect = [mock_conv, False] - scalars_result = MagicMock() - scalars_result.all.return_value = [mock_msg] - mock_db.session.scalars.return_value = scalars_result - - resp = api.get(**v_args) - assert resp["limit"] == 1 - assert resp["has_more"] is False - assert len(resp["data"]) == 1 - - def test_message_feedback_not_found(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - MessageFeedbackApi, - "/apps/app_123/feedbacks", - "POST", - mock_account, - mock_app_model, - payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"}, - ) as (api, mock_db, v_args): - mock_db.session.scalar.return_value = None - - with pytest.raises(NotFound): - api.post(**v_args) - - def test_message_feedback_success(self, app, mock_account, mock_app_model): - payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"} - with setup_test_context( - app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload - ) as (api, mock_db, v_args): - mock_msg = MagicMock() - mock_msg.admin_feedback = None - mock_db.session.scalar.return_value = mock_msg - - resp = api.post(**v_args) - assert resp == {"result": "success"} - - def test_message_annotation_count(self, app, mock_account, mock_app_model): - with setup_test_context( - app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - mock_db.session.scalar.return_value = 5 - - resp = api.get(**v_args) - assert resp == {"count": 5} - - @patch("controllers.console.app.message.MessageService") - def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model): - mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"] - - with setup_test_context( - app, - MessageSuggestedQuestionApi, - "/apps/app_123/chat-messages/msg_123/suggested-questions", - "GET", - mock_account, - mock_app_model, - ) as (api, mock_db, v_args): - resp = api.get(**v_args) - assert resp == {"data": ["q1", "q2"]} - - @pytest.mark.parametrize( - ("exc", "expected_exc"), - [ - (MessageNotExistsError, NotFound), - (ConversationNotExistsError, NotFound), - (ProviderTokenNotInitError, ProviderNotInitializeError), - (QuotaExceededError, ProviderQuotaExceededError), - (ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError), - (SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError), - (Exception, InternalServerError), - ], - ) - @patch("controllers.console.app.message.MessageService") - def test_message_suggested_questions_errors( - self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model - ): - mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc() - - with setup_test_context( - app, - MessageSuggestedQuestionApi, - "/apps/app_123/chat-messages/msg_123/suggested-questions", - "GET", - mock_account, - mock_app_model, - ) as (api, mock_db, v_args): - with pytest.raises(expected_exc): - api.get(**v_args) - - @patch("services.feedback_service.FeedbackService.export_feedbacks") - def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model): - mock_export.return_value = {"exported": True} - - with setup_test_context( - app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - resp = api.get(**v_args) - assert resp == {"exported": True} - - def test_message_api_get_success(self, app, mock_account, mock_app_model): - with setup_test_context( - app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - mock_msg = MagicMock() - mock_msg.id = "msg_123" - mock_msg.feedbacks = [] - mock_msg.annotation = None - mock_msg.annotation_hit_history = None - mock_msg.agent_thoughts = [] - mock_msg.message_files = [] - mock_msg.extra_contents = [] - mock_msg.message = {} - mock_msg.message_metadata_dict = {} - - mock_db.session.scalar.return_value = mock_msg - - resp = api.get(**v_args) - assert resp["id"] == "msg_123" diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic.py b/api/tests/unit_tests/controllers/console/app/test_statistic.py deleted file mode 100644 index beba23385d..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_statistic.py +++ /dev/null @@ -1,275 +0,0 @@ -from decimal import Decimal -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.local import LocalProxy - -from controllers.console.app.statistic import ( - AverageResponseTimeStatistic, - AverageSessionInteractionStatistic, - DailyConversationStatistic, - DailyMessageStatistic, - DailyTerminalsStatistic, - DailyTokenCostStatistic, - TokensPerSecondStatistic, - UserSatisfactionRateStatistic, -) -from models import App, AppMode - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.CHAT - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -def setup_test_context( - test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None) -): - with ( - patch("controllers.console.app.statistic.db") as mock_db_stat, - patch("controllers.console.app.wraps.db") as mock_db_wraps, - patch("controllers.console.wraps.db", mock_db_wraps), - patch( - "controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123") - ), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - mock_conn = MagicMock() - mock_conn.execute.return_value = mock_rs - - mock_begin = MagicMock() - mock_begin.__enter__.return_value = mock_conn - mock_db_stat.engine.begin.return_value = mock_begin - - mock_query = MagicMock() - mock_query.filter.return_value.first.return_value = mock_app_model - mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model - mock_query.where.return_value.first.return_value = mock_app_model - mock_query.where.return_value.where.return_value.first.return_value = mock_app_model - mock_db_wraps.session.query.return_value = mock_query - - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - with test_app.test_request_context(route_path, method="GET"): - request.view_args = {"app_id": "app_123"} - api_instance = endpoint_class() - response = api_instance.get(app_id="app_123") - return response - - -class TestStatisticEndpoints: - def test_daily_message_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.message_count = 10 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["message_count"] == 10 - - def test_daily_conversation_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.conversation_count = 5 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyConversationStatistic, - "/apps/app_123/statistics/daily-conversations", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["conversation_count"] == 5 - - def test_daily_terminals_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.terminal_count = 2 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyTerminalsStatistic, - "/apps/app_123/statistics/daily-end-users", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["terminal_count"] == 2 - - def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.token_count = 100 - mock_row.total_price = Decimal("0.02") - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyTokenCostStatistic, - "/apps/app_123/statistics/token-costs", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["token_count"] == 100 - assert response.json["data"][0]["total_price"] == "0.02" - - def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.interactions = Decimal("3.523") - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - AverageSessionInteractionStatistic, - "/apps/app_123/statistics/average-session-interactions", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["interactions"] == 3.52 - - def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.message_count = 100 - mock_row.feedback_count = 10 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - UserSatisfactionRateStatistic, - "/apps/app_123/statistics/user-satisfaction-rate", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["rate"] == 100.0 - - def test_average_response_time_statistic(self, app, mock_account, mock_app_model): - mock_app_model.mode = AppMode.COMPLETION - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.latency = 1.234 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - AverageResponseTimeStatistic, - "/apps/app_123/statistics/average-response-time", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["latency"] == 1234.0 - - def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.tokens_per_second = 15.5 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - TokensPerSecondStatistic, - "/apps/app_123/statistics/tokens-per-second", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["tps"] == 15.5 - - @patch("controllers.console.app.statistic.parse_time_range") - def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model): - mock_parse.side_effect = ValueError("Invalid time") - - from werkzeug.exceptions import BadRequest - - with pytest.raises(BadRequest): - setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=invalid&end=invalid", - mock_account, - mock_app_model, - [], - ) - - @patch("controllers.console.app.statistic.parse_time_range") - def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model): - import datetime - - start = datetime.datetime.now() - end = datetime.datetime.now() - mock_parse.return_value = (start, end) - - response = setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=something&end=something", - mock_account, - mock_app_model, - [], - ) - assert response.status_code == 200 - mock_parse.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py deleted file mode 100644 index 9b5d47c208..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py +++ /dev/null @@ -1,313 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.local import LocalProxy - -from controllers.console.app.error import DraftWorkflowNotExist -from controllers.console.app.workflow_draft_variable import ( - ConversationVariableCollectionApi, - EnvironmentVariableCollectionApi, - NodeVariableCollectionApi, - SystemVariableCollectionApi, - VariableApi, - VariableResetApi, - WorkflowVariableCollectionApi, -) -from controllers.web.error import InvalidArgumentError, NotFoundError -from models import App, AppMode -from models.enums import DraftVariableType - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.WORKFLOW - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None): - with ( - patch("controllers.console.app.wraps.db") as mock_db_wraps, - patch("controllers.console.wraps.db", mock_db_wraps), - patch("controllers.console.app.workflow_draft_variable.db"), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - mock_query = MagicMock() - mock_query.filter.return_value.first.return_value = mock_app_model - mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model - mock_query.where.return_value.first.return_value = mock_app_model - mock_query.where.return_value.where.return_value.first.return_value = mock_app_model - mock_db_wraps.session.query.return_value = mock_query - - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - with test_app.test_request_context(route_path, method=method, json=payload): - request.view_args = {"app_id": "app_123"} - # extract node_id or variable_id from path manually since view_args overrides - if "nodes/" in route_path: - request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0] - if "variables/" in route_path: - # simplistic extraction - parts = route_path.split("variables/") - if len(parts) > 1 and parts[1] and parts[1] != "reset": - request.view_args["variable_id"] = parts[1].split("/")[0] - - api_instance = endpoint_class() - # we just call dispatch_request to avoid manual argument passing - if hasattr(api_instance, method.lower()): - func = getattr(api_instance, method.lower()) - return func(**request.view_args) - - -class TestWorkflowDraftVariableEndpoints: - @staticmethod - def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock: - class DummyValueType: - def exposed_type(self): - return DraftVariableType.NODE - - mock_var = MagicMock() - mock_var.app_id = "app_123" - mock_var.id = "var_123" - mock_var.name = "test_var" - mock_var.description = "" - mock_var.get_variable_type.return_value = variable_type - mock_var.get_selector.return_value = [] - mock_var.value_type = DummyValueType() - mock_var.edited = False - mock_var.visible = True - mock_var.file_id = None - mock_var.variable_file = None - mock_var.is_truncated.return_value = False - mock_var.get_value.return_value.model_copy.return_value.value = "test_value" - return mock_var - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_workflow_variable_collection_get_success( - self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model - ): - mock_wf_srv.return_value.is_workflow_exist.return_value = True - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList( - variables=[], total=0 - ) - - resp = setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables?page=1&limit=20", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": [], "total": 0} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.is_workflow_exist.return_value = False - - with pytest.raises(DraftWorkflowNotExist): - setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables", - "GET", - mock_account, - mock_app_model, - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): - resp = setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables", - "DELETE", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[]) - resp = setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/node_123/variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model): - with pytest.raises(InvalidArgumentError): - setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/sys/variables", - "GET", - mock_account, - mock_app_model, - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): - resp = setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/node_123/variables", - "DELETE", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model - ) - assert resp["id"] == "var_123" - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = None - - with pytest.raises(NotFoundError): - setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, - VariableApi, - "/apps/app_123/workflows/draft/variables/var_123", - "PATCH", - mock_account, - mock_app_model, - payload={"name": "new_name"}, - ) - assert resp["id"] == "var_123" - mock_draft_srv.return_value.update_variable.assert_called_once() - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model - ) - assert resp.status_code == 204 - mock_draft_srv.return_value.delete_variable.assert_called_once() - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - mock_draft_srv.return_value.reset_variable.return_value = None # means no content - - resp = setup_test_context( - app, - VariableResetApi, - "/apps/app_123/workflows/draft/variables/var_123/reset", - "PUT", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[]) - - resp = setup_test_context( - app, - ConversationVariableCollectionApi, - "/apps/app_123/workflows/draft/conversation-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model): - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[]) - - resp = setup_test_context( - app, - SystemVariableCollectionApi, - "/apps/app_123/workflows/draft/system-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf = MagicMock() - mock_wf.environment_variables = [] - mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf - - resp = setup_test_context( - app, - EnvironmentVariableCollectionApi, - "/apps/app_123/workflows/draft/environment-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py deleted file mode 100644 index bc4c7e0993..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ /dev/null @@ -1,209 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask - -from controllers.console.auth.data_source_bearer_auth import ( - ApiKeyAuthDataSource, - ApiKeyAuthDataSourceBinding, - ApiKeyAuthDataSourceBindingDelete, -) -from controllers.console.auth.error import ApiKeyAuthFailedError - - -class TestApiKeyAuthDataSource: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") - def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_binding = MagicMock() - mock_binding.id = "bind_123" - mock_binding.category = "api_key" - mock_binding.provider = "custom_provider" - mock_binding.disabled = False - mock_binding.created_at.timestamp.return_value = 1620000000 - mock_binding.updated_at.timestamp.return_value = 1620000001 - - mock_get_list.return_value = [mock_binding] - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSource() - response = api_instance.get() - - assert "sources" in response - assert len(response["sources"]) == 1 - assert response["sources"][0]["provider"] == "custom_provider" - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") - def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_get_list.return_value = None - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSource() - response = api_instance.get() - - assert "sources" in response - assert len(response["sources"]) == 0 - - -class TestApiKeyAuthDataSourceBinding: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") - def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context( - "/console/api/api-key-auth/data-source/binding", - method="POST", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, - ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBinding() - response = api_instance.post() - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_validate.assert_called_once() - mock_create.assert_called_once() - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") - def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_create.side_effect = ValueError("Invalid structure") - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context( - "/console/api/api-key-auth/data-source/binding", - method="POST", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, - ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBinding() - with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"): - api_instance.post() - - -class TestApiKeyAuthDataSourceBindingDelete: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth") - def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBindingDelete() - response = api_instance.delete("binding_123") - - assert response[0]["result"] == "success" - assert response[1] == 204 - mock_delete.assert_called_once_with("tenant_123", "binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py deleted file mode 100644 index f369565946..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py +++ /dev/null @@ -1,192 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from werkzeug.local import LocalProxy - -from controllers.console.auth.data_source_oauth import ( - OAuthDataSource, - OAuthDataSourceBinding, - OAuthDataSourceCallback, - OAuthDataSourceSync, -) - - -class TestOAuthDataSource: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("flask_login.current_user") - @patch("libs.login.current_user") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None) - def test_get_oauth_url_successful( - self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app - ): - mock_oauth_provider = MagicMock() - mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth" - mock_get_providers.return_value = {"notion": mock_oauth_provider} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - mock_libs_user.return_value = mock_account - mock_flask_user.return_value = mock_account - - # also patch current_account_with_tenant - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSource() - response = api_instance.get("notion") - - assert response[0]["data"] == "http://oauth.provider/auth" - assert response[1] == 200 - mock_oauth_provider.get_authorization_url.assert_called_once() - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("flask_login.current_user") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSource() - response = api_instance.get("unknown_provider") - - assert response[0]["error"] == "Invalid provider" - assert response[1] == 400 - - -class TestOAuthDataSourceCallback: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_successful(self, mock_get_providers, app): - provider_mock = MagicMock() - mock_get_providers.return_value = {"notion": provider_mock} - - with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("notion") - - assert response.status_code == 302 - assert "code=mock_code" in response.location - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_missing_code(self, mock_get_providers, app): - provider_mock = MagicMock() - mock_get_providers.return_value = {"notion": provider_mock} - - with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("notion") - - assert response.status_code == 302 - assert "error=Access denied" in response.location - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_invalid_provider(self, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("invalid") - - assert response[0]["error"] == "Invalid provider" - assert response[1] == 400 - - -class TestOAuthDataSourceBinding: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_get_binding_successful(self, mock_get_providers, app): - mock_provider = MagicMock() - mock_provider.get_access_token.return_value = None - mock_get_providers.return_value = {"notion": mock_provider} - - with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"): - api_instance = OAuthDataSourceBinding() - response = api_instance.get("notion") - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_provider.get_access_token.assert_called_once_with("auth_code_123") - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_get_binding_missing_code(self, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"): - api_instance = OAuthDataSourceBinding() - response = api_instance.get("notion") - - assert response[0]["error"] == "Invalid code" - assert response[1] == 400 - - -class TestOAuthDataSourceSync: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app): - mock_provider = MagicMock() - mock_provider.sync_data_source.return_value = None - mock_get_providers.return_value = {"notion": mock_provider} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSourceSync() - # The route pattern uses , so we just pass a string for unit testing - response = api_instance.get("notion", "binding_123") - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_provider.sync_data_source.assert_called_once_with("binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py deleted file mode 100644 index fc5663e72d..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py +++ /dev/null @@ -1,417 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from werkzeug.exceptions import BadRequest, NotFound - -from controllers.console.auth.oauth_server import ( - OAuthServerAppApi, - OAuthServerUserAccountApi, - OAuthServerUserAuthorizeApi, - OAuthServerUserTokenApi, -) - - -class TestOAuthServerAppApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - oauth_app.redirect_uris = ["http://localhost/callback"] - oauth_app.app_icon = "icon_url" - oauth_app.app_label = "Test App" - oauth_app.scope = "read,write" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, - ): - api_instance = OAuthServerAppApi() - response = api_instance.post() - - assert response["app_icon"] == "icon_url" - assert response["app_label"] == "Test App" - assert response["scope"] == "read,write" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, - ): - api_instance = OAuthServerAppApi() - with pytest.raises(BadRequest, match="redirect_uri is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_client_id(self, mock_get_app, mock_db, app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = None - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, - ): - api_instance = OAuthServerAppApi() - with pytest.raises(NotFound, match="client_id is invalid"): - api_instance.post() - - -class TestOAuthServerUserAuthorizeApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - oauth_app = MagicMock() - oauth_app.client_id = "test_client_id" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.current_account_with_tenant") - @patch("controllers.console.wraps.current_account_with_tenant") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code") - @patch("libs.login.check_csrf_token") - def test_successful_authorize( - self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app - ): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - mock_account = MagicMock() - mock_account.id = "user_123" - from models.account import AccountStatus - - mock_account.status = AccountStatus.ACTIVE - - mock_current.return_value = (mock_account, MagicMock()) - mock_wrap_current.return_value = (mock_account, MagicMock()) - - mock_sign.return_value = "auth_code_123" - - with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}): - with patch("libs.login.current_user", mock_account): - api_instance = OAuthServerUserAuthorizeApi() - response = api_instance.post() - - assert response["code"] == "auth_code_123" - mock_sign.assert_called_once_with("test_client_id", "user_123") - - -class TestOAuthServerUserTokenApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - oauth_app.client_secret = "test_secret" - oauth_app.redirect_uris = ["http://localhost/callback"] - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") - def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_sign.return_value = ("access_123", "refresh_123") - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "test_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - response = api_instance.post() - - assert response["access_token"] == "access_123" - assert response["refresh_token"] == "refresh_123" - assert response["token_type"] == "Bearer" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "client_secret": "test_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="code is required"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "invalid_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="client_secret is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "test_secret", - "redirect_uri": "http://invalid/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="redirect_uri is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") - def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_sign.return_value = ("new_access", "new_refresh") - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, - ): - api_instance = OAuthServerUserTokenApi() - response = api_instance.post() - - assert response["access_token"] == "new_access" - assert response["refresh_token"] == "new_refresh" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "refresh_token", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="refresh_token is required"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "invalid_grant", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="invalid grant_type"): - api_instance.post() - - -class TestOAuthServerUserAccountApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") - def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - mock_account = MagicMock() - mock_account.name = "Test User" - mock_account.email = "test@example.com" - mock_account.avatar = "avatar_url" - mock_account.interface_language = "en-US" - mock_account.timezone = "UTC" - mock_validate.return_value = mock_account - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer valid_access_token"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response["name"] == "Test User" - assert response["email"] == "test@example.com" - assert response["avatar"] == "avatar_url" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Authorization header is required" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "InvalidFormat"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Invalid Authorization header format" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Basic something"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "token_type is invalid" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer "}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Invalid Authorization header format" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") - def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_validate.return_value = None - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer invalid_token"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "access_token or client_id is invalid" From 56e0907548b024a2e29bd8bfd29ab72c53d85333 Mon Sep 17 00:00:00 2001 From: letterbeezps Date: Mon, 23 Mar 2026 20:42:57 +0800 Subject: [PATCH 03/34] fix: do not block upsert for baidu vdb (#33280) Co-authored-by: zhangping24 Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/.env.example | 3 ++ .../middleware/vdb/baidu_vector_config.py | 15 ++++++ .../rag/datasource/vdb/baidu/baidu_vector.py | 50 +++++++++++++------ docker/.env.example | 3 ++ docker/docker-compose.yaml | 3 ++ 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/api/.env.example b/api/.env.example index 40e1c2dfdf..9672a99d55 100644 --- a/api/.env.example +++ b/api/.env.example @@ -353,6 +353,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # Upstash configuration UPSTASH_VECTOR_URL=your-server-url diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 8f956745b1..c8e4f7309f 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings): description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)", default="COARSE_MODE", ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field( + description="Auto build row count increment threshold (default is 500)", + default=500, + ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field( + description="Auto build row count increment ratio threshold (default is 0.05)", + default=0.05, + ) + + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field( + description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)", + default=300, + ) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 144d834495..9f5842e449 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore from pymochow.model.database import Database from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore from pymochow.model.schema import ( + AutoBuildRowCountIncrement, Field, FilteringIndex, HNSWParams, @@ -51,6 +52,9 @@ class BaiduConfig(BaseModel): replicas: int = 3 inverted_index_analyzer: str = "DEFAULT_ANALYZER" inverted_index_parser_mode: str = "COARSE_MODE" + auto_build_row_count_increment: int = 500 + auto_build_row_count_increment_ratio: float = 0.05 + rebuild_index_timeout_in_seconds: int = 300 @model_validator(mode="before") @classmethod @@ -107,18 +111,6 @@ class BaiduVector(BaseVector): rows.append(row) table.upsert(rows=rows) - # rebuild vector index after upsert finished - table.rebuild_index(self.vector_index) - timeout = 3600 # 1 hour timeout - start_time = time.time() - while True: - time.sleep(1) - index = table.describe_index(self.vector_index) - if index.state == IndexState.NORMAL: - break - if time.time() - start_time > timeout: - raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") - def text_exists(self, id: str) -> bool: res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id}) if res and res.code == 0: @@ -232,8 +224,14 @@ class BaiduVector(BaseVector): return self._client.database(self._client_config.database) def _table_existed(self) -> bool: - tables = self._db.list_table() - return any(table.table_name == self._collection_name for table in tables) + try: + table = self._db.table(self._collection_name) + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + return False + else: + raise + return True def _create_table(self, dimension: int): # Try to grab distributed lock and create table @@ -287,6 +285,11 @@ class BaiduVector(BaseVector): field=VDBField.VECTOR, metric_type=metric_type, params=HNSWParams(m=16, efconstruction=200), + auto_build=True, + auto_build_index_policy=AutoBuildRowCountIncrement( + row_count_increment=self._client_config.auto_build_row_count_increment, + row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio, + ), ) ) @@ -335,7 +338,7 @@ class BaiduVector(BaseVector): ) # Wait for table created - timeout = 300 # 5 minutes timeout + timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout start_time = time.time() while True: time.sleep(1) @@ -345,6 +348,20 @@ class BaiduVector(BaseVector): if time.time() - start_time > timeout: raise TimeoutError(f"Table creation timeout after {timeout} seconds") redis_client.set(table_exist_cache_key, 1, ex=3600) + # rebuild vector index immediately after table created, make sure index is ready + table.rebuild_index(self.vector_index) + timeout = 3600 # 1 hour timeout + self._wait_for_index_ready(table, timeout) + + def _wait_for_index_ready(self, table, timeout: int = 3600): + start_time = time.time() + while True: + time.sleep(1) + index = table.describe_index(self.vector_index) + if index.state == IndexState.NORMAL: + break + if time.time() - start_time > timeout: + raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") class BaiduVectorFactory(AbstractVectorFactory): @@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory): replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER, inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE, + auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT, + auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO, + rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS, ), ) diff --git a/docker/.env.example b/docker/.env.example index 9d6cd65318..8cf77cf56b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -771,6 +771,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # VikingDB configurations, only available when VECTOR_STORE is `vikingdb` VIKINGDB_ACCESS_KEY=your-ak diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index bf72a0f623..6e11cac678 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -345,6 +345,9 @@ x-shared-env: &shared-api-worker-env BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: ${BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER:-DEFAULT_ANALYZER} BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: ${BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE:-COARSE_MODE} + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: ${BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT:-500} + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: ${BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO:-0.05} + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: ${BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS:-300} VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-your-ak} VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-your-sk} VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai} From 4b4a5c058e7d3706a9c4a6af95824e553f88bd5c Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 07:52:31 -0500 Subject: [PATCH 04/34] test: migrate file service zip and lookup tests to testcontainers (#33944) --- .../test_file_service_zip_and_lookup.py | 96 ++++++++++++++++++ .../test_file_service_zip_and_lookup.py | 99 ------------------- 2 files changed, 96 insertions(+), 99 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py delete mode 100644 api/tests/unit_tests/services/test_file_service_zip_and_lookup.py diff --git a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py new file mode 100644 index 0000000000..4e0a726cc7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py @@ -0,0 +1,96 @@ +""" +Testcontainers integration tests for FileService helpers. + +Covers: +- ZIP tempfile building (sanitization + deduplication + content writes) +- tenant-scoped batch lookup behavior (get_upload_files_by_ids) +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 +from zipfile import ZipFile + +import pytest + +import services.file_service as file_service_module +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.file_service import FileService + + +def _create_upload_file(db_session, *, tenant_id: str, key: str, name: str) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.OPENDAL, + key=key, + name=name, + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session.add(upload_file) + db_session.commit() + return upload_file + + +def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure ZIP entry names are safe and unique while preserving extensions.""" + upload_files: list[Any] = [ + SimpleNamespace(name="a/b.txt", key="k1"), + SimpleNamespace(name="c/b.txt", key="k2"), + SimpleNamespace(name="../b.txt", key="k3"), + ] + + data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} + + def _load(key: str, stream: bool = True) -> list[bytes]: + assert stream is True + return data_by_key[key] + + monkeypatch.setattr(file_service_module.storage, "load", _load) + + with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: + with ZipFile(tmp, mode="r") as zf: + assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] + assert zf.read("b.txt") == b"one" + assert zf.read("b (1).txt") == b"two" + assert zf.read("b (2).txt") == b"three" + + +def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers) -> None: + """Ensure empty input returns an empty mapping without hitting the database.""" + assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {} + + +def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers) -> None: + """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" + tenant_id = str(uuid4()) + file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt") + file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt") + + result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id]) + + assert set(result.keys()) == {file1.id, file2.id} + assert result[file1.id].id == file1.id + assert result[file2.id].id == file2.id + + +def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers) -> None: + """Ensure files from other tenants are not returned.""" + tenant_a = str(uuid4()) + tenant_b = str(uuid4()) + file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt") + _create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt") + + result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id]) + + assert set(result.keys()) == {file_a.id} diff --git a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py b/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py deleted file mode 100644 index 7b4d349e33..0000000000 --- a/api/tests/unit_tests/services/test_file_service_zip_and_lookup.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Unit tests for `services.file_service.FileService` helpers. - -We keep these tests focused on: -- ZIP tempfile building (sanitization + deduplication + content writes) -- tenant-scoped batch lookup behavior (`get_upload_files_by_ids`) -""" - -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any -from zipfile import ZipFile - -import pytest - -import services.file_service as file_service_module -from services.file_service import FileService - - -def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure ZIP entry names are safe and unique while preserving extensions.""" - - # Arrange: three upload files that all sanitize down to the same basename ("b.txt"). - upload_files: list[Any] = [ - SimpleNamespace(name="a/b.txt", key="k1"), - SimpleNamespace(name="c/b.txt", key="k2"), - SimpleNamespace(name="../b.txt", key="k3"), - ] - - # Stream distinct bytes per key so we can verify content is written to the right entry. - data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} - - def _load(key: str, stream: bool = True) -> list[bytes]: - # Return the corresponding chunks for this key (the production code iterates chunks). - assert stream is True - return data_by_key[key] - - monkeypatch.setattr(file_service_module.storage, "load", _load) - - # Act: build zip in a tempfile. - with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: - with ZipFile(tmp, mode="r") as zf: - # Assert: names are sanitized (no directory components) and deduped with suffixes. - assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] - - # Assert: each entry contains the correct bytes from storage. - assert zf.read("b.txt") == b"one" - assert zf.read("b (1).txt") == b"two" - assert zf.read("b (2).txt") == b"three" - - -def test_get_upload_files_by_ids_returns_empty_when_no_ids(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure empty input returns an empty mapping without hitting the database.""" - - class _Session: - def scalars(self, _stmt): # type: ignore[no-untyped-def] - raise AssertionError("db.session.scalars should not be called for empty id lists") - - monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=_Session())) - - assert FileService.get_upload_files_by_ids("tenant-1", []) == {} - - -def test_get_upload_files_by_ids_returns_id_keyed_mapping(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" - - upload_files: list[Any] = [ - SimpleNamespace(id="file-1", tenant_id="tenant-1"), - SimpleNamespace(id="file-2", tenant_id="tenant-1"), - ] - - class _ScalarResult: - def __init__(self, items: list[Any]) -> None: - self._items = items - - def all(self) -> list[Any]: - return self._items - - class _Session: - def __init__(self, items: list[Any]) -> None: - self._items = items - self.calls: list[object] = [] - - def scalars(self, stmt): # type: ignore[no-untyped-def] - # Capture the statement so we can at least assert the query path is taken. - self.calls.append(stmt) - return _ScalarResult(self._items) - - session = _Session(upload_files) - monkeypatch.setattr(file_service_module, "db", SimpleNamespace(session=session)) - - # Provide duplicates to ensure callers can safely pass repeated ids. - result = FileService.get_upload_files_by_ids("tenant-1", ["file-1", "file-1", "file-2"]) - - assert set(result.keys()) == {"file-1", "file-2"} - assert result["file-1"].id == "file-1" - assert result["file-2"].id == "file-2" - assert len(session.calls) == 1 From 72e3fcd25fb18a34bc335c338df524e1caa13d12 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 07:54:37 -0500 Subject: [PATCH 05/34] test: migrate end user service batch tests to testcontainers (#33947) --- .../services/test_end_user_service.py | 141 +++ .../services/test_end_user_service.py | 841 ------------------ 2 files changed, 141 insertions(+), 841 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_end_user_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index ae811db768..cafabc939b 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -414,3 +414,144 @@ class TestEndUserServiceGetEndUserById: ) assert result is None + + +class TestEndUserServiceCreateBatch: + """Integration tests for EndUserService.create_end_user_batch.""" + + @pytest.fixture + def factory(self): + return TestEndUserServiceFactory() + + def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + """Create multiple apps under the same tenant.""" + first_app = factory.create_app_and_account(db_session_with_containers) + tenant_id = first_app.tenant_id + apps = [first_app] + for _ in range(count - 1): + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=first_app.created_by, + updated_by=first_app.updated_by, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() + return tenant_id, all_apps + + def test_create_batch_empty_app_ids(self, db_session_with_containers): + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" + ) + assert result == {} + + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 3 + for app_id in app_ids: + assert app_id in result + assert result[app_id].session_id == user_id + assert result[app_id].type == InvokeFrom.SERVICE_API + + def test_create_batch_default_session_id(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="" + ) + + assert len(result) == 2 + for end_user in result.values(): + assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user._is_anonymous is True + + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 2 + + def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + # Create batch first time + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Create batch second time — should return existing users + second_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(second_result) == 2 + for app_id in app_ids: + assert first_result[app_id].id == second_result[app_id].id + + def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + user_id = f"user-{uuid4()}" + + # Create for first 2 apps + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[apps[0].id, apps[1].id], + user_id=user_id, + ) + + # Create for all 3 apps — should reuse first 2, create 3rd + all_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[a.id for a in apps], + user_id=user_id, + ) + + assert len(all_result) == 3 + assert all_result[apps[0].id].id == first_result[apps[0].id].id + assert all_result[apps[1].id].id == first_result[apps[1].id].id + assert all_result[apps[2].id].session_id == user_id + + @pytest.mark.parametrize( + "invoke_type", + [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], + ) + def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=invoke_type, tenant_id=tenant_id, app_ids=[apps[0].id], user_id=user_id + ) + + assert len(result) == 1 + assert result[apps[0].id].type == invoke_type diff --git a/api/tests/unit_tests/services/test_end_user_service.py b/api/tests/unit_tests/services/test_end_user_service.py deleted file mode 100644 index a3b1f46436..0000000000 --- a/api/tests/unit_tests/services/test_end_user_service.py +++ /dev/null @@ -1,841 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from models.model import App, DefaultEndUserSessionID, EndUser -from services.end_user_service import EndUserService - - -class TestEndUserServiceFactory: - """Factory class for creating test data and mock objects for end user service tests.""" - - @staticmethod - def create_app_mock( - app_id: str = "app-123", - tenant_id: str = "tenant-456", - name: str = "Test App", - ) -> MagicMock: - """Create a mock App object.""" - app = MagicMock(spec=App) - app.id = app_id - app.tenant_id = tenant_id - app.name = name - return app - - @staticmethod - def create_end_user_mock( - user_id: str = "user-789", - tenant_id: str = "tenant-456", - app_id: str = "app-123", - session_id: str = "session-001", - type: InvokeFrom = InvokeFrom.SERVICE_API, - is_anonymous: bool = False, - ) -> MagicMock: - """Create a mock EndUser object.""" - end_user = MagicMock(spec=EndUser) - end_user.id = user_id - end_user.tenant_id = tenant_id - end_user.app_id = app_id - end_user.session_id = session_id - end_user.type = type - end_user.is_anonymous = is_anonymous - end_user.external_user_id = session_id - return end_user - - -class TestEndUserServiceGetEndUserById: - """Unit tests for EndUserService.get_end_user_by_id method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory): - """Test successful retrieval of end user by ID.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = mock_end_user - - # Act - result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - assert result == mock_end_user - mock_session.query.assert_called_once_with(EndUser) - mock_query.where.assert_called_once() - mock_query.first.assert_called_once() - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class): - """Test retrieval of non-existent end user returns None.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - assert result is None - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class): - """Test that query parameters are correctly applied.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - end_user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) - - # Assert - # Verify the where clause was called with the correct conditions - call_args = mock_query.where.call_args[0] - assert len(call_args) == 3 - # Check that the conditions match the expected filters - # (We can't easily test the exact conditions without importing SQLAlchemy) - - -class TestEndUserServiceGetOrCreateEndUser: - """Unit tests for EndUserService.get_or_create_end_user method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") - def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory): - """Test get_or_create_end_user with specific user_id.""" - # Arrange - app_mock = factory.create_app_mock() - user_id = "user-123" - expected_end_user = factory.create_end_user_mock() - mock_get_or_create_by_type.return_value = expected_end_user - - # Act - result = EndUserService.get_or_create_end_user(app_mock, user_id) - - # Assert - assert result == expected_end_user - mock_get_or_create_by_type.assert_called_once_with( - InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id - ) - - @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") - def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory): - """Test get_or_create_end_user without user_id (None).""" - # Arrange - app_mock = factory.create_app_mock() - expected_end_user = factory.create_end_user_mock() - mock_get_or_create_by_type.return_value = expected_end_user - - # Act - result = EndUserService.get_or_create_end_user(app_mock, None) - - # Assert - assert result == expected_end_user - mock_get_or_create_by_type.assert_called_once_with( - InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None - ) - - -class TestEndUserServiceGetOrCreateEndUserByType: - """ - Unit tests for EndUserService.get_or_create_end_user_by_type method. - - This test suite covers: - - Creating end users with different InvokeFrom types - - Type migration for legacy users - - Query ordering and prioritization - - Session management - """ - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory): - """Test creating a new end user with specific user_id.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None # No existing user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - # Verify new EndUser was created with correct parameters - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - added_user = mock_session.add.call_args[0][0] - assert added_user.tenant_id == tenant_id - assert added_user.app_id == app_id - assert added_user.type == type_enum - assert added_user.session_id == user_id - assert added_user.external_user_id == user_id - assert added_user._is_anonymous is False - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory): - """Test creating a new end user with default session ID.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = None - type_enum = InvokeFrom.WEB_APP - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None # No existing user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - added_user = mock_session.add.call_args[0][0] - assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert added_user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - @patch("services.end_user_service.logger") - def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory): - """Test retrieving existing user with same type.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user = factory.create_end_user_mock( - tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = existing_user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - assert result == existing_user - mock_session.add.assert_not_called() - mock_session.commit.assert_not_called() - mock_logger.info.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - @patch("services.end_user_service.logger") - def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory): - """Test upgrading existing user with different type.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - old_type = InvokeFrom.WEB_APP - new_type = InvokeFrom.SERVICE_API - - existing_user = factory.create_end_user_mock( - tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = existing_user - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - assert result == existing_user - assert existing_user.type == new_type - mock_session.commit.assert_called_once() - mock_logger.info.assert_called_once() - logger_call_args = mock_logger.info.call_args[0] - assert "Upgrading legacy EndUser" in logger_call_args[0] - # The old and new types are passed as separate arguments - assert mock_logger.info.call_args[0][1] == existing_user.id - assert mock_logger.info.call_args[0][2] == old_type - assert mock_logger.info.call_args[0][3] == new_type - assert mock_logger.info.call_args[0][4] == user_id - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory): - """Test that query ordering prioritizes exact type matches.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - target_type = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_or_create_end_user_by_type( - type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - mock_query.order_by.assert_called_once() - # Verify that case statement is used for ordering - order_by_call = mock_query.order_by.call_args[0][0] - # The exact structure depends on SQLAlchemy's case implementation - # but we can verify it was called - - # Test 10: Session context manager properly closes - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_session_context_manager_closes(self, mock_db, mock_session_class, factory): - """Test that Session context manager is properly used.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.SERVICE_API, - tenant_id=tenant_id, - app_id=app_id, - user_id=user_id, - ) - - # Assert - # Verify context manager was entered and exited - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_all_invokefrom_types_supported(self, mock_db, mock_session_class): - """Test that all InvokeFrom enum values are supported.""" - # Arrange - tenant_id = "tenant-123" - app_id = "app-456" - user_id = "user-789" - - for invoke_type in InvokeFrom: - with patch("services.end_user_service.Session") as mock_session_class: - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.first.return_value = None - - # Act - result = EndUserService.get_or_create_end_user_by_type( - type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id - ) - - # Assert - added_user = mock_session.add.call_args[0][0] - assert added_user.type == invoke_type - - -class TestEndUserServiceCreateEndUserBatch: - """Unit tests for EndUserService.create_end_user_batch method.""" - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestEndUserServiceFactory() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_empty_app_ids(self, mock_db, mock_session_class): - """Test batch creation with empty app_ids list.""" - # Arrange - tenant_id = "tenant-123" - app_ids: list[str] = [] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert result == {} - mock_session_class.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_default_session_id(self, mock_db, mock_session_class): - """Test batch creation with empty user_id (uses default session).""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - user_id = "" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 2 - for app_id, end_user in result.items(): - assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - assert end_user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class): - """Test that duplicate app_ids are deduplicated while preserving order.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - # Should have 3 unique app_ids in original order - assert len(result) == 3 - assert "app-456" in result - assert "app-789" in result - assert "app-123" in result - - # Verify the order is preserved - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 3 - assert added_users[0].app_id == "app-456" - assert added_users[1].app_id == "app-789" - assert added_users[2].app_id == "app-123" - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory): - """Test batch creation when all users already exist.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user1 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - existing_user2 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1, existing_user2] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 2 - assert result["app-456"] == existing_user1 - assert result["app-789"] == existing_user2 - mock_session.add_all.assert_not_called() - mock_session.commit.assert_not_called() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory): - """Test batch creation with some existing and some new users.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-123"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - existing_user1 = factory.create_end_user_mock( - tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - # app-789 and app-123 don't exist - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 3 - assert result["app-456"] == existing_user1 - assert "app-789" in result - assert "app-123" in result - - # Should create 2 new users - mock_session.add_all.assert_called_once() - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 2 - - mock_session.commit.assert_called_once() - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory): - """Test batch creation handles duplicates in existing users gracefully.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - # Simulate duplicate records in database - existing_user1 = factory.create_end_user_mock( - user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - existing_user2 = factory.create_end_user_mock( - user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum - ) - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [existing_user1, existing_user2] - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 1 - # Should prefer the first one found - assert result["app-456"] == existing_user1 - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class): - """Test batch creation with all InvokeFrom types.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - - for invoke_type in InvokeFrom: - with patch("services.end_user_service.Session") as mock_session_class: - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - added_user = mock_session.add_all.call_args[0][0][0] - assert added_user.type == invoke_type - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory): - """Test batch creation with single app_id.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - result = EndUserService.create_end_user_batch( - type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id - ) - - # Assert - assert len(result) == 1 - assert "app-456" in result - mock_session.add_all.assert_called_once() - added_users = mock_session.add_all.call_args[0][0] - assert len(added_users) == 1 - assert added_users[0].app_id == "app-456" - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class): - """Test batch creation correctly sets anonymous flag.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789"] - - # Test with regular user ID - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - authenticated user - result = EndUserService.create_end_user_batch( - type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789" - ) - - # Assert - added_users = mock_session.add_all.call_args[0][0] - for user in added_users: - assert user._is_anonymous is False - - # Test with default session ID - mock_session.reset_mock() - mock_query.reset_mock() - mock_query.all.return_value = [] - - # Act - anonymous user - result = EndUserService.create_end_user_batch( - type=InvokeFrom.SERVICE_API, - tenant_id=tenant_id, - app_ids=app_ids, - user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID, - ) - - # Assert - added_users = mock_session.add_all.call_args[0][0] - for user in added_users: - assert user._is_anonymous is True - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_efficient_single_query(self, mock_db, mock_session_class): - """Test that batch creation uses efficient single query for existing users.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456", "app-789", "app-123"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) - - # Assert - # Should make exactly one query to check for existing users - mock_session.query.assert_called_once_with(EndUser) - mock_query.where.assert_called_once() - mock_query.all.assert_called_once() - - # Verify the where clause uses .in_() for app_ids - where_call = mock_query.where.call_args[0] - # The exact structure depends on SQLAlchemy implementation - # but we can verify it was called with the right parameters - - @patch("services.end_user_service.Session") - @patch("services.end_user_service.db") - def test_create_batch_session_context_manager(self, mock_db, mock_session_class): - """Test that batch creation properly uses session context manager.""" - # Arrange - tenant_id = "tenant-123" - app_ids = ["app-456"] - user_id = "user-789" - type_enum = InvokeFrom.SERVICE_API - - mock_session = MagicMock() - mock_context = MagicMock() - mock_context.__enter__.return_value = mock_session - mock_session_class.return_value = mock_context - - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.all.return_value = [] # No existing users - - # Act - EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) - - # Assert - mock_context.__enter__.assert_called_once() - mock_context.__exit__.assert_called_once() - mock_session.commit.assert_called_once() From 65223c80925d3b3ca9a19b082d5321c84ad9c6a4 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 07:55:50 -0500 Subject: [PATCH 06/34] test: remove mock-based tests superseded by testcontainers (#33946) --- .../test_delete_archived_workflow_run.py | 57 ------------------- ...kflow_node_execution_service_repository.py | 30 ---------- 2 files changed, 87 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_delete_archived_workflow_run.py delete mode 100644 api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py deleted file mode 100644 index a7e1a011f6..0000000000 --- a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Unit tests for archived workflow run deletion service. -""" - -from unittest.mock import MagicMock, patch - - -class TestArchivedWorkflowRunDeletion: - def test_delete_by_run_id_calls_delete_run(self): - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - deleter = ArchivedWorkflowRunDeletion() - repo = MagicMock() - repo.get_archived_run_ids.return_value = {"run-1"} - run = MagicMock() - run.id = "run-1" - run.tenant_id = "tenant-1" - - session = MagicMock() - session.get.return_value = run - - session_maker = MagicMock() - session_maker.return_value.__enter__.return_value = session - session_maker.return_value.__exit__.return_value = None - mock_db = MagicMock() - mock_db.engine = MagicMock() - - with ( - patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), - patch( - "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", - return_value=session_maker, - autospec=True, - ), - patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True), - patch.object( - deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True - ) as mock_delete_run, - ): - result = deleter.delete_by_run_id("run-1") - - assert result.success is True - mock_delete_run.assert_called_once_with(run) - - def test_delete_run_dry_run(self): - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - deleter = ArchivedWorkflowRunDeletion(dry_run=True) - run = MagicMock() - run.id = "run-1" - run.tenant_id = "tenant-1" - - with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo: - result = deleter._delete_run(run) - - assert result.success is True - mock_get_repo.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py deleted file mode 100644 index 79bf5e94c2..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from repositories.sqlalchemy_api_workflow_node_execution_repository import ( - DifyAPISQLAlchemyWorkflowNodeExecutionRepository, -) - - -class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: - @pytest.fixture - def repository(self): - mock_session_maker = MagicMock() - return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker) - - def test_repository_implements_protocol(self, repository): - """Test that the repository implements the required protocol methods.""" - # Verify all protocol methods are implemented - assert hasattr(repository, "get_node_last_execution") - assert hasattr(repository, "get_executions_by_workflow_run") - assert hasattr(repository, "get_execution_by_id") - - # Verify methods are callable - assert callable(repository.get_node_last_execution) - assert callable(repository.get_executions_by_workflow_run) - assert callable(repository.get_execution_by_id) - assert callable(repository.delete_expired_executions) - assert callable(repository.delete_executions_by_app) - assert callable(repository.get_expired_executions_batch) - assert callable(repository.delete_executions_by_ids) From 30dd36505ca0f7d6fb2564dbe6f36cff90a2f9a8 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 07:57:01 -0500 Subject: [PATCH 07/34] test: migrate batch update document status tests to testcontainers (#33951) --- ...et_service_batch_update_document_status.py | 16 +++ ...et_service_batch_update_document_status.py | 100 ------------------ 2 files changed, 16 insertions(+), 100 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index 7983b1cd93..ab7e2a3f50 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -694,3 +694,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1) patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id) + + def test_batch_update_invalid_action_raises_value_error( + self, db_session_with_containers: Session, patched_dependencies + ): + """Test that an invalid action raises ValueError.""" + factory = DocumentBatchUpdateIntegrationDataFactory + dataset = factory.create_dataset(db_session_with_containers) + doc = factory.create_document(db_session_with_containers, dataset) + user = UserDouble(id=str(uuid4())) + + patched_dependencies["redis_client"].get.return_value = None + + with pytest.raises(ValueError, match="Invalid action"): + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user + ) diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py deleted file mode 100644 index abff48347e..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ /dev/null @@ -1,100 +0,0 @@ -import datetime -from unittest.mock import Mock, patch - -import pytest - -from models.dataset import Dataset, Document -from services.dataset_service import DocumentService -from tests.unit_tests.conftest import redis_mock - - -class DocumentBatchUpdateTestDataFactory: - """Factory class for creating test data and mock objects for document batch update tests.""" - - @staticmethod - def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-456") -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - return dataset - - @staticmethod - def create_user_mock(user_id: str = "user-789") -> Mock: - """Create a mock user.""" - user = Mock() - user.id = user_id - return user - - @staticmethod - def create_document_mock( - document_id: str = "doc-1", - name: str = "test_document.pdf", - enabled: bool = True, - archived: bool = False, - indexing_status: str = "completed", - completed_at: datetime.datetime | None = None, - **kwargs, - ) -> Mock: - """Create a mock document with specified attributes.""" - document = Mock(spec=Document) - document.id = document_id - document.name = name - document.enabled = enabled - document.archived = archived - document.indexing_status = indexing_status - document.completed_at = completed_at or datetime.datetime.now() - - document.disabled_at = None - document.disabled_by = None - document.archived_at = None - document.archived_by = None - document.updated_at = None - - for key, value in kwargs.items(): - setattr(document, key, value) - return document - - -class TestDatasetServiceBatchUpdateDocumentStatus: - """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - """Common mock setup for document service dependencies.""" - with ( - patch("services.dataset_service.DocumentService.get_document") as mock_get_doc, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, - ): - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_naive_utc_now.return_value = current_time - - yield { - "get_document": mock_get_doc, - "db_session": mock_db, - "naive_utc_now": mock_naive_utc_now, - "current_time": current_time, - } - - def test_batch_update_invalid_action_error(self, mock_document_service_dependencies): - """Test that ValueError is raised when an invalid action is provided.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = doc - - redis_mock.reset_mock() - redis_mock.get.return_value = None - - invalid_action = "invalid_action" - with pytest.raises(ValueError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user - ) - - assert invalid_action in str(exc_info.value) - assert "Invalid action" in str(exc_info.value) - - redis_mock.setex.assert_not_called() From 30deeb6f1c81f793f9e723318880c1e57a5762f2 Mon Sep 17 00:00:00 2001 From: kurokobo Date: Mon, 23 Mar 2026 22:19:32 +0900 Subject: [PATCH 08/34] feat(firecrawl): follow pagination when crawl status is completed (#33864) Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com> --- .../rag/extractor/firecrawl/firecrawl_app.py | 42 ++++++++-- .../rag/extractor/firecrawl/test_firecrawl.py | 78 +++++++++++++++++++ 2 files changed, 112 insertions(+), 8 deletions(-) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 371f7b0865..e1ddd2dd96 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -95,15 +95,11 @@ class FirecrawlApp: if response.status_code == 200: crawl_status_response = response.json() if crawl_status_response.get("status") == "completed": - total = crawl_status_response.get("total", 0) - if total == 0: + # Normalize to avoid None bypassing the zero-guard when the API returns null. + total = crawl_status_response.get("total") or 0 + if total <= 0: raise Exception("Failed to check crawl status. Error: No page found") - data = crawl_status_response.get("data", []) - url_data_list: list[FirecrawlDocumentData] = [] - for item in data: - if isinstance(item, dict) and "metadata" in item and "markdown" in item: - url_data = self._extract_common_fields(item) - url_data_list.append(url_data) + url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers) if url_data_list: file_key = "website_files/" + job_id + ".txt" try: @@ -120,6 +116,36 @@ class FirecrawlApp: self._handle_error(response, "check crawl status") raise RuntimeError("unreachable: _handle_error always raises") + def _collect_all_crawl_pages( + self, first_page: dict[str, Any], headers: dict[str, str] + ) -> list[FirecrawlDocumentData]: + """Collect all crawl result pages by following pagination links. + + Raises an exception if any paginated request fails, to avoid returning + partial data that is inconsistent with the reported total. + + The number of pages processed is capped at ``total`` (the + server-reported page count) to guard against infinite loops caused by + a misbehaving server that keeps returning a ``next`` URL. + """ + total: int = first_page.get("total") or 0 + url_data_list: list[FirecrawlDocumentData] = [] + current_page = first_page + pages_processed = 0 + while True: + for item in current_page.get("data", []): + if isinstance(item, dict) and "metadata" in item and "markdown" in item: + url_data_list.append(self._extract_common_fields(item)) + next_url: str | None = current_page.get("next") + pages_processed += 1 + if not next_url or pages_processed >= total: + break + response = self._get_request(next_url, headers) + if response.status_code != 200: + self._handle_error(response, "fetch next crawl page") + current_page = response.json() + return url_data_list + def _format_crawl_status_response( self, status: str, diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 2add12fd09..db49221583 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -164,6 +164,13 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="No page found"): app.check_crawl_status("job-1") + def test_check_crawl_status_completed_with_null_total_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": None, "data": []})) + + with pytest.raises(Exception, match="No page found"): + app.check_crawl_status("job-1") + def test_check_crawl_status_non_completed(self, mocker: MockerFixture): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") payload = {"status": "processing", "total": 5, "completed": 1, "data": []} @@ -203,6 +210,77 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="Error saving crawl data"): app.check_crawl_status("job-err") + def test_check_crawl_status_follows_pagination(self, mocker: MockerFixture): + """When status is completed and next is present, follow pagination to collect all pages.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + page2 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=2", + "data": [{"metadata": {"title": "p2", "description": "", "sourceURL": "https://p2"}, "markdown": "m2"}], + } + page3 = { + "status": "completed", + "total": 3, + "completed": 3, + "data": [{"metadata": {"title": "p3", "description": "", "sourceURL": "https://p3"}, "markdown": "m3"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(200, page2), _response(200, page3)]) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-42") + + assert result["status"] == "completed" + assert result["total"] == 3 + assert len(result["data"]) == 3 + assert [d["title"] for d in result["data"]] == ["p1", "p2", "p3"] + + def test_check_crawl_status_pagination_error_raises(self, mocker: MockerFixture): + """An error while fetching a paginated page raises an exception; no partial data is returned.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 2, + "completed": 2, + "next": "https://custom.firecrawl.dev/v2/crawl/job-99?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(500, {"error": "server error"})]) + + with pytest.raises(Exception, match="fetch next crawl page"): + app.check_crawl_status("job-99") + + def test_check_crawl_status_pagination_capped_at_total(self, mocker: MockerFixture): + """Pagination stops once pages_processed reaches total, even if next is present.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + # total=1: only the first page should be processed; next must not be followed + page1 = { + "status": "completed", + "total": 1, + "completed": 1, + "next": "https://custom.firecrawl.dev/v2/crawl/job-cap?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mock_get = mocker.patch("httpx.get", return_value=_response(200, page1)) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-cap") + + assert len(result["data"]) == 1 + mock_get.assert_called_once() # initial fetch only; next URL is not followed due to cap + def test_extract_common_fields_and_status_formatter(self): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") From 29cff809b9a6726337a81a43c520b410581e2490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Baki=20Burak=20=C3=96=C4=9F=C3=BCn?= <63836730+bakiburakogun@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:19:53 +0300 Subject: [PATCH 09/34] fix(i18n): comprehensive Turkish (tr-TR) translation fixes and missing keys (#33885) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: bakiburakogun Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Baki Burak Öğün --- web/i18n/tr-TR/app-annotation.json | 2 +- web/i18n/tr-TR/app-api.json | 10 +-- web/i18n/tr-TR/app-debug.json | 35 +++++---- web/i18n/tr-TR/app-log.json | 6 +- web/i18n/tr-TR/app-overview.json | 6 +- web/i18n/tr-TR/app.json | 40 +++++----- web/i18n/tr-TR/billing.json | 8 +- web/i18n/tr-TR/common.json | 94 +++++++++++++++------- web/i18n/tr-TR/dataset-creation.json | 4 +- web/i18n/tr-TR/dataset-documents.json | 6 +- web/i18n/tr-TR/dataset-hit-testing.json | 2 +- web/i18n/tr-TR/dataset-pipeline.json | 6 +- web/i18n/tr-TR/dataset.json | 38 ++++----- web/i18n/tr-TR/login.json | 6 +- web/i18n/tr-TR/pipeline.json | 8 +- web/i18n/tr-TR/plugin-tags.json | 4 +- web/i18n/tr-TR/plugin.json | 100 +++++++++++++----------- web/i18n/tr-TR/time.json | 24 +++--- web/i18n/tr-TR/tools.json | 8 +- web/i18n/tr-TR/workflow.json | 65 ++++++++------- 20 files changed, 267 insertions(+), 205 deletions(-) diff --git a/web/i18n/tr-TR/app-annotation.json b/web/i18n/tr-TR/app-annotation.json index a4b5a869d2..a370fae561 100644 --- a/web/i18n/tr-TR/app-annotation.json +++ b/web/i18n/tr-TR/app-annotation.json @@ -25,7 +25,7 @@ "batchModal.tip": "CSV dosyası aşağıdaki yapıya uygun olmalıdır:", "batchModal.title": "Toplu İçe Aktarma", "editBy": "{{author}} tarafından düzenlendi", - "editModal.answerName": "Storyteller Bot", + "editModal.answerName": "Hikaye Anlatıcı Bot", "editModal.answerPlaceholder": "Cevabınızı buraya yazın", "editModal.createdAt": "Oluşturulma Tarihi", "editModal.queryName": "Kullanıcı Sorgusu", diff --git a/web/i18n/tr-TR/app-api.json b/web/i18n/tr-TR/app-api.json index cfdf56268c..c381ecd54e 100644 --- a/web/i18n/tr-TR/app-api.json +++ b/web/i18n/tr-TR/app-api.json @@ -55,10 +55,10 @@ "copied": "Kopyalandı", "copy": "Kopyala", "develop.noContent": "İçerik yok", - "develop.pathParams": "Path Params", - "develop.query": "Query", - "develop.requestBody": "Request Body", - "develop.toc": "Içeriği", + "develop.pathParams": "Yol Parametreleri", + "develop.query": "Sorgu", + "develop.requestBody": "İstek Gövdesi", + "develop.toc": "İçindekiler", "disabled": "Devre Dışı", "loading": "Yükleniyor", "merMaid.rerender": "Yeniden İşleme", @@ -67,6 +67,6 @@ "pause": "Duraklat", "play": "Oynat", "playing": "Oynatılıyor", - "regenerate": "Yenilemek", + "regenerate": "Yeniden Oluştur", "status": "Durum" } diff --git a/web/i18n/tr-TR/app-debug.json b/web/i18n/tr-TR/app-debug.json index 1ae01dca6c..1eb5f83d9b 100644 --- a/web/i18n/tr-TR/app-debug.json +++ b/web/i18n/tr-TR/app-debug.json @@ -1,33 +1,33 @@ { - "agent.agentMode": "Agent Modu", - "agent.agentModeDes": "Agent için çıkarım modunu ayarlayın", + "agent.agentMode": "Ajan Modu", + "agent.agentModeDes": "Ajan için çıkarım modunu ayarlayın", "agent.agentModeType.ReACT": "ReAct", "agent.agentModeType.functionCall": "Fonksiyon Çağrısı", "agent.buildInPrompt": "Yerleşik Prompt", "agent.firstPrompt": "İlk Prompt", "agent.nextIteration": "Sonraki Yineleme", "agent.promptPlaceholder": "Promptunuzu buraya yazın", - "agent.setting.description": "Agent Asistanı ayarları, Agent modunu ve yerleşik promptlar gibi gelişmiş özellikleri ayarlamanıza olanak tanır. Sadece Agent türünde kullanılabilir.", - "agent.setting.maximumIterations.description": "Bir Agent asistanının gerçekleştirebileceği yineleme sayısını sınırlayın", + "agent.setting.description": "Ajan Asistanı ayarları, Ajan modunu ve yerleşik promptlar gibi gelişmiş özellikleri ayarlamanıza olanak tanır. Sadece Ajan türünde kullanılabilir.", + "agent.setting.maximumIterations.description": "Bir Ajan asistanının gerçekleştirebileceği yineleme sayısını sınırlayın", "agent.setting.maximumIterations.name": "Maksimum Yineleme", - "agent.setting.name": "Agent Ayarları", + "agent.setting.name": "Ajan Ayarları", "agent.tools.description": "Araçlar kullanmak, internette arama yapmak veya bilimsel hesaplamalar yapmak gibi LLM yeteneklerini genişletebilir", "agent.tools.enabled": "Etkinleştirildi", "agent.tools.name": "Araçlar", - "assistantType.agentAssistant.description": "Görevleri tamamlamak için araçları özerk bir şekilde seçebilen bir zeki Agent oluşturun", - "assistantType.agentAssistant.name": "Agent Asistanı", + "assistantType.agentAssistant.description": "Görevleri tamamlamak için araçları özerk bir şekilde seçebilen bir zeki Ajan oluşturun", + "assistantType.agentAssistant.name": "Ajan Asistanı", "assistantType.chatAssistant.description": "Büyük Dil Modeli kullanarak sohbet tabanlı bir asistan oluşturun", "assistantType.chatAssistant.name": "Temel Asistan", "assistantType.name": "Asistan Türü", "autoAddVar": "Ön promptta referans verilen tanımlanmamış değişkenler, kullanıcı giriş formunda eklemek istiyor musunuz?", "chatSubTitle": "Talimatlar", "code.instruction": "Talimat", - "codegen.apply": "Uygulamak", + "codegen.apply": "Uygula", "codegen.applyChanges": "Değişiklikleri Uygula", "codegen.description": "Kod Oluşturucu, talimatlarınıza göre yüksek kaliteli kod oluşturmak için yapılandırılmış modelleri kullanır. Lütfen açık ve ayrıntılı talimatlar verin.", - "codegen.generate": "Oluşturmak", + "codegen.generate": "Oluştur", "codegen.generatedCodeTitle": "Oluşturulan Kod", - "codegen.instruction": "Talimat -ları", + "codegen.instruction": "Talimatlar", "codegen.instructionPlaceholder": "Oluşturmak istediğiniz kodun ayrıntılı açıklamasını girin.", "codegen.loading": "Kod oluşturuluyor...", "codegen.noDataLine1": "Solda kullanım durumunuzu açıklayın,", @@ -40,11 +40,11 @@ "datasetConfig.embeddingModelRequired": "Yapılandırılmış bir Gömme Modeli gereklidir", "datasetConfig.knowledgeTip": "Bilgi eklemek için “+” düğmesine tıklayın", "datasetConfig.params": "Parametreler", - "datasetConfig.rerankModelRequired": "Rerank modeli gereklidir", + "datasetConfig.rerankModelRequired": "Yeniden Sıralama modeli gereklidir", "datasetConfig.retrieveChangeTip": "Dizin modunu ve geri alım modunu değiştirmek, bu Bilgi ile ilişkili uygulamaları etkileyebilir.", "datasetConfig.retrieveMultiWay.description": "Kullanıcı niyetine dayanarak, tüm Bilgilerde sorgular, çoklu kaynaklardan ilgili metni alır ve yeniden sıraladıktan sonra kullanıcı sorgusuyla eşleşen en iyi sonuçları seçer.", "datasetConfig.retrieveMultiWay.title": "Çoklu yol geri alım", - "datasetConfig.retrieveOneWay.description": "Kullanıcı niyetine ve Bilgi tanımına dayanarak, Agent en iyi Bilgi'yi sorgulamak için özerk bir şekilde seçer. Belirgin, sınırlı Bilgi bulunan uygulamalar için en iyisidir.", + "datasetConfig.retrieveOneWay.description": "Kullanıcı niyetine ve Bilgi tanımına dayanarak, Ajan en iyi Bilgi'yi sorgulamak için özerk bir şekilde seçer. Belirgin, sınırlı Bilgi bulunan uygulamalar için en iyisidir.", "datasetConfig.retrieveOneWay.title": "N-to-1 geri alım", "datasetConfig.score_threshold": "Skor Eşiği", "datasetConfig.score_thresholdTip": "Parça filtreleme için benzerlik eşiğini ayarlamak için kullanılır.", @@ -235,11 +235,16 @@ "inputs.run": "ÇALIŞTIR", "inputs.title": "Hata ayıklama ve Önizleme", "inputs.userInputField": "Kullanıcı Giriş Alanı", + "manageModels": "Modelleri yönet", "modelConfig.modeType.chat": "Sohbet", "modelConfig.modeType.completion": "Tamamlama", "modelConfig.model": "Model", "modelConfig.setTone": "Yanıtların tonunu ayarla", "modelConfig.title": "Model ve Parametreler", + "noModelProviderConfigured": "Yapılandırılmış model sağlayıcı yok", + "noModelProviderConfiguredTip": "Başlamak için bir model sağlayıcı yükleyin veya yapılandırın.", + "noModelSelected": "Model seçilmedi", + "noModelSelectedTip": "Devam etmek için yukarıdan bir model yapılandırın.", "noResult": "Çıktı burada görüntülenecektir.", "notSetAPIKey.description": "LLM sağlayıcı anahtarı ayarlanmadı, hata ayıklamadan önce ayarlanması gerekiyor.", "notSetAPIKey.settingBtn": "Ayarlar'a git", @@ -267,12 +272,12 @@ "operation.resetConfig": "Sıfırla", "operation.stopResponding": "Yanıtlamayı Durdur", "operation.userAction": "Kullanıcı", - "orchestrate": "Orchestrate", + "orchestrate": "Düzenle", "otherError.historyNoBeEmpty": "Konuşma geçmişi prompt'ta ayarlanmalıdır", "otherError.promptNoBeEmpty": "Prompt boş olamaz", "otherError.queryNoBeEmpty": "Sorgu prompt'ta ayarlanmalıdır", "pageTitle.line1": "PROMPT", - "pageTitle.line2": "Engineering", + "pageTitle.line2": "Mühendisliği", "promptMode.advanced": "Uzman Modu", "promptMode.advancedWarning.description": "Uzman Modunda, tüm PROMPT'u düzenleyebilirsiniz.", "promptMode.advancedWarning.learnMore": "Daha Fazla Bilgi", @@ -320,7 +325,7 @@ "variableConfig.file.image.name": "Resim", "variableConfig.file.supportFileTypes": "Destek Dosya Türleri", "variableConfig.file.video.name": "Video", - "variableConfig.hide": "Gizlemek", + "variableConfig.hide": "Gizle", "variableConfig.inputPlaceholder": "Lütfen girin", "variableConfig.json": "JSON Kodu", "variableConfig.jsonSchema": "JSON Şeması", diff --git a/web/i18n/tr-TR/app-log.json b/web/i18n/tr-TR/app-log.json index 6596db8be8..da615534f4 100644 --- a/web/i18n/tr-TR/app-log.json +++ b/web/i18n/tr-TR/app-log.json @@ -1,6 +1,6 @@ { - "agentLog": "Agent Günlüğü", - "agentLogDetail.agentMode": "Agent Modu", + "agentLog": "Ajan Günlüğü", + "agentLogDetail.agentMode": "Ajan Modu", "agentLogDetail.finalProcessing": "Son İşleme", "agentLogDetail.iteration": "Yineleme", "agentLogDetail.iterations": "Yinelemeler", @@ -80,5 +80,5 @@ "triggerBy.webhook": "Webhook", "viewLog": "Günlüğü Görüntüle", "workflowSubtitle": "Günlük, Automate'in çalışmasını kaydetmiştir.", - "workflowTitle": "Workflow Günlükleri" + "workflowTitle": "İş Akışı Günlükleri" } diff --git a/web/i18n/tr-TR/app-overview.json b/web/i18n/tr-TR/app-overview.json index 50d9a27f9f..e48c6c6fe7 100644 --- a/web/i18n/tr-TR/app-overview.json +++ b/web/i18n/tr-TR/app-overview.json @@ -66,7 +66,7 @@ "overview.appInfo.preUseReminder": "Devam etmeden önce web app'i etkinleştirin.", "overview.appInfo.preview": "Önizleme", "overview.appInfo.qrcode.download": "QR Kodu İndir", - "overview.appInfo.qrcode.scan": "Paylaşmak İçin Taramak", + "overview.appInfo.qrcode.scan": "Paylaşmak İçin Tara", "overview.appInfo.qrcode.title": "Bağlantı QR Kodu", "overview.appInfo.regenerate": "Yeniden Oluştur", "overview.appInfo.regenerateNotice": "Genel URL'yi yeniden oluşturmak istiyor musunuz?", @@ -102,11 +102,11 @@ "overview.appInfo.settings.workflow.show": "Göster", "overview.appInfo.settings.workflow.showDesc": "web app'te iş akışı ayrıntılarını gösterme veya gizleme", "overview.appInfo.settings.workflow.subTitle": "İş Akışı Detayları", - "overview.appInfo.settings.workflow.title": "Workflow Adımları", + "overview.appInfo.settings.workflow.title": "İş Akışı Adımları", "overview.appInfo.title": "Web Uygulaması", "overview.disableTooltip.triggerMode": "Trigger Düğümü modunda {{feature}} özelliği desteklenmiyor.", "overview.status.disable": "Devre Dışı", - "overview.status.running": "Hizmette", + "overview.status.running": "Çalışıyor", "overview.title": "Genel Bakış", "overview.triggerInfo.explanation": "İş akışı tetikleyici yönetimi", "overview.triggerInfo.learnAboutTriggers": "Tetikleyiciler hakkında bilgi edinin", diff --git a/web/i18n/tr-TR/app.json b/web/i18n/tr-TR/app.json index af6c5bdcd9..bf17583c47 100644 --- a/web/i18n/tr-TR/app.json +++ b/web/i18n/tr-TR/app.json @@ -77,8 +77,8 @@ "gotoAnything.actions.themeLightDesc": "Aydınlık görünüm kullan", "gotoAnything.actions.themeSystem": "Sistem Teması", "gotoAnything.actions.themeSystemDesc": "İşletim sisteminizin görünümünü takip edin", - "gotoAnything.actions.zenDesc": "Toggle canvas focus mode", - "gotoAnything.actions.zenTitle": "Zen Mode", + "gotoAnything.actions.zenDesc": "Tuval odak modunu aç/kapat", + "gotoAnything.actions.zenTitle": "Zen Modu", "gotoAnything.clearToSearchAll": "Tümünü aramak için @ işaretini kaldırın", "gotoAnything.commandHint": "Kategoriye göre göz atmak için @ yazın", "gotoAnything.emptyState.noAppsFound": "Uygulama bulunamadı", @@ -129,11 +129,11 @@ "mermaid.classic": "Klasik", "mermaid.handDrawn": "Elle çizilmiş", "newApp.Cancel": "İptal", - "newApp.Confirm": "Onaylamak", + "newApp.Confirm": "Onayla", "newApp.Create": "Oluştur", "newApp.advancedShortDescription": "Çok turlu sohbetler için geliştirilmiş iş akışı", "newApp.advancedUserDescription": "Ek bellek özellikleri ve sohbet robotu arayüzü ile iş akışı.", - "newApp.agentAssistant": "Yeni Agent Asistanı", + "newApp.agentAssistant": "Yeni Ajan Asistanı", "newApp.agentShortDescription": "Akıl yürütme ve otonom araç kullanımına sahip akıllı ajan", "newApp.agentUserDescription": "Görev hedeflerine ulaşmak için yinelemeli akıl yürütme ve otonom araç kullanımı yeteneğine sahip akıllı bir ajan.", "newApp.appCreateDSLErrorPart1": "DSL sürümlerinde önemli bir fark tespit edildi. İçe aktarmayı zorlamak, uygulamanın hatalı çalışmasına neden olabilir.", @@ -161,12 +161,12 @@ "newApp.completionShortDescription": "Metin oluşturma görevleri için yapay zeka asistanı", "newApp.completionUserDescription": "Basit yapılandırmayla metin oluşturma görevleri için hızlı bir şekilde bir yapay zeka asistanı oluşturun.", "newApp.dropDSLToCreateApp": "Uygulama oluşturmak için DSL dosyasını buraya bırakın", - "newApp.forAdvanced": "İLERI DÜZEY KULLANICILAR IÇIN", + "newApp.forAdvanced": "İLERİ DÜZEY KULLANICILAR İÇİN", "newApp.forBeginners": "Daha temel uygulama türleri", "newApp.foundResult": "{{count}} Sonuç", - "newApp.foundResults": "{{count}} Sonuç -ları", + "newApp.foundResults": "{{count}} Sonuç", "newApp.hideTemplates": "Mod seçim ekranına geri dön", - "newApp.import": "Ithalat", + "newApp.import": "İçe Aktar", "newApp.learnMore": "Daha fazla bilgi edinin", "newApp.nameNotEmpty": "İsim boş olamaz", "newApp.noAppsFound": "Uygulama bulunamadı", @@ -182,7 +182,7 @@ "newApp.workflowShortDescription": "Akıllı otomasyonlar için ajantik akış", "newApp.workflowUserDescription": "Sürükle-bırak kolaylığıyla görsel olarak otonom yapay zeka iş akışları oluşturun.", "newApp.workflowWarning": "Şu anda beta aşamasında", - "newAppFromTemplate.byCategories": "KATEGORILERE GÖRE", + "newAppFromTemplate.byCategories": "KATEGORİLERE GÖRE", "newAppFromTemplate.searchAllTemplate": "Tüm şablonlarda ara...", "newAppFromTemplate.sidebar.Agent": "Aracı", "newAppFromTemplate.sidebar.Assistant": "Asistan", @@ -210,12 +210,12 @@ "structOutput.required": "Gerekli", "structOutput.structured": "Yapılandırılmış", "structOutput.structuredTip": "Yapılandırılmış Çıktılar, modelin sağladığınız JSON Şemasına uyacak şekilde her zaman yanıtlar üretmesini sağlayan bir özelliktir.", - "switch": "Workflow Orkestrasyonuna Geç", + "switch": "İş Akışı Orkestrasyonuna Geç", "switchLabel": "Oluşturulacak uygulama kopyası", "switchStart": "Geçişi Başlat", "switchTip": "izin vermeyecek", "switchTipEnd": " Temel Orkestrasyona geri dönmek.", - "switchTipStart": "Sizin için yeni bir uygulama kopyası oluşturulacak ve yeni kopya Workflow Orkestrasyonuna geçecektir. Yeni kopya ", + "switchTipStart": "Sizin için yeni bir uygulama kopyası oluşturulacak ve yeni kopya İş Akışı Orkestrasyonuna geçecektir. Yeni kopya ", "theme.switchDark": "Koyu tema geçiş yap", "theme.switchLight": "Aydınlık tema'ya geç", "tracing.aliyun.description": "Alibaba Cloud tarafından sağlanan tamamen yönetilen ve bakım gerektirmeyen gözlemleme platformu, Dify uygulamalarının kutudan çıkar çıkmaz izlenmesi, takip edilmesi ve değerlendirilmesine olanak tanır.", @@ -248,7 +248,7 @@ "tracing.description": "Üçüncü taraf LLMOps sağlayıcısını yapılandırma ve uygulama performansını izleme.", "tracing.disabled": "Devre Dışı", "tracing.disabledTip": "Lütfen önce sağlayıcıyı yapılandırın", - "tracing.enabled": "Hizmette", + "tracing.enabled": "Etkin", "tracing.expand": "Genişlet", "tracing.inUse": "Kullanımda", "tracing.langfuse.description": "LLM uygulamanızı hata ayıklamak ve geliştirmek için izlemeler, değerlendirmeler, prompt yönetimi ve metrikler.", @@ -258,7 +258,7 @@ "tracing.mlflow.description": "Deney takibi, gözlemlenebilirlik ve değerlendirme için açık kaynaklı LLMOps platformu, AI/LLM uygulamalarını güvenle oluşturmak için.", "tracing.mlflow.title": "MLflow", "tracing.opik.description": "Opik, LLM uygulamalarını değerlendirmek, test etmek ve izlemek için açık kaynaklı bir platformdur.", - "tracing.opik.title": "Opik Belediyesi", + "tracing.opik.title": "Opik", "tracing.phoenix.description": "LLM iş akışlarınız ve ajanlarınız için açık kaynaklı ve OpenTelemetry tabanlı gözlemlenebilirlik, değerlendirme, istem mühendisliği ve deney platformu.", "tracing.phoenix.title": "Phoenix", "tracing.tencent.description": "Tencent Uygulama Performans İzleme, LLM uygulamaları için kapsamlı izleme ve çok boyutlu analiz sağlar.", @@ -266,20 +266,20 @@ "tracing.title": "Uygulama performansını izleme", "tracing.tracing": "İzleme", "tracing.tracingDescription": "Uygulama yürütmesinin tam bağlamını, LLM çağrıları, bağlam, promptlar, HTTP istekleri ve daha fazlası dahil olmak üzere üçüncü taraf izleme platformuna yakalama.", - "tracing.view": "Görünüm", + "tracing.view": "Görüntüle", "tracing.weave.description": "Weave, LLM uygulamalarını değerlendirmek, test etmek ve izlemek için açık kaynaklı bir platformdur.", - "tracing.weave.title": "Dokuma", + "tracing.weave.title": "Weave", "typeSelector.advanced": "Sohbet akışı", - "typeSelector.agent": "Agent", - "typeSelector.all": "All Types", + "typeSelector.agent": "Ajan", + "typeSelector.all": "Tüm Türler", "typeSelector.chatbot": "Chatbot", - "typeSelector.completion": "Completion", - "typeSelector.workflow": "Workflow", + "typeSelector.completion": "Tamamlama", + "typeSelector.workflow": "İş Akışı", "types.advanced": "Sohbet akışı", - "types.agent": "Agent", + "types.agent": "Ajan", "types.all": "Hepsi", "types.basic": "Temel", "types.chatbot": "Chatbot", "types.completion": "Tamamlama", - "types.workflow": "Workflow" + "types.workflow": "İş Akışı" } diff --git a/web/i18n/tr-TR/billing.json b/web/i18n/tr-TR/billing.json index b780045768..2d85bed3b3 100644 --- a/web/i18n/tr-TR/billing.json +++ b/web/i18n/tr-TR/billing.json @@ -88,7 +88,7 @@ "plansCommon.documentsTooltip": "Bilgi Veri Kaynağından ithal edilen belge sayısına kota.", "plansCommon.free": "Ücretsiz", "plansCommon.freeTrialTip": "200 OpenAI çağrısının ücretsiz denemesi.", - "plansCommon.freeTrialTipPrefix": "Kaydolun ve bir", + "plansCommon.freeTrialTipPrefix": "Kaydolun ve bir ", "plansCommon.freeTrialTipSuffix": "Kredi kartı gerekmez", "plansCommon.getStarted": "Başlayın", "plansCommon.logsHistory": "{{days}} günlük geçmişi", @@ -97,7 +97,7 @@ "plansCommon.messageRequest.title": "{{count,number}} mesaj kredisi", "plansCommon.messageRequest.titlePerMonth": "{{count,number}} mesaj/ay", "plansCommon.messageRequest.tooltip": "OpenAI modellerini (gpt4 hariç) kullanarak çeşitli planlar için mesaj çağrı kotaları. Limitin üzerindeki mesajlar OpenAI API Anahtarınızı kullanır.", - "plansCommon.modelProviders": "Model Sağlayıcılar", + "plansCommon.modelProviders": "OpenAI/Anthropic/Llama2/Azure OpenAI/Hugging Face/Replicate Desteği", "plansCommon.month": "ay", "plansCommon.mostPopular": "En Popüler", "plansCommon.planRange.monthly": "Aylık", @@ -116,7 +116,7 @@ "plansCommon.startNodes.unlimited": "Sınırsız Tetikleyiciler/iş akışı", "plansCommon.support": "Destek", "plansCommon.supportItems.SSOAuthentication": "SSO kimlik doğrulama", - "plansCommon.supportItems.agentMode": "Agent Modu", + "plansCommon.supportItems.agentMode": "Ajan Modu", "plansCommon.supportItems.bulkUpload": "Toplu doküman yükleme", "plansCommon.supportItems.communityForums": "Topluluk forumları", "plansCommon.supportItems.customIntegration": "Özel entegrasyon ve destek", @@ -128,7 +128,7 @@ "plansCommon.supportItems.personalizedSupport": "Kişiselleştirilmiş destek", "plansCommon.supportItems.priorityEmail": "Öncelikli e-posta ve sohbet desteği", "plansCommon.supportItems.ragAPIRequest": "RAG API Talepleri", - "plansCommon.supportItems.workflow": "Workflow", + "plansCommon.supportItems.workflow": "İş Akışı", "plansCommon.talkToSales": "Satışlarla Konuşun", "plansCommon.taxTip": "Tüm abonelik fiyatları (aylık/yıllık) geçerli vergiler (ör. KDV, satış vergisi) hariçtir.", "plansCommon.taxTipSecond": "Bölgenizde geçerli vergi gereksinimleri yoksa, ödeme sayfanızda herhangi bir vergi görünmeyecek ve tüm abonelik süresi boyunca ek bir ücret tahsil edilmeyecektir.", diff --git a/web/i18n/tr-TR/common.json b/web/i18n/tr-TR/common.json index 9aeb24cd1e..66e895fd2b 100644 --- a/web/i18n/tr-TR/common.json +++ b/web/i18n/tr-TR/common.json @@ -96,7 +96,7 @@ "appMenus.logAndAnn": "Günlükler & Anlamlandırmalar", "appMenus.logs": "Günlükler", "appMenus.overview": "İzleme", - "appMenus.promptEng": "Orchestrate", + "appMenus.promptEng": "Düzenle", "appModes.chatApp": "Sohbet Uygulaması", "appModes.completionApp": "Metin Üreteci", "avatar.deleteDescription": "Profil resminizi kaldırmak istediğinize emin misiniz? Hesabınız varsayılan başlangıç avatarını kullanacaktır.", @@ -114,7 +114,7 @@ "chat.inputPlaceholder": "{{botName}} ile konuş", "chat.renameConversation": "Konuşmayı Yeniden Adlandır", "chat.resend": "Yeniden gönder", - "chat.thinking": "Düşünü...", + "chat.thinking": "Düşünüyor...", "chat.thought": "Düşünce", "compliance.gdpr": "GDPR DPA", "compliance.iso27001": "ISO 27001:2022 Sertifikası", @@ -178,7 +178,7 @@ "fileUploader.uploadFromComputerLimit": "{{type}} yüklemesi {{size}}'ı aşamaz", "fileUploader.uploadFromComputerReadError": "Dosya okuma başarısız oldu, lütfen tekrar deneyin.", "fileUploader.uploadFromComputerUploadError": "Dosya yükleme başarısız oldu, lütfen tekrar yükleyin.", - "imageInput.browse": "tarayıcı", + "imageInput.browse": "göz atın", "imageInput.dropImageHere": "Görüntünüzü buraya bırakın veya", "imageInput.supportedFormats": "PNG, JPG, JPEG, WEBP ve GIF'i destekler", "imageUploader.imageUpload": "Görüntü Yükleme", @@ -265,7 +265,7 @@ "menus.datasets": "Bilgi", "menus.datasetsTips": "YAKINDA: Kendi metin verilerinizi içe aktarın veya LLM bağlamını geliştirmek için Webhook aracılığıyla gerçek zamanlı veri yazın.", "menus.explore": "Keşfet", - "menus.exploreMarketplace": "Marketplace'i Keşfedin", + "menus.exploreMarketplace": "Pazar Yeri'ni Keşfedin", "menus.newApp": "Yeni Uygulama", "menus.newDataset": "Bilgi Oluştur", "menus.plugins": "Eklentiler", @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Yetkisiz", "modelProvider.buyQuota": "Kota Satın Al", "modelProvider.callTimes": "Çağrı Süreleri", + "modelProvider.card.aiCreditsInUse": "Yapay zeka kredileri kullanımda", + "modelProvider.card.aiCreditsOption": "Yapay zeka kredileri", + "modelProvider.card.apiKeyOption": "API Anahtarı", + "modelProvider.card.apiKeyRequired": "API anahtarı gerekli", + "modelProvider.card.apiKeyUnavailableFallback": "API Anahtarı kullanılamıyor, şimdi yapay zeka kredileri kullanılıyor", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Geri dönmek için API anahtarı yapılandırmanızı kontrol edin", "modelProvider.card.buyQuota": "Kota Satın Al", "modelProvider.card.callTimes": "Çağrı Süreleri", + "modelProvider.card.creditsExhaustedDescription": "Lütfen planınızı yükseltin veya bir API anahtarı yapılandırın", + "modelProvider.card.creditsExhaustedFallback": "Yapay zeka kredileri tükendi, şimdi API anahtarı kullanılıyor", + "modelProvider.card.creditsExhaustedFallbackDescription": "Yapay zeka kredisi önceliğini sürdürmek için planınızı yükseltin.", + "modelProvider.card.creditsExhaustedMessage": "Yapay zeka kredileri tükendi", "modelProvider.card.modelAPI": "{{modelName}} modelleri API Anahtarını kullanıyor.", "modelProvider.card.modelNotSupported": "{{modelName}} modelleri kurulu değil.", "modelProvider.card.modelSupported": "{{modelName}} modelleri bu kotayı kullanıyor.", + "modelProvider.card.noApiKeysDescription": "Kendi model kimlik bilgilerinizi kullanmaya başlamak için bir API anahtarı ekleyin.", + "modelProvider.card.noApiKeysFallback": "API anahtarı yok, bunun yerine yapay zeka kredileri kullanılıyor", + "modelProvider.card.noApiKeysTitle": "Henüz API anahtarı yapılandırılmadı", + "modelProvider.card.noAvailableUsage": "Kullanılabilir kullanım yok", "modelProvider.card.onTrial": "Deneme Sürümünde", "modelProvider.card.paid": "Ücretli", "modelProvider.card.priorityUse": "Öncelikli Kullan", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "API Anahtarını Kaldır", "modelProvider.card.tip": "Mesaj kredileri {{modelNames}}'den modelleri destekler. Öncelik ücretli kotaya verilecektir. Ücretsiz kota, ücretli kota tükendiğinde kullanılacaktır.", "modelProvider.card.tokens": "Tokenler", + "modelProvider.card.unavailable": "Kullanılamaz", + "modelProvider.card.upgradePlan": "planınızı yükseltin", + "modelProvider.card.usageLabel": "Kullanım", + "modelProvider.card.usagePriority": "Kullanım Önceliği", + "modelProvider.card.usagePriorityTip": "Modelleri çalıştırırken önce hangi kaynağın kullanılacağını belirleyin.", "modelProvider.collapse": "Daralt", "modelProvider.config": "Yapılandır", "modelProvider.configLoadBalancing": "Yük Dengelemeyi Yapılandır", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model ve Parametreler", "modelProvider.modelHasBeenDeprecated": "Bu model kullanım dışıdır", + "modelProvider.modelSettings": "Model Ayarları", "modelProvider.models": "Modeller", "modelProvider.modelsNum": "{{num}} Model", "modelProvider.noModelFound": "{{model}} için model bulunamadı", + "modelProvider.noneConfigured": "Uygulamaları çalıştırmak için varsayılan bir sistem modeli yapılandırın", "modelProvider.notConfigured": "Sistem modeli henüz tam olarak yapılandırılmadı ve bazı işlevler kullanılamayabilir.", "modelProvider.parameters": "PARAMETRELER", "modelProvider.parametersInvalidRemoved": "Bazı parametreler geçersizdir ve kaldırılmıştır.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "{{date}} tarihinde sıfırla", "modelProvider.searchModel": "Model ara", "modelProvider.selectModel": "Modelinizi seçin", + "modelProvider.selector.aiCredits": "Yapay zeka kredileri", + "modelProvider.selector.apiKeyUnavailable": "API Anahtarı kullanılamıyor", + "modelProvider.selector.apiKeyUnavailableTip": "API anahtarı kaldırıldı. Lütfen yeni bir API anahtarı yapılandırın.", + "modelProvider.selector.configure": "Yapılandır", + "modelProvider.selector.configureRequired": "Yapılandırma gerekli", + "modelProvider.selector.creditsExhausted": "Krediler tükendi", + "modelProvider.selector.creditsExhaustedTip": "Yapay zeka kredileriniz tükendi. Lütfen planınızı yükseltin veya bir API anahtarı ekleyin.", + "modelProvider.selector.disabled": "Devre Dışı", + "modelProvider.selector.discoverMoreInMarketplace": "Pazar Yeri'nde daha fazlasını keşfedin", "modelProvider.selector.emptySetting": "Lütfen ayarlara gidip yapılandırın", "modelProvider.selector.emptyTip": "Kullanılabilir model yok", + "modelProvider.selector.fromMarketplace": "Pazar Yeri'nden", + "modelProvider.selector.incompatible": "Uyumsuz", + "modelProvider.selector.incompatibleTip": "Bu model mevcut sürümde kullanılamıyor. Lütfen başka bir kullanılabilir model seçin.", + "modelProvider.selector.install": "Yükle", + "modelProvider.selector.modelProviderSettings": "Model Sağlayıcı Ayarları", + "modelProvider.selector.noProviderConfigured": "Yapılandırılmış model sağlayıcı yok", + "modelProvider.selector.noProviderConfiguredDesc": "Yüklemek için Pazar Yeri'ne göz atın veya ayarlardan sağlayıcıları yapılandırın.", + "modelProvider.selector.onlyCompatibleModelsShown": "Yalnızca uyumlu modeller gösterilir", "modelProvider.selector.rerankTip": "Lütfen Yeniden Sıralama modelini ayarlayın", "modelProvider.selector.tip": "Bu model kaldırıldı. Lütfen bir model ekleyin veya başka bir model seçin.", "modelProvider.setupModelFirst": "Lütfen önce modelinizi ayarlayın", @@ -427,11 +465,11 @@ "operation.cancel": "İptal", "operation.change": "Değiştir", "operation.clear": "Temizle", - "operation.close": "Kapatmak", - "operation.config": "Konfigürasyon", + "operation.close": "Kapat", + "operation.config": "Yapılandırma", "operation.confirm": "Onayla", "operation.confirmAction": "Lütfen işleminizi onaylayın.", - "operation.copied": "Kopya -lanan", + "operation.copied": "Kopyalandı", "operation.copy": "Kopyala", "operation.copyImage": "Resmi Kopyala", "operation.create": "Oluştur", @@ -463,7 +501,7 @@ "operation.openInNewTab": "Yeni sekmede aç", "operation.params": "Parametreler", "operation.refresh": "Yeniden Başlat", - "operation.regenerate": "Yenilemek", + "operation.regenerate": "Yeniden Oluştur", "operation.reload": "Yeniden Yükle", "operation.remove": "Kaldır", "operation.rename": "Yeniden Adlandır", @@ -480,10 +518,10 @@ "operation.send": "Gönder", "operation.settings": "Ayarlar", "operation.setup": "Kurulum", - "operation.skip": "Gemi", + "operation.skip": "Atla", "operation.submit": "Gönder", "operation.sure": "Eminim", - "operation.view": "Görünüm", + "operation.view": "Görüntüle", "operation.viewDetails": "Detayları Görüntüle", "operation.viewMore": "DAHA FAZLA GÖSTER", "operation.yes": "Evet", @@ -500,7 +538,7 @@ "promptEditor.context.item.title": "Bağlam", "promptEditor.context.modal.add": "Bağlam Ekle", "promptEditor.context.modal.footer": "Bağlamları aşağıdaki Bağlam bölümünde yönetebilirsiniz.", - "promptEditor.context.modal.title": "Bağlamda {{num}} Knowledge", + "promptEditor.context.modal.title": "Bağlamda {{num}} Bilgi", "promptEditor.existed": "Zaten prompt içinde mevcut", "promptEditor.history.item.desc": "Tarihi mesaj şablonunu ekle", "promptEditor.history.item.title": "Konuşma Geçmişi", @@ -585,7 +623,7 @@ "tag.selectorPlaceholder": "Aramak veya oluşturmak için yazın", "theme.auto": "sistem", "theme.dark": "koyu", - "theme.light": "ışık", + "theme.light": "açık", "theme.theme": "Tema", "toast.close": "Bildirimi kapat", "toast.notifications": "Bildirimler", @@ -605,27 +643,27 @@ "userProfile.support": "Destek", "userProfile.workspace": "Çalışma Alanı", "voice.language.arTN": "Tunus Arapçası", - "voice.language.deDE": "German", - "voice.language.enUS": "English", - "voice.language.esES": "Spanish", + "voice.language.deDE": "Almanca", + "voice.language.enUS": "İngilizce", + "voice.language.esES": "İspanyolca", "voice.language.faIR": "Farsça", - "voice.language.frFR": "French", + "voice.language.frFR": "Fransızca", "voice.language.hiIN": "Hintçe", - "voice.language.idID": "Indonesian", - "voice.language.itIT": "Italian", - "voice.language.jaJP": "Japanese", - "voice.language.koKR": "Korean", - "voice.language.plPL": "Polish", - "voice.language.ptBR": "Portuguese", + "voice.language.idID": "Endonezyaca", + "voice.language.itIT": "İtalyanca", + "voice.language.jaJP": "Japonca", + "voice.language.koKR": "Korece", + "voice.language.plPL": "Lehçe", + "voice.language.ptBR": "Portekizce", "voice.language.roRO": "Romence", - "voice.language.ruRU": "Russian", + "voice.language.ruRU": "Rusça", "voice.language.slSI": "Slovence", - "voice.language.thTH": "Thai", + "voice.language.thTH": "Tayca", "voice.language.trTR": "Türkçe", - "voice.language.ukUA": "Ukrainian", - "voice.language.viVN": "Vietnamese", - "voice.language.zhHans": "Chinese", - "voice.language.zhHant": "Traditional Chinese", + "voice.language.ukUA": "Ukraynaca", + "voice.language.viVN": "Vietnamca", + "voice.language.zhHans": "Çince", + "voice.language.zhHant": "Geleneksel Çince", "voiceInput.converting": "Metne dönüştürülüyor...", "voiceInput.notAllow": "mikrofon yetkilendirilmedi", "voiceInput.speaking": "Şimdi konuş...", diff --git a/web/i18n/tr-TR/dataset-creation.json b/web/i18n/tr-TR/dataset-creation.json index ab409664a2..81f09945c2 100644 --- a/web/i18n/tr-TR/dataset-creation.json +++ b/web/i18n/tr-TR/dataset-creation.json @@ -142,8 +142,8 @@ "stepTwo.previewChunk": "Önizleme Parçası", "stepTwo.previewChunkCount": "{{count}} Tahmini parçalar", "stepTwo.previewChunkTip": "Önizlemeyi yüklemek için soldaki 'Önizleme Parçası' düğmesini tıklayın", - "stepTwo.previewSwitchTipEnd": "token", - "stepTwo.previewSwitchTipStart": "Geçerli parça önizlemesi metin formatındadır, soru ve yanıt formatına geçiş ek tüketir", + "stepTwo.previewSwitchTipEnd": " token tüketecektir", + "stepTwo.previewSwitchTipStart": "Geçerli parça önizlemesi metin formatındadır, soru ve yanıt formatı önizlemesine geçiş ek", "stepTwo.previewTitle": "Önizleme", "stepTwo.previewTitleButton": "Önizleme", "stepTwo.previousStep": "Önceki adım", diff --git a/web/i18n/tr-TR/dataset-documents.json b/web/i18n/tr-TR/dataset-documents.json index 461ebe6d6d..c4516eaf22 100644 --- a/web/i18n/tr-TR/dataset-documents.json +++ b/web/i18n/tr-TR/dataset-documents.json @@ -300,12 +300,12 @@ "segment.collapseChunks": "Parçaları daraltma", "segment.contentEmpty": "İçerik boş olamaz", "segment.contentPlaceholder": "içeriği buraya ekleyin", - "segment.dateTimeFormat": "MM/DD/YYYY HH:mm", + "segment.dateTimeFormat": "DD/MM/YYYY HH:mm", "segment.delete": "Bu parçayı silmek istiyor musunuz?", "segment.editChildChunk": "Alt Parçayı Düzenle", "segment.editChunk": "Yığını Düzenle", "segment.editParentChunk": "Üst Parçayı Düzenle", - "segment.edited": "DÜZENLEN -MİŞ", + "segment.edited": "DÜZENLENMİŞ", "segment.editedAt": "Şurada düzenlendi:", "segment.empty": "Yığın bulunamadı", "segment.expandChunks": "Parçaları genişletme", @@ -331,7 +331,7 @@ "segment.regenerationSuccessMessage": "Bu pencereyi kapatabilirsiniz.", "segment.regenerationSuccessTitle": "Rejenerasyon tamamlandı", "segment.searchResults_one": "SONUÇ", - "segment.searchResults_other": "SONUÇ -LARI", + "segment.searchResults_other": "SONUÇLAR", "segment.searchResults_zero": "SONUÇ", "segment.summary": "ÖZET", "segment.summaryPlaceholder": "Daha iyi arama için kısa bir özet yazın…", diff --git a/web/i18n/tr-TR/dataset-hit-testing.json b/web/i18n/tr-TR/dataset-hit-testing.json index da09ffb03c..c7687a1ace 100644 --- a/web/i18n/tr-TR/dataset-hit-testing.json +++ b/web/i18n/tr-TR/dataset-hit-testing.json @@ -14,7 +14,7 @@ "input.placeholder": "Bir metin girin, kısa bir bildirim cümlesi önerilir.", "input.testing": "Test Ediliyor", "input.title": "Kaynak metin", - "keyword": "Anahtar kelime -ler", + "keyword": "Anahtar Kelimeler", "noRecentTip": "Burada son sorgu sonuçları yok", "open": "Açık", "records": "Kayıt", diff --git a/web/i18n/tr-TR/dataset-pipeline.json b/web/i18n/tr-TR/dataset-pipeline.json index fe48dcd7bb..3c617f3a27 100644 --- a/web/i18n/tr-TR/dataset-pipeline.json +++ b/web/i18n/tr-TR/dataset-pipeline.json @@ -67,8 +67,8 @@ "onlineDrive.notSupportedFileType": "Bu dosya türü desteklenmiyor", "onlineDrive.resetKeywords": "Anahtar kelimeleri sıfırlama", "operations.backToDataSource": "Veri Kaynağına Geri Dön", - "operations.choose": "Seçmek", - "operations.convert": "Dönüştürmek", + "operations.choose": "Seç", + "operations.convert": "Dönüştür", "operations.dataSource": "Veri Kaynağı", "operations.details": "Şey", "operations.editInfo": "Bilgileri düzenle", @@ -85,7 +85,7 @@ "publishTemplate.success.learnMore": "Daha fazla bilgi edinin", "publishTemplate.success.message": "İşlem hattı şablonu yayımlandı", "publishTemplate.success.tip": "Bu şablonu oluşturma sayfasında kullanabilirsiniz.", - "templates.customized": "Özel -leştirilmiş", + "templates.customized": "Özelleştirilmiş", "testRun.dataSource.localFiles": "Yerel Dosyalar", "testRun.notion.docTitle": "Kavram belgeleri", "testRun.notion.title": "Notion Sayfalarını Seçin", diff --git a/web/i18n/tr-TR/dataset.json b/web/i18n/tr-TR/dataset.json index 842fb7491b..f11fc9387c 100644 --- a/web/i18n/tr-TR/dataset.json +++ b/web/i18n/tr-TR/dataset.json @@ -1,14 +1,14 @@ { - "allExternalTip": "Yalnızca harici bilgileri kullanırken, kullanıcı Rerank modelinin etkinleştirilip etkinleştirilmeyeceğini seçebilir. Etkinleştirilmezse, alınan parçalar puanlara göre sıralanır. Farklı bilgi tabanlarının erişim stratejileri tutarsız olduğunda, yanlış olacaktır.", + "allExternalTip": "Yalnızca harici bilgileri kullanırken, kullanıcı Yeniden Sıralama modelinin etkinleştirilip etkinleştirilmeyeceğini seçebilir. Etkinleştirilmezse, alınan parçalar puanlara göre sıralanır. Farklı bilgi tabanlarının erişim stratejileri tutarsız olduğunda, yanlış olacaktır.", "allKnowledge": "Tüm Bilgiler", "allKnowledgeDescription": "Bu çalışma alanındaki tüm bilgileri görüntülemek için seçin. Yalnızca Çalışma Alanı Sahibi tüm bilgileri yönetebilir.", "appCount": " bağlı uygulamalar", "batchAction.archive": "Arşiv", "batchAction.cancel": "İptal", - "batchAction.delete": "Silmek", - "batchAction.disable": "Devre dışı bırakmak", + "batchAction.delete": "Sil", + "batchAction.disable": "Devre Dışı Bırak", "batchAction.download": "İndir", - "batchAction.enable": "Etkinleştirmek", + "batchAction.enable": "Etkinleştir", "batchAction.reIndex": "Yeniden dizinle", "batchAction.selected": "Seçilmiş", "chunkingMode.general": "Genel", @@ -32,7 +32,7 @@ "createDatasetIntro": "Kendi metin verilerinizi içe aktarın veya Webhook aracılığıyla gerçek zamanlı olarak veri yazın, LLM bağlamını geliştirin.", "createExternalAPI": "Harici bilgi API'si ekleme", "createFromPipeline": "Bilgi İşlem Hattından Oluşturun", - "createNewExternalAPI": "Yeni bir External Knowledge API oluşturma", + "createNewExternalAPI": "Yeni bir Harici Bilgi API'si oluşturma", "datasetDeleteFailed": "Bilgi silinemedi", "datasetDeleted": "Bilgi silindi", "datasetUsedByApp": "Bilgi bazı uygulamalar tarafından kullanılıyor. Uygulamalar artık bu Bilgiyi kullanamayacak ve tüm prompt yapılandırmaları ve günlükler kalıcı olarak silinecektir.", @@ -45,7 +45,7 @@ "deleteExternalAPIConfirmWarningContent.content.front": "Bu Harici Bilgi API'si aşağıdakilerle bağlantılıdır", "deleteExternalAPIConfirmWarningContent.noConnectionContent": "Bu API'yi sildiğinizden emin misiniz?", "deleteExternalAPIConfirmWarningContent.title.end": "?", - "deleteExternalAPIConfirmWarningContent.title.front": "Silmek", + "deleteExternalAPIConfirmWarningContent.title.front": "Sil", "didYouKnow": "Biliyor muydunuz?", "docAllEnabled_one": "{{count}} belgesi etkinleştirildi", "docAllEnabled_other": "Tüm {{count}} belgeleri etkinleştirildi", @@ -54,29 +54,29 @@ "documentsDisabled": "{{num}} belge devre dışı - 30 günden uzun süre etkin değil", "editExternalAPIConfirmWarningContent.end": "Dışsal bilgi ve bu değişiklik hepsine uygulanacaktır. Bu değişikliği kaydetmek istediğinizden emin misiniz?", "editExternalAPIConfirmWarningContent.front": "Bu Harici Bilgi API'si aşağıdakilerle bağlantılıdır", - "editExternalAPIFormTitle": "External Knowledge API'yi düzenleme", + "editExternalAPIFormTitle": "Harici Bilgi API'sini düzenleme", "editExternalAPIFormWarning.end": "Dış bilgi", "editExternalAPIFormWarning.front": "Bu Harici API aşağıdakilere bağlıdır:", - "editExternalAPITooltipTitle": "BAĞLANTILI BILGI", + "editExternalAPITooltipTitle": "BAĞLANTILI BİLGİ", "embeddingModelNotAvailable": "Gömme modeli mevcut değil.", - "enable": "Etkinleştirmek", + "enable": "Etkinleştir", "externalAPI": "Harici API", "externalAPIForm.apiKey": "API Anahtarı", "externalAPIForm.cancel": "İptal", - "externalAPIForm.edit": "Düzenlemek", + "externalAPIForm.edit": "Düzenle", "externalAPIForm.encrypted.end": "Teknoloji.", "externalAPIForm.encrypted.front": "API Token'ınız kullanılarak şifrelenecek ve saklanacaktır.", "externalAPIForm.endpoint": "API Uç Noktası", "externalAPIForm.name": "Ad", - "externalAPIForm.save": "Kurtarmak", + "externalAPIForm.save": "Kaydet", "externalAPIPanelDescription": "Harici bilgi API'si, Dify dışındaki bir bilgi bankasına bağlanmak ve bu bilgi bankasından bilgi almak için kullanılır.", - "externalAPIPanelDocumentation": "External Knowledge API'nin nasıl oluşturulacağını öğrenin", + "externalAPIPanelDocumentation": "Harici Bilgi API'sinin nasıl oluşturulacağını öğrenin", "externalAPIPanelTitle": "Harici Bilgi API'si", "externalKnowledgeBase": "Harici Bilgi Bankası", "externalKnowledgeDescription": "Bilgi Açıklaması", "externalKnowledgeDescriptionPlaceholder": "Bu Bilgi Bankası'nda neler olduğunu açıklayın (isteğe bağlı)", "externalKnowledgeForm.cancel": "İptal", - "externalKnowledgeForm.connect": "Bağlamak", + "externalKnowledgeForm.connect": "Bağla", "externalKnowledgeForm.connectedFailed": "Harici Bilgi Tabanına bağlanılamadı", "externalKnowledgeForm.connectedSuccess": "Harici Bilgi Tabanı başarıyla bağlandı", "externalKnowledgeId": "Harici Bilgi Kimliği", @@ -126,7 +126,7 @@ "metadata.datasetMetadata.deleteContent": "Bu {{name}} meta verisini silmek istediğinizden emin misiniz?", "metadata.datasetMetadata.deleteTitle": "Silmek için onayla", "metadata.datasetMetadata.description": "Bu bilgideki tüm meta verileri yönetebilirsiniz. Değişiklikler her belgeye senkronize edilecektir.", - "metadata.datasetMetadata.disabled": "Devre dışı bırakıldı.", + "metadata.datasetMetadata.disabled": "Devre Dışı", "metadata.datasetMetadata.name": "İsim", "metadata.datasetMetadata.namePlaceholder": "Meta veri adı", "metadata.datasetMetadata.rename": "Yeniden Adlandır", @@ -140,8 +140,8 @@ "metadata.selectMetadata.newAction": "Yeni Veriler", "metadata.selectMetadata.search": "Arama meta verileri", "mixtureHighQualityAndEconomicTip": "Yüksek kaliteli ve ekonomik bilgi tabanlarının karışımı için Yeniden Sıralama modeli gereklidir.", - "mixtureInternalAndExternalTip": "Rerank modeli, iç ve dış bilgilerin karışımı için gereklidir.", - "multimodal": "Multimodal", + "mixtureInternalAndExternalTip": "Yeniden Sıralama modeli, iç ve dış bilgilerin karışımı için gereklidir.", + "multimodal": "Çok Modlu", "nTo1RetrievalLegacy": "Geri alım stratejisinin optimizasyonu ve yükseltilmesi nedeniyle, N-to-1 geri alımı Eylül ayında resmi olarak kullanım dışı kalacaktır. O zamana kadar normal şekilde kullanabilirsiniz.", "nTo1RetrievalLegacyLink": "Daha fazla bilgi edin", "nTo1RetrievalLegacyLinkText": "N-1 geri alma Eylül ayında resmi olarak kullanımdan kaldırılacaktır.", @@ -172,12 +172,12 @@ "serviceApi.card.apiReference": "API Referansı", "serviceApi.card.endpoint": "Hizmet API Uç Noktası", "serviceApi.card.title": "Backend servis api", - "serviceApi.disabled": "Engelli", - "serviceApi.enabled": "Hizmette", + "serviceApi.disabled": "Devre Dışı", + "serviceApi.enabled": "Etkin", "serviceApi.title": "Servis API'si", "unavailable": "Kullanılamıyor", "unknownError": "Bilinmeyen hata", - "updated": "Güncel -leştirilmiş", + "updated": "Güncellendi", "weightedScore.customized": "Özelleştirilmiş", "weightedScore.description": "Verilen ağırlıkları ayarlayarak bu yeniden sıralama stratejisi, anlamsal mı yoksa anahtar kelime eşleştirmesini mi önceliklendireceğini belirler.", "weightedScore.keyword": "Anahtar Kelime", diff --git a/web/i18n/tr-TR/login.json b/web/i18n/tr-TR/login.json index 94b08bc971..b30b7b9240 100644 --- a/web/i18n/tr-TR/login.json +++ b/web/i18n/tr-TR/login.json @@ -21,7 +21,7 @@ "checkCode.validTime": "Kodun 5 dakika boyunca geçerli olduğunu unutmayın", "checkCode.verificationCode": "Doğrulama kodu", "checkCode.verificationCodePlaceholder": "6 haneli kodu girin", - "checkCode.verify": "Doğrulamak", + "checkCode.verify": "Doğrula", "checkEmailForResetLink": "Şifrenizi sıfırlamak için bir bağlantı içeren e-postayı kontrol edin. Birkaç dakika içinde görünmezse, spam klasörünüzü kontrol ettiğinizden emin olun.", "confirmPassword": "Şifreyi Onayla", "confirmPasswordPlaceholder": "Yeni şifrenizi onaylayın", @@ -57,8 +57,8 @@ "invitationCode": "Davet Kodu", "invitationCodePlaceholder": "Davet kodunuz", "join": "Katıl", - "joinTipEnd": "takımına davet ediyor", - "joinTipStart": "Sizi", + "joinTipEnd": " takımına Dify'de davet ediyor", + "joinTipStart": "Sizi ", "license.link": "Açık Kaynak Lisansını", "license.tip": "Dify Community Edition'ı başlatmadan önce GitHub'daki", "licenseExpired": "Lisansın Süresi Doldu", diff --git a/web/i18n/tr-TR/pipeline.json b/web/i18n/tr-TR/pipeline.json index 371bc7973b..fbbc2300dc 100644 --- a/web/i18n/tr-TR/pipeline.json +++ b/web/i18n/tr-TR/pipeline.json @@ -3,7 +3,7 @@ "common.confirmPublishContent": "Bilgi işlem hattı başarıyla yayımlandıktan sonra, bu bilgi bankasının öbek yapısı değiştirilemez. Yayınlamak istediğinizden emin misiniz?", "common.goToAddDocuments": "Belge eklemeye git", "common.preparingDataSource": "Veri Kaynağını Hazırlama", - "common.processing": "Işleme", + "common.processing": "İşleme", "common.publishAs": "Bilgi İşlem Hattı Olarak Yayımlama", "common.publishAsPipeline.description": "Bilgi açıklaması", "common.publishAsPipeline.descriptionPlaceholder": "Lütfen bu Bilgi İşlem Hattının açıklamasını girin. (İsteğe bağlı)", @@ -12,13 +12,13 @@ "common.reRun": "Yeniden çalıştır", "common.testRun": "Test Çalıştırması", "inputField.create": "Kullanıcı giriş alanı oluştur", - "inputField.manage": "Yönetmek", + "inputField.manage": "Yönet", "publishToast.desc": "İşlem hattı yayımlanmadığında, bilgi bankası düğümündeki öbek yapısını değiştirebilirsiniz ve işlem hattı düzenlemesi ve değişiklikleri otomatik olarak taslak olarak kaydedilir.", "publishToast.title": "Bu işlem hattı henüz yayımlanmadı", - "ragToolSuggestions.noRecommendationPlugins": "Önerilen eklenti yok, daha fazlasını Marketplace içinde bulabilirsiniz", + "ragToolSuggestions.noRecommendationPlugins": "Önerilen eklenti yok, daha fazlasını Pazar Yeri içinde bulabilirsiniz", "ragToolSuggestions.title": "RAG için Öneriler", "result.resultPreview.error": "Yürütme sırasında hata oluştu", "result.resultPreview.footerTip": "Test çalıştırma modunda, {{count}} parçaya kadar önizleme yapabilirsiniz", - "result.resultPreview.loading": "Işleme... Lütfen bekleyin", + "result.resultPreview.loading": "İşleniyor... Lütfen bekleyin", "result.resultPreview.viewDetails": "Ayrıntıları görüntüleme" } diff --git a/web/i18n/tr-TR/plugin-tags.json b/web/i18n/tr-TR/plugin-tags.json index fb8a504393..3105bde585 100644 --- a/web/i18n/tr-TR/plugin-tags.json +++ b/web/i18n/tr-TR/plugin-tags.json @@ -11,9 +11,9 @@ "tags.medical": "Tıbbi", "tags.news": "Haberler", "tags.other": "Diğer", - "tags.productivity": "Verimli -lik", + "tags.productivity": "Verimlilik", "tags.rag": "PAÇAVRA", - "tags.search": "Aramak", + "tags.search": "Ara", "tags.social": "Sosyal", "tags.travel": "Seyahat", "tags.utilities": "Yardımcı program", diff --git a/web/i18n/tr-TR/plugin.json b/web/i18n/tr-TR/plugin.json index 005e354586..7f0bac81d3 100644 --- a/web/i18n/tr-TR/plugin.json +++ b/web/i18n/tr-TR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Eklentiyi kaldır", "action.deleteContentLeft": "Kaldırmak ister misiniz", "action.deleteContentRight": "eklenti?", + "action.deleteSuccess": "Eklenti başarıyla kaldırıldı", "action.pluginInfo": "Eklenti bilgisi", "action.usedInApps": "Bu eklenti {{num}} uygulamalarında kullanılıyor.", "allCategories": "Tüm Kategoriler", @@ -48,12 +49,12 @@ "autoUpdate.pluginDowngradeWarning.title": "Eklenti Düşürme", "autoUpdate.specifyPluginsToUpdate": "Güncellemek için eklentileri belirtin", "autoUpdate.strategy.disabled.description": "Eklentiler otomatik olarak güncellenmeyecek", - "autoUpdate.strategy.disabled.name": "Engelli", + "autoUpdate.strategy.disabled.name": "Devre Dışı", "autoUpdate.strategy.fixOnly.description": "Yalnızca yamanın sürüm güncellemeleri için otomatik güncelleme (örneğin, 1.0.1 → 1.0.2). Küçük sürüm değişiklikleri güncellemeleri tetiklemez.", "autoUpdate.strategy.fixOnly.name": "Sadece Düzelt", "autoUpdate.strategy.fixOnly.selectedDescription": "Sadece yamanın versiyonları için otomatik güncelleme", "autoUpdate.strategy.latest.description": "Her zaman en son sürüme güncelle", - "autoUpdate.strategy.latest.name": "Son", + "autoUpdate.strategy.latest.name": "En Son", "autoUpdate.strategy.latest.selectedDescription": "Her zaman en son sürüme güncelle", "autoUpdate.updateSettings": "Ayarları Güncelle", "autoUpdate.updateTime": "Güncelleme zamanı", @@ -67,25 +68,25 @@ "category.all": "Tüm", "category.bundles": "Paketler", "category.datasources": "Veri Kaynakları", - "category.extensions": "Uzantı -ları", - "category.models": "Model", - "category.tools": "Araçları", + "category.extensions": "Uzantılar", + "category.models": "Modeller", + "category.tools": "Araçlar", "category.triggers": "Tetikleyiciler", - "categorySingle.agent": "Temsilci Stratejisi", - "categorySingle.bundle": "Bohça", + "categorySingle.agent": "Ajan Stratejisi", + "categorySingle.bundle": "Paket", "categorySingle.datasource": "Veri Kaynağı", "categorySingle.extension": "Uzantı", "categorySingle.model": "Model", - "categorySingle.tool": "Alet", - "categorySingle.trigger": "Tetik", + "categorySingle.tool": "Araç", + "categorySingle.trigger": "Tetikleyici", "debugInfo.title": "Hata ayıklama", "debugInfo.viewDocs": "Belgeleri Görüntüle", "deprecated": "Kaldırılmış", "detailPanel.actionNum": "{{num}} {{action}} DAHİL", "detailPanel.categoryTip.debugging": "Hata Ayıklama Eklentisi", - "detailPanel.categoryTip.github": "Github'dan yüklendi", + "detailPanel.categoryTip.github": "GitHub'dan yüklendi", "detailPanel.categoryTip.local": "Yerel Eklenti", - "detailPanel.categoryTip.marketplace": "Marketplace'ten yüklendi", + "detailPanel.categoryTip.marketplace": "Pazar Yeri'nden yüklendi", "detailPanel.configureApp": "Uygulamayı Yapılandır", "detailPanel.configureModel": "Modeli yapılandırma", "detailPanel.configureTool": "Aracı yapılandır", @@ -95,7 +96,7 @@ "detailPanel.deprecation.reason.businessAdjustments": "iş ayarlamaları", "detailPanel.deprecation.reason.noMaintainer": "bakımcı yok", "detailPanel.deprecation.reason.ownershipTransferred": "mülkiyet devredildi", - "detailPanel.disabled": "Sakat", + "detailPanel.disabled": "Devre Dışı", "detailPanel.endpointDeleteContent": "{{name}} öğesini kaldırmak ister misiniz?", "detailPanel.endpointDeleteTip": "Uç Noktayı Kaldır", "detailPanel.endpointDisableContent": "{{name}} öğesini devre dışı bırakmak ister misiniz?", @@ -109,18 +110,19 @@ "detailPanel.modelNum": "{{num}} DAHİL OLAN MODELLER", "detailPanel.operation.back": "Geri", "detailPanel.operation.checkUpdate": "Güncellemeyi Kontrol Et", - "detailPanel.operation.detail": "Şey", + "detailPanel.operation.detail": "Detay", "detailPanel.operation.info": "Eklenti Bilgileri", - "detailPanel.operation.install": "Yüklemek", - "detailPanel.operation.remove": "Kaldırmak", - "detailPanel.operation.update": "Güncelleştirmek", - "detailPanel.operation.viewDetail": "ayrıntılara bakın", + "detailPanel.operation.install": "Yükle", + "detailPanel.operation.remove": "Kaldır", + "detailPanel.operation.update": "Güncelle", + "detailPanel.operation.updateTooltip": "En son modellere erişmek için güncelleyin.", + "detailPanel.operation.viewDetail": "Pazar Yeri'nde görüntüle", "detailPanel.serviceOk": "Servis Tamam", "detailPanel.strategyNum": "{{num}} {{strategy}} DAHİL", "detailPanel.switchVersion": "Sürümü Değiştir", "detailPanel.toolSelector.auto": "Otomatik", "detailPanel.toolSelector.descriptionLabel": "Araç açıklaması", - "detailPanel.toolSelector.descriptionPlaceholder": "Aletin amacının kısa açıklaması, örneğin belirli bir konum için sıcaklığı elde edin.", + "detailPanel.toolSelector.descriptionPlaceholder": "Aracın amacının kısa açıklaması, örneğin belirli bir konum için sıcaklığı elde edin.", "detailPanel.toolSelector.empty": "Araç eklemek için '+' düğmesini tıklayın. Birden fazla araç ekleyebilirsiniz.", "detailPanel.toolSelector.params": "AKIL YÜRÜTME YAPILANDIRMASI", "detailPanel.toolSelector.paramsTip1": "LLM çıkarım parametrelerini kontrol eder.", @@ -128,7 +130,7 @@ "detailPanel.toolSelector.placeholder": "Bir araç seçin...", "detailPanel.toolSelector.settings": "KULLANICI AYARLARI", "detailPanel.toolSelector.title": "Araç ekle", - "detailPanel.toolSelector.toolLabel": "Alet", + "detailPanel.toolSelector.toolLabel": "Araç", "detailPanel.toolSelector.toolSetting": "Araç Ayarları", "detailPanel.toolSelector.uninstalledContent": "Bu eklenti yerel/GitHub deposundan yüklenir. Lütfen kurulumdan sonra kullanın.", "detailPanel.toolSelector.uninstalledLink": "Eklentilerde Yönet", @@ -142,11 +144,11 @@ "error.fetchReleasesError": "Sürümler alınamıyor. Lütfen daha sonra tekrar deneyin.", "error.inValidGitHubUrl": "Geçersiz GitHub URL'si. Lütfen şu biçimde geçerli bir URL girin: https://github.com/owner/repo", "error.noReleasesFound": "Yayın bulunamadı. Lütfen GitHub deposunu veya giriş URL'sini kontrol edin.", - "findMoreInMarketplace": "Marketplace'te daha fazla bilgi edinin", + "findMoreInMarketplace": "Pazar Yeri'nde daha fazla bilgi edinin", "from": "Kaynak", - "fromMarketplace": "Pazar Yerinden", + "fromMarketplace": "Pazar Yeri'nden", "install": "{{num}} yükleme", - "installAction": "Yüklemek", + "installAction": "Yükle", "installFrom": "ŞURADAN YÜKLE", "installFromGitHub.gitHubRepo": "GitHub deposu", "installFromGitHub.installFailed": "Yükleme başarısız oldu", @@ -161,10 +163,10 @@ "installFromGitHub.uploadFailed": "Karşıya yükleme başarısız oldu", "installModal.back": "Geri", "installModal.cancel": "İptal", - "installModal.close": "Kapatmak", + "installModal.close": "Kapat", "installModal.dropPluginToInstall": "Yüklemek için eklenti paketini buraya bırakın", "installModal.fromTrustSource": "Lütfen eklentileri yalnızca güvenilir bir kaynaktan yüklediğinizden emin olun.", - "installModal.install": "Yüklemek", + "installModal.install": "Yükle", "installModal.installComplete": "Kurulum tamamlandı", "installModal.installFailed": "Yükleme başarısız oldu", "installModal.installFailedDesc": "Eklenti yüklenemedi, başarısız oldu.", @@ -176,7 +178,7 @@ "installModal.labels.package": "Paket", "installModal.labels.repository": "Depo", "installModal.labels.version": "Sürüm", - "installModal.next": "Önümüzdeki", + "installModal.next": "İleri", "installModal.pluginLoadError": "Eklenti yükleme hatası", "installModal.pluginLoadErrorDesc": "Bu eklenti yüklenmeyecek", "installModal.readyToInstall": "Aşağıdaki eklentiyi yüklemek üzere", @@ -189,16 +191,16 @@ "list.notFound": "Eklenti bulunamadı", "list.source.github": "GitHub'dan yükleyin", "list.source.local": "Yerel Paket Dosyasından Yükle", - "list.source.marketplace": "Marketten Yükleme", + "list.source.marketplace": "Pazar Yeri'nden Yükleme", "marketplace.and": "ve", "marketplace.difyMarketplace": "Dify Pazar Yeri", - "marketplace.discover": "Keşfetmek", + "marketplace.discover": "Keşfet", "marketplace.empower": "Yapay zeka geliştirmenizi güçlendirin", - "marketplace.moreFrom": "Marketplace'ten daha fazlası", + "marketplace.moreFrom": "Pazar Yeri'nden daha fazlası", "marketplace.noPluginFound": "Eklenti bulunamadı", "marketplace.partnerTip": "Dify partner'ı tarafından doğrulandı", "marketplace.pluginsResult": "{{num}} sonuç", - "marketplace.sortBy": "Kara şehir", + "marketplace.sortBy": "Sırala", "marketplace.sortOption.firstReleased": "İlk Çıkanlar", "marketplace.sortOption.mostPopular": "En popüler", "marketplace.sortOption.newlyReleased": "Yeni Çıkanlar", @@ -207,7 +209,7 @@ "marketplace.viewMore": "Daha fazla göster", "metadata.title": "Eklentiler", "pluginInfoModal.packageName": "Paket", - "pluginInfoModal.release": "Serbest bırakma", + "pluginInfoModal.release": "Sürüm", "pluginInfoModal.repository": "Depo", "pluginInfoModal.title": "Eklenti bilgisi", "privilege.admins": "Yöneticiler", @@ -220,32 +222,38 @@ "readmeInfo.failedToFetch": "README alınamadı", "readmeInfo.needHelpCheckReadme": "Yardıma mı ihtiyacınız var? README dosyasına bakın.", "readmeInfo.noReadmeAvailable": "README mevcut değil", - "readmeInfo.title": "OKUMA MESELESİ", + "readmeInfo.title": "BENİOKU", "requestAPlugin": "Bir eklenti iste", - "search": "Aramak", + "search": "Ara", "searchCategories": "Arama Kategorileri", - "searchInMarketplace": "Marketplace'te arama yapma", + "searchInMarketplace": "Pazar Yeri'nde arama yapın", "searchPlugins": "Eklentileri ara", "searchTools": "Arama araçları...", - "source.github": "GitHub (İngilizce)", + "source.github": "GitHub", "source.local": "Yerel Paket Dosyası", - "source.marketplace": "Pazar", + "source.marketplace": "Pazar Yeri", "task.clearAll": "Tümünü temizle", - "task.errorPlugins": "Failed to Install Plugins", + "task.errorMsg.github": "Bu eklenti otomatik olarak yüklenemedi.\nLütfen GitHub'dan yükleyin.", + "task.errorMsg.marketplace": "Bu eklenti otomatik olarak yüklenemedi.\nLütfen Pazar Yeri'nden yükleyin.", + "task.errorMsg.unknown": "Bu eklenti yüklenemedi.\nEklenti kaynağı belirlenemedi.", + "task.errorPlugins": "Eklentiler Yüklenemedi", "task.installError": "{{errorLength}} eklentileri yüklenemedi, görüntülemek için tıklayın", - "task.installSuccess": "{{successLength}} plugins installed successfully", - "task.installed": "Installed", + "task.installFromGithub": "GitHub'dan yükle", + "task.installFromMarketplace": "Pazar Yeri'nden yükle", + "task.installSuccess": "{{successLength}} eklenti başarıyla yüklendi", + "task.installed": "Yüklendi", "task.installedError": "{{errorLength}} eklentileri yüklenemedi", "task.installing": "Eklentiler yükleniyor.", + "task.installingHint": "Yükleniyor... Bu işlem birkaç dakika sürebilir.", "task.installingWithError": "{{installingLength}} eklentileri yükleniyor, {{successLength}} başarılı, {{errorLength}} başarısız oldu", "task.installingWithSuccess": "{{installingLength}} eklentileri yükleniyor, {{successLength}} başarılı.", - "task.runningPlugins": "Installing Plugins", - "task.successPlugins": "Successfully Installed Plugins", - "upgrade.close": "Kapatmak", - "upgrade.description": "Aşağıdaki eklentiyi yüklemek üzere", - "upgrade.successfulTitle": "Yükleme başarılı", - "upgrade.title": "Eklentiyi Yükle", - "upgrade.upgrade": "Yüklemek", - "upgrade.upgrading": "Yükleme...", + "task.runningPlugins": "Eklentiler Yükleniyor", + "task.successPlugins": "Başarıyla Yüklenen Eklentiler", + "upgrade.close": "Kapat", + "upgrade.description": "Aşağıdaki eklentiyi güncellemek üzeresiniz", + "upgrade.successfulTitle": "Güncelleme başarılı", + "upgrade.title": "Eklentiyi Güncelle", + "upgrade.upgrade": "Güncelle", + "upgrade.upgrading": "Güncelleniyor...", "upgrade.usedInApps": "{{num}} uygulamalarında kullanılır" } diff --git a/web/i18n/tr-TR/time.json b/web/i18n/tr-TR/time.json index c51f3d361b..81f51fde81 100644 --- a/web/i18n/tr-TR/time.json +++ b/web/i18n/tr-TR/time.json @@ -1,16 +1,16 @@ { - "dateFormats.display": "MMMM D, YYYY", - "dateFormats.displayWithTime": "MMMM D, YYYY hh:mm A", - "dateFormats.input": "YYYY-AA-GG", - "dateFormats.output": "YYYY-AA-GG", - "dateFormats.outputWithTime": "YYYY-AA-GGSS:DD:DDS.SSSZ", - "daysInWeek.Fri": "Cuma", - "daysInWeek.Mon": "Mon", - "daysInWeek.Sat": "Sat", - "daysInWeek.Sun": "Güneş", - "daysInWeek.Thu": "Perşembe", - "daysInWeek.Tue": "Salı", - "daysInWeek.Wed": "Çarşamba", + "dateFormats.display": "D MMMM YYYY", + "dateFormats.displayWithTime": "D MMMM YYYY hh:mm A", + "dateFormats.input": "YYYY-MM-DD", + "dateFormats.output": "YYYY-MM-DD", + "dateFormats.outputWithTime": "YYYY-MM-DDTHH:mm:ss.SSSZ", + "daysInWeek.Fri": "Cum", + "daysInWeek.Mon": "Pzt", + "daysInWeek.Sat": "Cmt", + "daysInWeek.Sun": "Paz", + "daysInWeek.Thu": "Per", + "daysInWeek.Tue": "Sal", + "daysInWeek.Wed": "Çar", "defaultPlaceholder": "Bir zaman seç...", "months.April": "Nisan", "months.August": "Ağustos", diff --git a/web/i18n/tr-TR/tools.json b/web/i18n/tr-TR/tools.json index ca6e9dc85f..f2b64fe481 100644 --- a/web/i18n/tr-TR/tools.json +++ b/web/i18n/tr-TR/tools.json @@ -21,7 +21,7 @@ "auth.setupModalTitleDescription": "Kimlik bilgilerini yapılandırdıktan sonra, çalışma alanındaki tüm üyeler uygulamaları düzenlerken bu aracı kullanabilir.", "auth.unauthorized": "Yetkisiz", "author": "Tarafından", - "builtInPromptTitle": "Prompt", + "builtInPromptTitle": "İstem", "contribute.line1": "Dify'ye ", "contribute.line2": "araçlar eklemekle ilgileniyorum.", "contribute.viewGuide": "Rehberi Görüntüle", @@ -192,7 +192,7 @@ "setBuiltInTools.parameters": "parametreler", "setBuiltInTools.required": "Gerekli", "setBuiltInTools.setting": "Ayar", - "setBuiltInTools.string": "string", + "setBuiltInTools.string": "metin", "setBuiltInTools.toolDescription": "Araç açıklaması", "test.parameters": "Parametreler", "test.parametersValue": "Parametreler ve Değer", @@ -205,9 +205,9 @@ "thought.used": "Kullanıldı", "thought.using": "Kullanılıyor", "title": "Araçlar", - "toolNameUsageTip": "Agent akıl yürütme ve prompt için araç çağrı adı", + "toolNameUsageTip": "Ajan akıl yürütme ve istem için araç çağrı adı", "toolRemoved": "Araç kaldırıldı", "type.builtIn": "Yerleşik", "type.custom": "Özel", - "type.workflow": "Workflow" + "type.workflow": "İş Akışı" } diff --git a/web/i18n/tr-TR/workflow.json b/web/i18n/tr-TR/workflow.json index 51a957518d..a9437ec3df 100644 --- a/web/i18n/tr-TR/workflow.json +++ b/web/i18n/tr-TR/workflow.json @@ -67,7 +67,7 @@ "changeHistory.hintText": "Düzenleme işlemleriniz, bu oturum süresince cihazınızda saklanan bir değişiklik geçmişinde izlenir. Bu tarihçesi düzenleyiciden çıktığınızda temizlenir.", "changeHistory.nodeAdd": "Düğüm eklendi", "changeHistory.nodeChange": "Düğüm değişti", - "changeHistory.nodeConnect": "Node bağlandı", + "changeHistory.nodeConnect": "Düğüm bağlandı", "changeHistory.nodeDelete": "Düğüm silindi", "changeHistory.nodeDescriptionChange": "Düğüm açıklaması değiştirildi", "changeHistory.nodeDragStop": "Düğüm taşındı", @@ -129,7 +129,7 @@ "common.currentView": "Geçerli Görünüm", "common.currentWorkflow": "Mevcut İş Akışı", "common.debugAndPreview": "Önizleme", - "common.disconnect": "Ayırmak", + "common.disconnect": "Bağlantıyı Kes", "common.duplicate": "Çoğalt", "common.editing": "Düzenleme", "common.effectVarConfirm.content": "Değişken diğer düğümlerde kullanılıyor. Yine de kaldırmak istiyor musunuz?", @@ -151,7 +151,7 @@ "common.humanInputEmailTipInDebugMode": "E-posta (Teslimat Yöntemi) {{email}} adresine gönderildi", "common.humanInputWebappTip": "Yalnızca hata ayıklama önizlemesi, kullanıcı bunu web uygulamasında görmeyecek.", "common.importDSL": "DSL İçe Aktar", - "common.importDSLTip": "Geçerli taslak üzerine yazılacak. İçe aktarmadan önce workflow yedekleyin.", + "common.importDSLTip": "Geçerli taslak üzerine yazılacak. İçe aktarmadan önce iş akışını yedekleyin.", "common.importFailure": "İçe Aktarma Başarısız", "common.importSuccess": "İçe Aktarma Başarılı", "common.importWarning": "Dikkat", @@ -220,10 +220,10 @@ "common.viewDetailInTracingPanel": "Ayrıntıları görüntüle", "common.viewOnly": "Sadece Görüntüleme", "common.viewRunHistory": "Çalıştırma geçmişini görüntüle", - "common.workflowAsTool": "Araç Olarak Workflow", + "common.workflowAsTool": "Araç Olarak İş Akışı", "common.workflowAsToolDisabledHint": "En son iş akışını yayınlayın ve bunu bir araç olarak yapılandırmadan önce bağlı bir Kullanıcı Girdisi düğümünün olduğundan emin olun.", - "common.workflowAsToolTip": "Workflow güncellemesinden sonra araç yeniden yapılandırması gereklidir.", - "common.workflowProcess": "Workflow Süreci", + "common.workflowAsToolTip": "İş Akışı güncellemesinden sonra araç yeniden yapılandırması gereklidir.", + "common.workflowProcess": "İş Akışı Süreci", "customWebhook": "Özel Webhook", "debug.copyLastRun": "Son Çalışmayı Kopyala", "debug.copyLastRunError": "Son çalışma girdilerini kopyalamak başarısız oldu.", @@ -240,7 +240,7 @@ "debug.relations.dependentsDescription": "Bu düğüme dayanan düğümler", "debug.relations.noDependencies": "Bağımlılık yok", "debug.relations.noDependents": "Bakmakla yükümlü olunan kişi yok", - "debug.relationsTab": "Ilişkiler", + "debug.relationsTab": "İlişkiler", "debug.settingsTab": "Ayarlar", "debug.variableInspect.chatNode": "Konuşma", "debug.variableInspect.clearAll": "Hepsini sıfırla", @@ -249,7 +249,7 @@ "debug.variableInspect.emptyLink": "Daha fazla öğrenin", "debug.variableInspect.emptyTip": "Bir düğümü kanvas üzerinde geçtikten veya bir düğümü adım adım çalıştırdıktan sonra, Düğüm Değişkeni'ndeki mevcut değeri Değişken İncele'de görüntüleyebilirsiniz.", "debug.variableInspect.envNode": "Çevre", - "debug.variableInspect.export": "Ihracat", + "debug.variableInspect.export": "Dışa Aktar", "debug.variableInspect.exportToolTip": "Değişkeni Dosya Olarak Dışa Aktar", "debug.variableInspect.largeData": "Büyük veri, salt okunur önizleme. Tümünü görüntülemek için dışa aktarın.", "debug.variableInspect.largeDataNoExport": "Büyük veri - yalnızca kısmi önizleme", @@ -294,11 +294,12 @@ "env.modal.value": "Değer", "env.modal.valuePlaceholder": "env değeri", "error.operations.addingNodes": "düğüm ekleme", - "error.operations.connectingNodes": "düğümleri bağlamak", + "error.operations.connectingNodes": "düğümleri bağlama", "error.operations.modifyingWorkflow": "iş akışını değiştirme", "error.operations.updatingWorkflow": "iş akışını güncelleme", "error.startNodeRequired": "Lütfen {{operation}} işleminden önce önce bir başlangıç düğümü ekleyin", "errorMsg.authRequired": "Yetkilendirme gereklidir", + "errorMsg.configureModel": "Bir model yapılandırın", "errorMsg.fieldRequired": "{{field}} gereklidir", "errorMsg.fields.code": "Kod", "errorMsg.fields.model": "Model", @@ -308,6 +309,7 @@ "errorMsg.fields.visionVariable": "Vizyon Değişkeni", "errorMsg.invalidJson": "{{field}} geçersiz JSON", "errorMsg.invalidVariable": "Geçersiz değişken", + "errorMsg.modelPluginNotInstalled": "Geçersiz değişken. Bu değişkeni etkinleştirmek için bir model yapılandırın.", "errorMsg.noValidTool": "{{field}} geçerli bir araç seçilmedi", "errorMsg.rerankModelRequired": "Yeniden Sıralama Modelini açmadan önce, lütfen ayarlarda modelin başarıyla yapılandırıldığını onaylayın.", "errorMsg.startNodeRequired": "Lütfen {{operation}} işleminden önce önce bir başlangıç düğümü ekleyin", @@ -327,7 +329,7 @@ "nodes.agent.installPlugin.cancel": "İptal", "nodes.agent.installPlugin.changelog": "Değişiklik günlüğü", "nodes.agent.installPlugin.desc": "Aşağıdaki eklentiyi yüklemek üzere", - "nodes.agent.installPlugin.install": "Yüklemek", + "nodes.agent.installPlugin.install": "Yükle", "nodes.agent.installPlugin.title": "Eklentiyi Yükle", "nodes.agent.learnMore": "Daha fazla bilgi edinin", "nodes.agent.linkToPlugin": "Eklentilere Bağlantı", @@ -352,7 +354,7 @@ "nodes.agent.outputVars.text": "Temsilci Tarafından Oluşturulan İçerik", "nodes.agent.outputVars.usage": "Model Kullanım Bilgileri", "nodes.agent.parameterSchema": "Parametre Şeması", - "nodes.agent.pluginInstaller.install": "Yüklemek", + "nodes.agent.pluginInstaller.install": "Yükle", "nodes.agent.pluginInstaller.installing": "Yükleme", "nodes.agent.pluginNotFoundDesc": "Bu eklenti GitHub'dan yüklenmiştir. Lütfen şuraya gidin: Eklentiler yeniden yüklemek için", "nodes.agent.pluginNotInstalled": "Bu eklenti yüklü değil", @@ -363,7 +365,7 @@ "nodes.agent.strategy.searchPlaceholder": "Arama aracısı stratejisi", "nodes.agent.strategy.selectTip": "Ajan stratejisi seçin", "nodes.agent.strategy.shortLabel": "Strateji", - "nodes.agent.strategy.tooltip": "Farklı Agentic stratejileri, sistemin çok adımlı araç çağrılarını nasıl planladığını ve yürüttüğünü belirler", + "nodes.agent.strategy.tooltip": "Farklı Ajan stratejileri, sistemin çok adımlı araç çağrılarını nasıl planladığını ve yürüttüğünü belirler", "nodes.agent.strategyNotFoundDesc": "Yüklenen eklenti sürümü bu stratejiyi sağlamaz.", "nodes.agent.strategyNotFoundDescAndSwitchVersion": "Yüklenen eklenti sürümü bu stratejiyi sağlamaz. Sürümü değiştirmek için tıklayın.", "nodes.agent.strategyNotInstallTooltip": "{{strategy}} yüklü değil", @@ -387,12 +389,12 @@ "nodes.assigner.operations./=": "/=", "nodes.assigner.operations.append": "Ekleme", "nodes.assigner.operations.clear": "Berrak", - "nodes.assigner.operations.extend": "Uzatmak", + "nodes.assigner.operations.extend": "Genişlet", "nodes.assigner.operations.over-write": "Üzerine", "nodes.assigner.operations.overwrite": "Üzerine", "nodes.assigner.operations.remove-first": "İlkini kaldır", "nodes.assigner.operations.remove-last": "Sonuncuyu Kaldır", - "nodes.assigner.operations.set": "Ayarlamak", + "nodes.assigner.operations.set": "Ayarla", "nodes.assigner.operations.title": "İşlem", "nodes.assigner.over-write": "Üzerine Yaz", "nodes.assigner.plus": "Artı", @@ -438,10 +440,11 @@ "nodes.common.memory.windowSize": "Pencere Boyutu", "nodes.common.outputVars": "Çıktı Değişkenleri", "nodes.common.pluginNotInstalled": "Eklenti yüklü değil", + "nodes.common.pluginsNotInstalled": "{{count}} eklenti yüklenmedi", "nodes.common.retry.maxRetries": "En fazla yeniden deneme", "nodes.common.retry.ms": "Ms", - "nodes.common.retry.retries": "{{num}} Yeni -den deneme", - "nodes.common.retry.retry": "Yeni -den deneme", + "nodes.common.retry.retries": "{{num}} Yeniden deneme", + "nodes.common.retry.retry": "Yeniden deneme", "nodes.common.retry.retryFailed": "Yeniden deneme başarısız oldu", "nodes.common.retry.retryFailedTimes": "{{times}} yeniden denemeleri başarısız oldu", "nodes.common.retry.retryInterval": "Yeniden deneme aralığı", @@ -639,7 +642,7 @@ "nodes.ifElse.optionName.url": "URL", "nodes.ifElse.optionName.video": "Video", "nodes.ifElse.or": "veya", - "nodes.ifElse.select": "Seçmek", + "nodes.ifElse.select": "Seç", "nodes.ifElse.selectVariable": "Değişken seçin...", "nodes.iteration.ErrorMethod.continueOnError": "Hata Üzerine Devam Et", "nodes.iteration.ErrorMethod.operationTerminated": "Sonlandırıldı", @@ -676,9 +679,14 @@ "nodes.knowledgeBase.chunksInput": "Parçalar", "nodes.knowledgeBase.chunksInputTip": "Bilgi tabanı düğümünün girdi değişkeni 'Chunks'tır. Değişkenin tipi, seçilen parça yapısıyla tutarlı olması gereken belirli bir JSON Şemasına sahip bir nesnedir.", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks değişkeni gereklidir", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API anahtarı kullanılamıyor", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Krediler tükendi", + "nodes.knowledgeBase.embeddingModelIncompatible": "Uyumsuz", "nodes.knowledgeBase.embeddingModelIsInvalid": "Gömme modeli geçersiz", "nodes.knowledgeBase.embeddingModelIsRequired": "Gömme modeli gereklidir", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Gömme modeli yapılandırılmadı", "nodes.knowledgeBase.indexMethodIsRequired": "İndeks yöntemi gereklidir", + "nodes.knowledgeBase.notConfigured": "Yapılandırılmadı", "nodes.knowledgeBase.rerankingModelIsInvalid": "Yeniden sıralama modeli geçersiz", "nodes.knowledgeBase.rerankingModelIsRequired": "Yeniden sıralama modeli gereklidir", "nodes.knowledgeBase.retrievalSettingIsRequired": "Alma ayarı gereklidir", @@ -687,7 +695,7 @@ "nodes.knowledgeRetrieval.metadata.options.automatic.subTitle": "Kullanıcı sorgusuna dayalı olarak otomatik olarak meta veri filtreleme koşulları oluşturun.", "nodes.knowledgeRetrieval.metadata.options.automatic.title": "Otomatik", "nodes.knowledgeRetrieval.metadata.options.disabled.subTitle": "Meta veri filtreleme özelliğini devre dışı bırakma", - "nodes.knowledgeRetrieval.metadata.options.disabled.title": "Devre dışı bırakıldı.", + "nodes.knowledgeRetrieval.metadata.options.disabled.title": "Devre Dışı", "nodes.knowledgeRetrieval.metadata.options.manual.subTitle": "Manuel olarak meta veri filtreleme koşulları ekleyin", "nodes.knowledgeRetrieval.metadata.options.manual.title": "Kılavuz", "nodes.knowledgeRetrieval.metadata.panel.add": "Koşul Ekle", @@ -861,7 +869,8 @@ "nodes.templateTransform.codeSupportTip": "Sadece Jinja2 destekler", "nodes.templateTransform.inputVars": "Giriş Değişkenleri", "nodes.templateTransform.outputVars.output": "Dönüştürülmüş içerik", - "nodes.tool.authorize": "Yetkilendirmek", + "nodes.tool.authorizationRequired": "Yetkilendirme gerekli", + "nodes.tool.authorize": "Yetkilendir", "nodes.tool.inputVars": "Giriş Değişkenleri", "nodes.tool.insertPlaceholder1": "Yazın veya basın", "nodes.tool.insertPlaceholder2": "değişken ekle", @@ -961,7 +970,7 @@ "nodes.triggerSchedule.title": "Program", "nodes.triggerSchedule.useCronExpression": "Cron ifadesi kullan", "nodes.triggerSchedule.useVisualPicker": "Görsel seçici kullan", - "nodes.triggerSchedule.visualConfig": "Görsel Konfigürasyon", + "nodes.triggerSchedule.visualConfig": "Görsel Yapılandırma", "nodes.triggerSchedule.weekdays": "Hafta günleri", "nodes.triggerWebhook.addHeader": "Ekle", "nodes.triggerWebhook.addParameter": "Ekle", @@ -1021,8 +1030,8 @@ "onboarding.back": "Geri", "onboarding.description": "Farklı başlangıç düğümlerinin farklı yetenekleri vardır. Endişelenmeyin, bunları her zaman daha sonra değiştirebilirsiniz.", "onboarding.escTip.key": "esc", - "onboarding.escTip.press": "Basın", - "onboarding.escTip.toDismiss": "reddetmek", + "onboarding.escTip.press": "Kapatmak için", + "onboarding.escTip.toDismiss": "tuşuna basın", "onboarding.learnMore": "Daha fazla bilgi edin", "onboarding.title": "Başlamak için bir başlangıç düğümü seçin", "onboarding.trigger": "Tetik", @@ -1051,10 +1060,12 @@ "panel.change": "Değiştir", "panel.changeBlock": "Düğümü Değiştir", "panel.checklist": "Kontrol Listesi", + "panel.checklistDescription": "Yayınlamadan önce aşağıdaki sorunları çözün", "panel.checklistResolved": "Tüm sorunlar çözüldü", "panel.checklistTip": "Yayınlamadan önce tüm sorunların çözüldüğünden emin olun", "panel.createdBy": "Oluşturan: ", "panel.goTo": "Git", + "panel.goToFix": "Düzeltmeye git", "panel.helpLink": "Yardım", "panel.maximize": "Kanvası Maksimize Et", "panel.minimize": "Tam Ekrandan Çık", @@ -1069,8 +1080,8 @@ "panel.startNode": "Başlangıç Düğümü", "panel.userInputField": "Kullanıcı Giriş Alanı", "publishLimit.startNodeDesc": "Bu plan için bir iş akışında 2 tetikleyici sınırına ulaştınız. Bu iş akışını yayınlamak için yükseltme yapın.", - "publishLimit.startNodeTitlePrefix": "Yükselt", - "publishLimit.startNodeTitleSuffix": "her iş akışı için sınırsız tetikleyici aç", + "publishLimit.startNodeTitlePrefix": "Yükseltme: ", + "publishLimit.startNodeTitleSuffix": "her iş akışı için sınırsız tetikleyicinin kilidini açın", "sidebar.exportWarning": "Mevcut Kaydedilmiş Versiyonu Dışa Aktar", "sidebar.exportWarningDesc": "Bu, çalışma akışınızın mevcut kaydedilmiş sürümünü dışa aktaracaktır. Editörde kaydedilmemiş değişiklikleriniz varsa, lütfen önce bunları çalışma akışı alanındaki dışa aktarma seçeneğini kullanarak kaydedin.", "singleRun.back": "Geri", @@ -1095,8 +1106,8 @@ "tabs.hideActions": "Araçları gizle", "tabs.installed": "Yüklendi", "tabs.logic": "Mantık", - "tabs.noFeaturedPlugins": "Marketplace'te daha fazla araç keşfedin", - "tabs.noFeaturedTriggers": "Marketplace'te daha fazla tetikleyici keşfedin", + "tabs.noFeaturedPlugins": "Pazar Yeri'nde daha fazla araç keşfedin", + "tabs.noFeaturedTriggers": "Pazar Yeri'nde daha fazla tetikleyici keşfedin", "tabs.noPluginsFound": "Hiç eklenti bulunamadı", "tabs.noResult": "Eşleşen bulunamadı", "tabs.plugin": "Eklenti", @@ -1116,7 +1127,7 @@ "tabs.transform": "Dönüştür", "tabs.usePlugin": "Araç seç", "tabs.utilities": "Yardımcı Araçlar", - "tabs.workflowTool": "Workflow", + "tabs.workflowTool": "İş Akışı", "tracing.stopBy": "{{user}} tarafından durduruldu", "triggerStatus.disabled": "TETİKLEYİCİ • DEVRE DIŞI", "triggerStatus.enabled": "TETİK", From 848a041c2527ef2cf9746f84bae63a5f694bed3d Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 08:20:25 -0500 Subject: [PATCH 10/34] test: migrate dataset service create dataset tests to testcontainers (#33945) --- .../test_dataset_service_create_dataset.py | 60 +++++++++++++++++++ .../test_dataset_service_create_dataset.py | 50 ---------------- 2 files changed, 60 insertions(+), 50 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py delete mode 100644 api/tests/unit_tests/services/test_dataset_service_create_dataset.py diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py new file mode 100644 index 0000000000..c486ff5613 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -0,0 +1,60 @@ +"""Testcontainers integration tests for DatasetService.create_empty_rag_pipeline_dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.account import Account, Tenant, TenantAccountJoin +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity + + +class TestDatasetServiceCreateRagPipelineDataset: + def _create_tenant_and_account(self, db_session_with_containers) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"ds_create_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return tenant, account + + def _build_entity(self, name: str = "Test Dataset") -> RagPipelineDatasetCreateEntity: + icon_info = IconInfo(icon="\U0001f4d9", icon_background="#FFF4ED", icon_type="emoji") + return RagPipelineDatasetCreateEntity( + name=name, + description="", + icon_info=icon_info, + permission="only_me", + ) + + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + tenant, _ = self._create_tenant_and_account(db_session_with_containers) + + mock_user = Mock(id=None) + with patch("services.dataset_service.current_user", mock_user): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=self._build_entity(), + ) diff --git a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py deleted file mode 100644 index f8c5270656..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Unit tests for non-SQL validation paths in DatasetService dataset creation.""" - -from unittest.mock import Mock, patch -from uuid import uuid4 - -import pytest - -from services.dataset_service import DatasetService -from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity - - -class TestDatasetServiceCreateRagPipelineDatasetNonSQL: - """Unit coverage for non-SQL validation in create_empty_rag_pipeline_dataset.""" - - @pytest.fixture - def mock_rag_pipeline_dependencies(self): - """Patch database session and current_user for validation-only unit coverage.""" - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.current_user") as mock_current_user, - ): - yield { - "db_session": mock_db, - "current_user_mock": mock_current_user, - } - - def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies): - """Raise ValueError when current_user.id is unavailable before SQL persistence.""" - # Arrange - tenant_id = str(uuid4()) - mock_rag_pipeline_dependencies["current_user_mock"].id = None - - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name="Test Dataset", - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act / Assert - with pytest.raises(ValueError, match="Current user or current user id not found"): - DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, - rag_pipeline_dataset_create_entity=entity, - ) From 6698b42f97bb130af74c4630ec9f49ea2fca2194 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 08:20:53 -0500 Subject: [PATCH 11/34] test: migrate api based extension service tests to testcontainers (#33952) --- .../test_api_based_extension_service.py | 144 ++++++ .../test_api_based_extension_service.py | 421 ------------------ 2 files changed, 144 insertions(+), 421 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_api_based_extension_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 7ce7357b41..b8e022503f 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -525,3 +525,147 @@ class TestAPIBasedExtensionService: # Try to get extension with wrong tenant ID with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) + + def test_save_extension_api_key_exactly_four_chars_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 4 characters should be rejected (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="1234", + ) + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_api_key_exactly_five_chars_accepted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 5 characters should be accepted (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="12345", + ) + + saved = APIBasedExtensionService.save(extension_data) + assert saved.id is not None + + def test_save_extension_requestor_constructor_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Exception raised by requestor constructor is wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config") + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: bad config"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_network_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Network exceptions during ping are wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError( + "network failure" + ) + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: network failure"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_update_duplicate_name_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Updating an existing extension to use another extension's name should fail.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + ext1 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Alpha", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + ext2 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Beta", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + # Try to rename ext2 to ext1's name + ext2.name = "Extension Alpha" + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService.save(ext2) + + def test_get_all_returns_empty_for_different_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Extensions from one tenant should not be visible to another.""" + fake = Faker() + _, tenant1 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + _, tenant2 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant1 is not None + + APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant1.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + assert tenant2 is not None + result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id) + assert result == [] diff --git a/api/tests/unit_tests/services/test_api_based_extension_service.py b/api/tests/unit_tests/services/test_api_based_extension_service.py deleted file mode 100644 index 7f4b5fdaa3..0000000000 --- a/api/tests/unit_tests/services/test_api_based_extension_service.py +++ /dev/null @@ -1,421 +0,0 @@ -""" -Comprehensive unit tests for services/api_based_extension_service.py - -Covers: -- APIBasedExtensionService.get_all_by_tenant_id -- APIBasedExtensionService.save -- APIBasedExtensionService.delete -- APIBasedExtensionService.get_with_tenant_id -- APIBasedExtensionService._validation (new record & existing record branches) -- APIBasedExtensionService._ping_connection (pong success, wrong response, exception) -""" - -from unittest.mock import MagicMock, patch - -import pytest - -from services.api_based_extension_service import APIBasedExtensionService - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_extension( - *, - id_: str | None = None, - tenant_id: str = "tenant-001", - name: str = "my-ext", - api_endpoint: str = "https://example.com/hook", - api_key: str = "secret-key-123", -) -> MagicMock: - """Return a lightweight mock that mimics APIBasedExtension.""" - ext = MagicMock() - ext.id = id_ - ext.tenant_id = tenant_id - ext.name = name - ext.api_endpoint = api_endpoint - ext.api_key = api_key - return ext - - -# --------------------------------------------------------------------------- -# Tests: get_all_by_tenant_id -# --------------------------------------------------------------------------- - - -class TestGetAllByTenantId: - """Tests for APIBasedExtensionService.get_all_by_tenant_id.""" - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt): - """Each api_key is decrypted and the list is returned.""" - ext1 = _make_extension(id_="id-1", api_key="enc-key-1") - ext2 = _make_extension(id_="id-2", api_key="enc-key-2") - - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [ - ext1, - ext2, - ] - - result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") - - assert result == [ext1, ext2] - assert ext1.api_key == "decrypted-key" - assert ext2.api_key == "decrypted-key" - assert mock_decrypt.call_count == 2 - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt): - """Returns an empty list gracefully when no records exist.""" - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] - - result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") - - assert result == [] - mock_decrypt.assert_not_called() - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt): - """Verifies the DB is queried with the supplied tenant_id.""" - mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] - - APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz") - - mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz") - - -# --------------------------------------------------------------------------- -# Tests: save -# --------------------------------------------------------------------------- - - -class TestSave: - """Tests for APIBasedExtensionService.save.""" - - @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") - @patch("services.api_based_extension_service.db") - @patch.object(APIBasedExtensionService, "_validation") - def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt): - """Happy path: validation passes, key is encrypted, record is added and committed.""" - ext = _make_extension(id_=None, api_key="plain-key-123") - - result = APIBasedExtensionService.save(ext) - - mock_validation.assert_called_once_with(ext) - mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123") - assert ext.api_key == "encrypted-key" - mock_db.session.add.assert_called_once_with(ext) - mock_db.session.commit.assert_called_once() - assert result is ext - - @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") - @patch("services.api_based_extension_service.db") - @patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty")) - def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt): - """If _validation raises, save should propagate the error without touching the DB.""" - ext = _make_extension(name="") - - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService.save(ext) - - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -# --------------------------------------------------------------------------- -# Tests: delete -# --------------------------------------------------------------------------- - - -class TestDelete: - """Tests for APIBasedExtensionService.delete.""" - - @patch("services.api_based_extension_service.db") - def test_delete_removes_record_and_commits(self, mock_db): - """delete() must call session.delete with the extension and then commit.""" - ext = _make_extension(id_="delete-me") - - APIBasedExtensionService.delete(ext) - - mock_db.session.delete.assert_called_once_with(ext) - mock_db.session.commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# Tests: get_with_tenant_id -# --------------------------------------------------------------------------- - - -class TestGetWithTenantId: - """Tests for APIBasedExtensionService.get_with_tenant_id.""" - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt): - """Found extension has its api_key decrypted before being returned.""" - ext = _make_extension(id_="ext-123", api_key="enc-key") - - (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext - - result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123") - - assert result is ext - assert ext.api_key == "decrypted-key" - mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key") - - @patch("services.api_based_extension_service.db") - def test_raises_value_error_when_not_found(self, mock_db): - """Raises ValueError when no matching extension exists.""" - (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None - - with pytest.raises(ValueError, match="API based extension is not found"): - APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent") - - @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") - @patch("services.api_based_extension_service.db") - def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt): - """Verifies both tenant_id and extension id are used in the query.""" - ext = _make_extension(id_="ext-abc") - chain = mock_db.session.query.return_value - chain.filter_by.return_value.filter_by.return_value.first.return_value = ext - - APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc") - - # First filter_by call uses tenant_id - chain.filter_by.assert_called_once_with(tenant_id="tenant-002") - # Second filter_by call uses id - chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc") - - -# --------------------------------------------------------------------------- -# Tests: _validation (new record — id is falsy) -# --------------------------------------------------------------------------- - - -class TestValidationNewRecord: - """Tests for _validation() with a brand-new record (no id).""" - - def _build_mock_db(self, name_exists: bool = False): - mock_db = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( - MagicMock() if name_exists else None - ) - return mock_db - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_valid_new_extension_passes(self, mock_db, mock_ping): - """A new record with all valid fields should pass without exceptions.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey") - - # Should not raise - APIBasedExtensionService._validation(ext) - mock_ping.assert_called_once_with(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_is_empty(self, mock_db): - """Empty name raises ValueError.""" - ext = _make_extension(id_=None, name="") - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_is_none(self, mock_db): - """None name raises ValueError.""" - ext = _make_extension(id_=None, name=None) - with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_name_already_exists_for_new_record(self, mock_db): - """A new record whose name already exists raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( - MagicMock() - ) - ext = _make_extension(id_=None, name="duplicate-name") - - with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_endpoint_is_empty(self, mock_db): - """Empty api_endpoint raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_endpoint="") - - with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_endpoint_is_none(self, mock_db): - """None api_endpoint raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_endpoint=None) - - with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_is_empty(self, mock_db): - """Empty api_key raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="") - - with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_is_none(self, mock_db): - """None api_key raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key=None) - - with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_too_short(self, mock_db): - """api_key shorter than 5 characters raises ValueError.""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="abc") - - with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService._validation(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_api_key_exactly_four_chars(self, mock_db): - """api_key with exactly 4 characters raises ValueError (boundary condition).""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="1234") - - with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService._validation(ext) - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping): - """api_key with exactly 5 characters should pass (boundary condition).""" - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None - ext = _make_extension(id_=None, api_key="12345") - - # Should not raise - APIBasedExtensionService._validation(ext) - - -# --------------------------------------------------------------------------- -# Tests: _validation (existing record — id is truthy) -# --------------------------------------------------------------------------- - - -class TestValidationExistingRecord: - """Tests for _validation() with an existing record (id is set).""" - - @patch.object(APIBasedExtensionService, "_ping_connection") - @patch("services.api_based_extension_service.db") - def test_valid_existing_extension_passes(self, mock_db, mock_ping): - """An existing record whose name is unique (excluding self) should pass.""" - # .where(...).first() → None means no *other* record has that name - ( - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value - ) = None - ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey") - - # Should not raise - APIBasedExtensionService._validation(ext) - mock_ping.assert_called_once_with(ext) - - @patch("services.api_based_extension_service.db") - def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db): - """Existing record cannot use a name already owned by a different record.""" - ( - mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value - ) = MagicMock() - ext = _make_extension(id_="existing-id", name="taken-name") - - with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService._validation(ext) - - -# --------------------------------------------------------------------------- -# Tests: _ping_connection -# --------------------------------------------------------------------------- - - -class TestPingConnection: - """Tests for APIBasedExtensionService._ping_connection.""" - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_successful_ping_returns_pong(self, mock_requestor_class): - """When the endpoint returns {"result": "pong"}, no exception is raised.""" - mock_client = MagicMock() - mock_client.request.return_value = {"result": "pong"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key") - # Should not raise - APIBasedExtensionService._ping_connection(ext) - - mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_wrong_ping_response_raises_value_error(self, mock_requestor_class): - """When the response is not {"result": "pong"}, a ValueError is raised.""" - mock_client = MagicMock() - mock_client.request.return_value = {"result": "error"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_network_exception_wraps_in_value_error(self, mock_requestor_class): - """Any exception raised during request is wrapped in a ValueError.""" - mock_client = MagicMock() - mock_client.request.side_effect = ConnectionError("network failure") - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error: network failure"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class): - """Exception raised by the requestor constructor itself is wrapped.""" - mock_requestor_class.side_effect = RuntimeError("bad config") - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error: bad config"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_missing_result_key_raises_value_error(self, mock_requestor_class): - """A response dict without a 'result' key does not equal 'pong' → raises.""" - mock_client = MagicMock() - mock_client.request.return_value = {} # no 'result' key - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - with pytest.raises(ValueError, match="connection error"): - APIBasedExtensionService._ping_connection(ext) - - @patch("services.api_based_extension_service.APIBasedExtensionRequestor") - def test_uses_ping_extension_point(self, mock_requestor_class): - """The PING extension point is passed to the client.request call.""" - from models.api_based_extension import APIBasedExtensionPoint - - mock_client = MagicMock() - mock_client.request.return_value = {"result": "pong"} - mock_requestor_class.return_value = mock_client - - ext = _make_extension() - APIBasedExtensionService._ping_connection(ext) - - call_kwargs = mock_client.request.call_args - assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING - assert call_kwargs.kwargs["params"] == {} From f5cc1c8b75cf4bf863bc49072c36dd228da4c2ec Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 08:26:31 -0500 Subject: [PATCH 12/34] test: migrate saved message service tests to testcontainers (#33949) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../services/test_saved_message_service.py | 201 +++--- .../services/test_saved_message_service.py | 626 ------------------ 2 files changed, 106 insertions(+), 721 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_saved_message_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 94a4e62560..d256c0d90b 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -396,11 +396,6 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) - # Verify no database operations were performed - - saved_messages = db_session_with_containers.query(SavedMessage).all() - assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -497,124 +492,140 @@ class TestSavedMessageService: # The message should still exist, only the saved_message should be deleted assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): - """ - Test error handling when no user is provided. - - This test verifies: - - Proper error handling for missing user - - ValueError is raised when user is None - - No database operations are performed - """ - # Arrange: Create test data - fake = Faker() + def test_save_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test saving a message for an EndUser.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) - # Act & Assert: Verify proper error handling - with pytest.raises(ValueError) as exc_info: - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + mock_external_service_dependencies["message_service"].get_message.return_value = message - assert "User is required" in str(exc_info.value) + SavedMessageService.save(app_model=app, user=end_user, message_id=message.id) - # Verify no database operations were performed for this specific test - # Note: We don't check total count as other tests may have created data - # Instead, we verify that the error was properly raised - pass - - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): - """ - Test error handling when saving message with no user. - - This test verifies: - - Method returns early when user is None - - No database operations are performed - - No exceptions are raised - """ - # Arrange: Create test data - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - message = self._create_test_message(db_session_with_containers, app, account) - - # Act: Execute the method under test with None user - result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) - - # Assert: Verify the expected outcomes - assert result is None - - # Verify no saved message was created - - saved_message = ( + saved = ( db_session_with_containers.query(SavedMessage) - .where( - SavedMessage.app_id == app.id, - SavedMessage.message_id == message.id, - ) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) .first() ) + assert saved is not None + assert saved.created_by == end_user.id + assert saved.created_by_role == "end_user" - assert saved_message is None - - def test_delete_success_existing_message( + def test_save_duplicate_is_idempotent( self, db_session_with_containers: Session, mock_external_service_dependencies ): - """ - Test successful deletion of an existing saved message. - - This test verifies: - - Proper deletion of existing saved message - - Correct database state after deletion - - No errors during deletion process - """ - # Arrange: Create test data - fake = Faker() + """Test that saving an already-saved message does not create a duplicate.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) message = self._create_test_message(db_session_with_containers, app, account) - # Create a saved message first - saved_message = SavedMessage( - app_id=app.id, - message_id=message.id, - created_by_role="account", - created_by=account.id, - ) + mock_external_service_dependencies["message_service"].get_message.return_value = message - db_session_with_containers.add(saved_message) + # Save once + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + # Save again + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + + count = ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .count() + ) + assert count == 1 + + def test_delete_without_user_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting without a user is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Pre-create a saved message + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id) + db_session_with_containers.add(saved) db_session_with_containers.commit() - # Verify saved message exists + SavedMessageService.delete(app_model=app, user=None, message_id=message.id) + + # Should still exist + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is not None + ) + + def test_delete_non_existent_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting a non-existent saved message is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Should not raise — use a valid UUID that doesn't exist in DB + from uuid import uuid4 + + SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4())) + + def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test deleting a saved message for an EndUser.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) + + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id) + db_session_with_containers.add(saved) + db_session_with_containers.commit() + + SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id) + + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is None + ) + + def test_delete_only_affects_own_saved_messages( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that delete only removes the requesting user's saved message.""" + app, account1 = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, account1) + + # Both users save the same message + saved_account = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id + ) + saved_end_user = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id + ) + db_session_with_containers.add_all([saved_account, saved_end_user]) + db_session_with_containers.commit() + + # Delete only account1's saved message + SavedMessageService.delete(app_model=app, user=account1, message_id=message.id) + + # Account's saved message should be gone assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == account1.id, ) .first() - is not None + is None ) - - # Act: Execute the method under test - SavedMessageService.delete(app_model=app, user=account, message_id=message.id) - - # Assert: Verify the expected outcomes - # Check if saved message was deleted from database - deleted_saved_message = ( + # End user's saved message should still exist + assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == end_user.id, ) .first() + is not None ) - - assert deleted_saved_message is None - - # Verify database state - db_session_with_containers.commit() - # The message should still exist, only the saved_message should be deleted - assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/unit_tests/services/test_saved_message_service.py b/api/tests/unit_tests/services/test_saved_message_service.py deleted file mode 100644 index 87b946fe46..0000000000 --- a/api/tests/unit_tests/services/test_saved_message_service.py +++ /dev/null @@ -1,626 +0,0 @@ -""" -Comprehensive unit tests for SavedMessageService. - -This test suite provides complete coverage of saved message operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -## Test Coverage - -### 1. Pagination (TestSavedMessageServicePagination) -Tests saved message listing and pagination: -- Pagination with valid user (Account and EndUser) -- Pagination without user raises ValueError -- Pagination with last_id parameter -- Empty results when no saved messages exist -- Integration with MessageService pagination - -### 2. Save Operations (TestSavedMessageServiceSave) -Tests saving messages: -- Save message for Account user -- Save message for EndUser -- Save without user (no-op) -- Prevent duplicate saves (idempotent) -- Message validation through MessageService - -### 3. Delete Operations (TestSavedMessageServiceDelete) -Tests deleting saved messages: -- Delete saved message for Account user -- Delete saved message for EndUser -- Delete without user (no-op) -- Delete non-existent saved message (no-op) -- Proper database cleanup - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked - for fast, isolated unit tests -- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and side effects - (database operations, method calls) - -## Key Concepts - -**User Types:** -- Account: Workspace members (console users) -- EndUser: API users (end users) - -**Saved Messages:** -- Users can save messages for later reference -- Each user has their own saved message list -- Saving is idempotent (duplicate saves ignored) -- Deletion is safe (non-existent deletes ignored) -""" - -from datetime import UTC, datetime -from unittest.mock import MagicMock, Mock, create_autospec, patch - -import pytest - -from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models import Account -from models.model import App, EndUser, Message -from models.web import SavedMessage -from services.saved_message_service import SavedMessageService - - -class SavedMessageServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - saved message operations. - """ - - @staticmethod - def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock: - """ - Create a mock Account object. - - Args: - account_id: Unique identifier for the account - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Account object with specified attributes - """ - account = create_autospec(Account, instance=True) - account.id = account_id - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock: - """ - Create a mock EndUser object. - - Args: - user_id: Unique identifier for the end user - **kwargs: Additional attributes to set on the mock - - Returns: - Mock EndUser object with specified attributes - """ - user = create_autospec(EndUser, instance=True) - user.id = user_id - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock App object. - - Args: - app_id: Unique identifier for the app - tenant_id: Tenant/workspace identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock App object with specified attributes - """ - app = create_autospec(App, instance=True) - app.id = app_id - app.tenant_id = tenant_id - app.name = kwargs.get("name", "Test App") - app.mode = kwargs.get("mode", "chat") - for key, value in kwargs.items(): - setattr(app, key, value) - return app - - @staticmethod - def create_message_mock( - message_id: str = "msg-123", - app_id: str = "app-123", - **kwargs, - ) -> Mock: - """ - Create a mock Message object. - - Args: - message_id: Unique identifier for the message - app_id: Associated app identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Message object with specified attributes - """ - message = create_autospec(Message, instance=True) - message.id = message_id - message.app_id = app_id - message.query = kwargs.get("query", "Test query") - message.answer = kwargs.get("answer", "Test answer") - message.created_at = kwargs.get("created_at", datetime.now(UTC)) - for key, value in kwargs.items(): - setattr(message, key, value) - return message - - @staticmethod - def create_saved_message_mock( - saved_message_id: str = "saved-123", - app_id: str = "app-123", - message_id: str = "msg-123", - created_by: str = "user-123", - created_by_role: str = "account", - **kwargs, - ) -> Mock: - """ - Create a mock SavedMessage object. - - Args: - saved_message_id: Unique identifier for the saved message - app_id: Associated app identifier - message_id: Associated message identifier - created_by: User who saved the message - created_by_role: Role of the user ('account' or 'end_user') - **kwargs: Additional attributes to set on the mock - - Returns: - Mock SavedMessage object with specified attributes - """ - saved_message = create_autospec(SavedMessage, instance=True) - saved_message.id = saved_message_id - saved_message.app_id = app_id - saved_message.message_id = message_id - saved_message.created_by = created_by - saved_message.created_by_role = created_by_role - saved_message.created_at = kwargs.get("created_at", datetime.now(UTC)) - for key, value in kwargs.items(): - setattr(saved_message, key, value) - return saved_message - - -@pytest.fixture -def factory(): - """Provide the test data factory to all tests.""" - return SavedMessageServiceTestDataFactory - - -class TestSavedMessageServicePagination: - """Test saved message pagination operations.""" - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - - # Create saved messages for this user - saved_messages = [ - factory.create_saved_message_mock( - saved_message_id=f"saved-{i}", - app_id=app.id, - message_id=f"msg-{i}", - created_by=user.id, - created_by_role="account", - ) - for i in range(3) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20) - - # Assert - assert result == expected_pagination - mock_db_session.query.assert_called_once_with(SavedMessage) - # Verify MessageService was called with correct message IDs - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=20, - include_ids=["msg-0", "msg-1", "msg-2"], - ) - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - - # Create saved messages for this end user - saved_messages = [ - factory.create_saved_message_mock( - saved_message_id=f"saved-{i}", - app_id=app.id, - message_id=f"msg-{i}", - created_by=user.id, - created_by_role="end_user", - ) - for i in range(2) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10) - - # Assert - assert result == expected_pagination - # Verify correct role was used in query - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=10, - include_ids=["msg-0", "msg-1"], - ) - - def test_pagination_without_user_raises_error(self, factory): - """Test that pagination without user raises ValueError.""" - # Arrange - app = factory.create_app_mock() - - # Act & Assert - with pytest.raises(ValueError, match="User is required"): - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20) - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory): - """Test pagination with last_id parameter.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - last_id = "msg-last" - - saved_messages = [ - factory.create_saved_message_mock( - message_id=f"msg-{i}", - app_id=app.id, - created_by=user.id, - ) - for i in range(5) - ] - - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = saved_messages - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10) - - # Assert - assert result == expected_pagination - # Verify last_id was passed to MessageService - mock_message_pagination.assert_called_once() - call_args = mock_message_pagination.call_args - assert call_args.kwargs["last_id"] == last_id - - @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory): - """Test pagination when user has no saved messages.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - - # Mock database query returning empty list - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Mock MessageService pagination response - expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False) - mock_message_pagination.return_value = expected_pagination - - # Act - result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20) - - # Assert - assert result == expected_pagination - # Verify MessageService was called with empty include_ids - mock_message_pagination.assert_called_once_with( - app_model=app, - user=user, - last_id=None, - limit=20, - include_ids=[], - ) - - -class TestSavedMessageServiceSave: - """Test save message operations.""" - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_message_for_account(self, mock_db_session, mock_get_message, factory): - """Test saving a message for an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message = factory.create_message_mock(message_id="msg-123", app_id=app.id) - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - mock_db_session.add.assert_called_once() - saved_message = mock_db_session.add.call_args[0][0] - assert saved_message.app_id == app.id - assert saved_message.message_id == message.id - assert saved_message.created_by == user.id - assert saved_message.created_by_role == "account" - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory): - """Test saving a message for an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - message = factory.create_message_mock(message_id="msg-456", app_id=app.id) - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - mock_db_session.add.assert_called_once() - saved_message = mock_db_session.add.call_args[0][0] - assert saved_message.app_id == app.id - assert saved_message.message_id == message.id - assert saved_message.created_by == user.id - assert saved_message.created_by_role == "end_user" - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_without_user_does_nothing(self, mock_db_session, factory): - """Test that saving without user is a no-op.""" - # Arrange - app = factory.create_app_mock() - - # Act - SavedMessageService.save(app_model=app, user=None, message_id="msg-123") - - # Assert - mock_db_session.query.assert_not_called() - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory): - """Test that saving an already saved message is idempotent.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-789" - - # Mock database query - existing saved message found - existing_saved = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = existing_saved - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message_id) - - # Assert - no new saved message created - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - mock_get_message.assert_not_called() - - @patch("services.saved_message_service.MessageService.get_message", autospec=True) - @patch("services.saved_message_service.db.session", autospec=True) - def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory): - """Test that save validates message exists through MessageService.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message = factory.create_message_mock() - - # Mock database query - no existing saved message - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock MessageService.get_message - mock_get_message.return_value = message - - # Act - SavedMessageService.save(app_model=app, user=user, message_id=message.id) - - # Assert - MessageService.get_message was called for validation - mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id) - - -class TestSavedMessageServiceDelete: - """Test delete saved message operations.""" - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_saved_message_for_account(self, mock_db_session, factory): - """Test deleting a saved message for an Account user.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-123" - - # Mock database query - existing saved message found - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - mock_db_session.delete.assert_called_once_with(saved_message) - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_saved_message_for_end_user(self, mock_db_session, factory): - """Test deleting a saved message for an EndUser.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_end_user_mock() - message_id = "msg-456" - - # Mock database query - existing saved message found - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user.id, - created_by_role="end_user", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - mock_db_session.delete.assert_called_once_with(saved_message) - mock_db_session.commit.assert_called_once() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_without_user_does_nothing(self, mock_db_session, factory): - """Test that deleting without user is a no-op.""" - # Arrange - app = factory.create_app_mock() - - # Act - SavedMessageService.delete(app_model=app, user=None, message_id="msg-123") - - # Assert - mock_db_session.query.assert_not_called() - mock_db_session.delete.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory): - """Test that deleting a non-existent saved message is a no-op.""" - # Arrange - app = factory.create_app_mock() - user = factory.create_account_mock() - message_id = "msg-nonexistent" - - # Mock database query - no saved message found - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act - SavedMessageService.delete(app_model=app, user=user, message_id=message_id) - - # Assert - no deletion occurred - mock_db_session.delete.assert_not_called() - mock_db_session.commit.assert_not_called() - - @patch("services.saved_message_service.db.session", autospec=True) - def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory): - """Test that delete only removes the user's own saved message.""" - # Arrange - app = factory.create_app_mock() - user1 = factory.create_account_mock(account_id="user-1") - message_id = "msg-shared" - - # Mock database query - finds user1's saved message - saved_message = factory.create_saved_message_mock( - app_id=app.id, - message_id=message_id, - created_by=user1.id, - created_by_role="account", - ) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = saved_message - - # Act - SavedMessageService.delete(app_model=app, user=user1, message_id=message_id) - - # Assert - only user1's saved message is deleted - mock_db_session.delete.assert_called_once_with(saved_message) - # Verify the query filters by user - assert mock_query.where.called From 20fc69ae7fe9f5d3b0d57ae231d4b0476f31411d Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:44:46 +0100 Subject: [PATCH 13/34] refactor: use EnumText for WorkflowAppLog.created_from and WorkflowArchiveLog columns (#33954) --- .../apps/workflow/generate_task_pipeline.py | 2 +- api/models/workflow.py | 12 ++++++-- api/tasks/trigger_processing_tasks.py | 2 +- ..._sqlalchemy_api_workflow_run_repository.py | 6 ++-- .../services/test_workflow_app_service.py | 29 ++++++++++--------- 5 files changed, 29 insertions(+), 22 deletions(-) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 96dd8c5445..bd6e2a0302 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): app_id=self._application_generate_entity.app_config.app_id, workflow_id=self._workflow.id, workflow_run_id=workflow_run_id, - created_from=created_from.value, + created_from=created_from, created_by_role=self._created_by_role, created_by=self._user_id, ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 6e8dda429d..334ec42058 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1221,7 +1221,9 @@ class WorkflowAppLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False + ) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -1301,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase): log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True + ) run_version: Mapped[str] = mapped_column(String(255), nullable=False) - run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[WorkflowExecutionStatus] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), nullable=False + ) run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False ) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 75ae1f6316..f8c7964805 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -179,7 +179,7 @@ def _record_trigger_failure_log( app_id=workflow.app_id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=created_by_role, created_by=created_by, ) diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index c3ed79656f..49b370990a 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -27,7 +27,7 @@ from models.human_input import ( HumanInputFormRecipient, RecipientType, ) -from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, @@ -218,7 +218,7 @@ class TestDeleteRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -278,7 +278,7 @@ class TestCountRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 8ab8df2a5a..84ce6364df 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from dify_graph.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole +from models.workflow import WorkflowAppLogCreatedFrom from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency @@ -221,7 +222,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -357,7 +358,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_1.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -399,7 +400,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_2.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -441,7 +442,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_4.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -521,7 +522,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -627,7 +628,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -732,7 +733,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -860,7 +861,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -902,7 +903,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="web-app", + created_from=WorkflowAppLogCreatedFrom.WEB_APP, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -1037,7 +1038,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1125,7 +1126,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1279,7 +1280,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1379,7 +1380,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1481,7 +1482,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) From 4a2e9633db85acde1a4b42cfd0dfad0672628d5f Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:46:06 +0100 Subject: [PATCH 14/34] refactor: use EnumText for ApiToken.type (#33961) --- api/controllers/console/apikey.py | 14 ++++++++------ api/controllers/console/datasets/datasets.py | 6 +++--- api/models/enums.py | 7 +++++++ api/models/model.py | 3 ++- .../libs/test_api_token_cache_integration.py | 3 ++- .../unit_tests/controllers/console/test_apikey.py | 5 +++-- 6 files changed, 25 insertions(+), 13 deletions(-) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 6c54be84a8..783cb5c444 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -9,6 +9,7 @@ from extensions.ext_database import db from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset +from models.enums import ApiTokenType from models.model import ApiToken, App from services.api_token_service import ApiTokenCache @@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None @@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource): ) key = ApiToken.generate_api_key(self.token_prefix or "", 24) + assert self.resource_type is not None, "resource_type must be set" api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) api_token.tenant_id = current_tenant_id @@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None @@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): """Create a new API key for an app""" return super().post(resource_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" token_prefix = "app-" @@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource): """Delete an API key for an app""" return super().delete(resource_id, api_key_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" @@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): """Create a new API key for a dataset""" return super().post(resource_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" token_prefix = "ds-" @@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource): """Delete an API key for a dataset""" return super().delete(resource_id, api_key_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 725a8380cd..fb98932269 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -54,7 +54,7 @@ from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum -from models.enums import SegmentStatus +from models.enums import ApiTokenType, SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -777,7 +777,7 @@ class DatasetIndexingStatusApi(Resource): class DatasetApiKeyApi(Resource): max_keys = 10 token_prefix = "dataset-" - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") @@ -826,7 +826,7 @@ class DatasetApiKeyApi(Resource): @console_ns.route("/datasets/api-keys/") class DatasetApiDeleteApi(Resource): - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("delete_dataset_api_key") @console_ns.doc(description="Delete dataset API key") diff --git a/api/models/enums.py b/api/models/enums.py index 4849099d30..8aca1df2b4 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -323,3 +323,10 @@ class ProviderQuotaType(StrEnum): if member.value == value: return member raise ValueError(f"No matching enum found for value '{value}'") + + +class ApiTokenType(StrEnum): + """API Token type""" + + APP = "app" + DATASET = "dataset" diff --git a/api/models/model.py b/api/models/model.py index b098966052..331a5b7d8c 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -31,6 +31,7 @@ from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db from .enums import ( + ApiTokenType, AppMCPServerStatus, AppStatus, BannerStatus, @@ -2095,7 +2096,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(String(16), nullable=False) + type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False) token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(sa.DateTime, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py index 1d7b835fd2..a942690cbd 100644 --- a/api/tests/integration_tests/libs/test_api_token_cache_integration.py +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -13,6 +13,7 @@ from unittest.mock import patch import pytest from extensions.ext_redis import redis_client +from models.enums import ApiTokenType from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken @@ -279,7 +280,7 @@ class TestEndToEndCacheFlow: test_token = ApiToken() test_token.id = "test-e2e-id" test_token.token = test_token_value - test_token.type = test_scope + test_token.type = ApiTokenType.APP test_token.app_id = "test-app" test_token.tenant_id = "test-tenant" test_token.last_used_at = None diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py index c18dd044a7..2dff9c4037 100644 --- a/api/tests/unit_tests/controllers/console/test_apikey.py +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -8,6 +8,7 @@ from controllers.console.apikey import ( BaseApiKeyResource, _get_resource, ) +from models.enums import ApiTokenType @pytest.fixture @@ -45,14 +46,14 @@ def bypass_permissions(): class DummyApiKeyListResource(BaseApiKeyListResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" token_prefix = "app-" class DummyApiKeyResource(BaseApiKeyResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" From 3f086b97b6a0aad1d78d367b901aa4e374f3feaf Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 12:46:54 -0500 Subject: [PATCH 15/34] test: remove mock tests superseded by testcontainers (#33957) --- .../test_document_service_display_status.py | 8 + .../test_document_service_display_status.py | 8 - .../services/test_web_conversation_service.py | 259 ------------------ 3 files changed, 8 insertions(+), 267 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_document_service_display_status.py delete mode 100644 api/tests/unit_tests/services/test_web_conversation_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index c6aa89c733..47d259d8a0 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -142,3 +142,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c rows = db_session_with_containers.scalars(filtered).all() assert {row.id for row in rows} == {doc1.id, doc2.id} + + +def test_normalize_display_status_alias_mapping(): + """Test that normalize_display_status maps aliases correctly.""" + assert DocumentService.normalize_display_status("ACTIVE") == "available" + assert DocumentService.normalize_display_status("enabled") == "available" + assert DocumentService.normalize_display_status("archived") == "archived" + assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/unit_tests/services/test_document_service_display_status.py b/api/tests/unit_tests/services/test_document_service_display_status.py deleted file mode 100644 index cb2e2940c8..0000000000 --- a/api/tests/unit_tests/services/test_document_service_display_status.py +++ /dev/null @@ -1,8 +0,0 @@ -from services.dataset_service import DocumentService - - -def test_normalize_display_status_alias_mapping(): - assert DocumentService.normalize_display_status("ACTIVE") == "available" - assert DocumentService.normalize_display_status("enabled") == "available" - assert DocumentService.normalize_display_status("archived") == "archived" - assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/unit_tests/services/test_web_conversation_service.py b/api/tests/unit_tests/services/test_web_conversation_service.py deleted file mode 100644 index 7687d355e9..0000000000 --- a/api/tests/unit_tests/services/test_web_conversation_service.py +++ /dev/null @@ -1,259 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account -from models.model import App, EndUser -from services.web_conversation_service import WebConversationService - - -@pytest.fixture -def app_model() -> App: - return cast(App, SimpleNamespace(id="app-1")) - - -def _account(**kwargs: Any) -> Account: - return cast(Account, SimpleNamespace(**kwargs)) - - -def _end_user(**kwargs: Any) -> EndUser: - return cast(EndUser, SimpleNamespace(**kwargs)) - - -def test_pagination_by_last_id_should_raise_error_when_user_is_none( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - - # Act + Assert - with pytest.raises(ValueError, match="User is required"): - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=None, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - ) - - -def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_user = _account(id="user-1") - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id="conv-9", - limit=10, - invoke_from=InvokeFrom.WEB_APP, - pinned=None, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] is None - assert call_kwargs["exclude_ids"] is None - assert call_kwargs["last_id"] == "conv-9" - assert call_kwargs["sort_by"] == "-updated_at" - - -def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-1" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {})) - session.scalars.return_value.all.return_value = ["conv-1", "conv-2"] - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - pinned=True, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] == ["conv-1", "conv-2"] - assert call_kwargs["exclude_ids"] is None - - -def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - session = MagicMock() - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-1" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - session.scalars.return_value.all.return_value = ["conv-3"] - mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") - mock_pagination.return_value = MagicMock() - - # Act - WebConversationService.pagination_by_last_id( - session=session, - app_model=app_model, - user=fake_user, - last_id=None, - limit=20, - invoke_from=InvokeFrom.WEB_APP, - pinned=False, - ) - - # Assert - call_kwargs = mock_pagination.call_args.kwargs - assert call_kwargs["include_ids"] is None - assert call_kwargs["exclude_ids"] == ["conv-3"] - - -def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: - # Arrange - mock_db = mocker.patch("services.web_conversation_service.db") - mocker.patch("services.web_conversation_service.ConversationService.get_conversation") - - # Act - WebConversationService.pin(app_model, "conv-1", None) - - # Assert - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_pin_should_return_early_when_conversation_is_already_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-1" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = object() - mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation") - - # Act - WebConversationService.pin(app_model, "conv-1", fake_user) - - # Assert - mock_get_conversation.assert_not_called() - mock_db.session.add.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_pin_should_create_pinned_conversation_when_not_already_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_account_cls = type("FakeAccount", (), {}) - fake_user = cast(Account, fake_account_cls()) - fake_user.id = "account-2" - mocker.patch("services.web_conversation_service.Account", fake_account_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = None - mock_conversation = SimpleNamespace(id="conv-2") - mock_get_conversation = mocker.patch( - "services.web_conversation_service.ConversationService.get_conversation", - return_value=mock_conversation, - ) - - # Act - WebConversationService.pin(app_model, "conv-2", fake_user) - - # Assert - mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user) - added_obj = mock_db.session.add.call_args.args[0] - assert added_obj.app_id == "app-1" - assert added_obj.conversation_id == "conv-2" - assert added_obj.created_by_role == "account" - assert added_obj.created_by == "account-2" - mock_db.session.commit.assert_called_once() - - -def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: - # Arrange - mock_db = mocker.patch("services.web_conversation_service.db") - - # Act - WebConversationService.unpin(app_model, "conv-1", None) - - # Assert - mock_db.session.delete.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_unpin_should_return_early_when_conversation_is_not_pinned( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-3" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act - WebConversationService.unpin(app_model, "conv-7", fake_user) - - # Assert - mock_db.session.delete.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_unpin_should_delete_pinned_conversation_when_exists( - app_model: App, - mocker: MockerFixture, -) -> None: - # Arrange - fake_end_user_cls = type("FakeEndUser", (), {}) - fake_user = cast(EndUser, fake_end_user_cls()) - fake_user.id = "end-user-4" - mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) - mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) - mock_db = mocker.patch("services.web_conversation_service.db") - pinned_obj = SimpleNamespace(id="pin-1") - mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj - - # Act - WebConversationService.unpin(app_model, "conv-8", fake_user) - - # Assert - mock_db.session.delete.assert_called_once_with(pinned_obj) - mock_db.session.commit.assert_called_once() From 8ca1ebb96d905c3ff00a916596669c9cd2768a6c Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 12:50:10 -0500 Subject: [PATCH 16/34] test: migrate workflow tools manage service tests to testcontainers (#33955) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../test_workflow_tools_manage_service.py | 109 ++ .../test_workflow_tools_manage_service.py | 955 ------------------ 2 files changed, 109 insertions(+), 955 deletions(-) delete mode 100644 api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 34906a4e54..e3c0749494 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService: # After the fix, this should always be 0 # For now, we document that the record may exist, demonstrating the bug # assert tool_count == 0 # Expected after fix + + def test_delete_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of a workflow tool.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + tool_name = fake.unique.word() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + ) + + tool = ( + db_session_with_containers.query(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name) + .first() + ) + assert tool is not None + + result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first() + ) + assert deleted is None + + def test_list_tenant_workflow_tools_empty( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing workflow tools when none exist returns empty list.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id) + + assert result == [] + + def test_get_workflow_tool_by_tool_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_tool_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_get_workflow_tool_by_app_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_app_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_list_single_workflow_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that list_single_workflow_tools raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4()) + + def test_create_workflow_tool_with_labels( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that labels are forwarded to ToolLabelManager when provided.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.unique.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + labels=["label-1", "label-2"], + ) + + assert result == {"result": "success"} + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once() diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py deleted file mode 100644 index e9bcc89445..0000000000 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ /dev/null @@ -1,955 +0,0 @@ -""" -Unit tests for services.tools.workflow_tools_manage_service - -Covers WorkflowToolManageService: create, update, list, delete, get, list_single. -""" - -import json -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from models.model import App -from models.tools import WorkflowToolProvider -from services.tools import workflow_tools_manage_service -from services.tools.workflow_tools_manage_service import WorkflowToolManageService - -# --------------------------------------------------------------------------- -# Shared helpers / fake infrastructure -# --------------------------------------------------------------------------- - - -class DummyWorkflow: - """Minimal in-memory Workflow substitute.""" - - def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: - self._graph_dict = graph_dict - self.version = version - - @property - def graph_dict(self) -> dict: - return self._graph_dict - - -class FakeQuery: - """Chainable query object that always returns a fixed result.""" - - def __init__(self, result: object) -> None: - self._result = result - - def where(self, *args: object, **kwargs: object) -> "FakeQuery": - return self - - def first(self) -> object: - return self._result - - def delete(self) -> int: - return 1 - - -class DummySession: - """Minimal SQLAlchemy session substitute.""" - - def __init__(self) -> None: - self.added: list[WorkflowToolProvider] = [] - self.committed: bool = False - - def __enter__(self) -> "DummySession": - return self - - def __exit__(self, exc_type: object, exc: object, tb: object) -> bool: - return False - - def add(self, obj: WorkflowToolProvider) -> None: - self.added.append(obj) - - def begin(self) -> "DummySession": - return self - - def commit(self) -> None: - self.committed = True - - -def _build_parameters() -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM), - ] - - -def _build_fake_db( - *, - existing_tool: WorkflowToolProvider | None = None, - app: object | None = None, - tool_by_id: WorkflowToolProvider | None = None, -) -> tuple[MagicMock, DummySession]: - """ - Build a fake db object plus a DummySession for Session context-manager. - - query(WorkflowToolProvider) returns existing_tool on first call, - then tool_by_id on subsequent calls (or None if not provided). - query(App) returns app. - """ - call_counts: dict[str, int] = {"wftp": 0} - - def query(model: type) -> FakeQuery: - if model is WorkflowToolProvider: - call_counts["wftp"] += 1 - if call_counts["wftp"] == 1: - return FakeQuery(existing_tool) - return FakeQuery(tool_by_id) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query, commit=MagicMock()) - dummy_session = DummySession() - return fake_db, dummy_session - - -# --------------------------------------------------------------------------- -# TestCreateWorkflowTool -# --------------------------------------------------------------------------- - - -class TestCreateWorkflowTool: - """Tests for WorkflowToolManageService.create_workflow_tool.""" - - def test_should_raise_when_human_input_nodes_present(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Human-input nodes must be rejected before any provider is created.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "n1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) - fake_session = SimpleNamespace(query=lambda m: FakeQuery(None) if m is WorkflowToolProvider else FakeQuery(app)) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - - # Act + Assert - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon={"type": "emoji", "emoji": "🔧"}, - description="desc", - parameters=_build_parameters(), - ) - - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - - def test_should_raise_when_duplicate_name_or_app_id(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Existing provider with same name or app_id raises ValueError.""" - # Arrange - existing = MagicMock(spec=WorkflowToolProvider) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(existing)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-1", - name="dup", - label="Dup", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the referenced App does not exist.""" - # Arrange - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(None) # App returns None - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="missing-app", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the App has no attached Workflow.""" - # Arrange - app_no_workflow = SimpleNamespace(workflow=None) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app_no_workflow) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Exceptions from WorkflowToolProviderController.from_db are wrapped as ValueError.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr( - workflow_tools_manage_service.WorkflowToolProviderController, - "from_db", - MagicMock(side_effect=RuntimeError("bad config")), - ) - - # Act + Assert - with pytest.raises(ValueError, match="bad config"): - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_succeed_and_persist_provider(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: provider is added to session and success dict is returned.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="2.0.0") - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - - icon = {"type": "emoji", "emoji": "🔧"} - - # Act - result = WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=_build_parameters(), - ) - - # Assert - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created: WorkflowToolProvider = dummy_session.added[0] - assert created.name == "tool_name" - assert created.label == "Tool" - assert created.icon == json.dumps(icon) - assert created.version == "2.0.0" - - def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Labels are forwarded to ToolLabelManager when provided.""" - # Arrange - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - - def query(m: type) -> FakeQuery: - if m is WorkflowToolProvider: - return FakeQuery(None) - return FakeQuery(app) - - fake_db = MagicMock() - fake_db.session = SimpleNamespace(query=query) - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - mock_label_mgr = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) - mock_to_ctrl = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", mock_to_ctrl - ) - - # Act - WorkflowToolManageService.create_workflow_tool( - user_id="u", - tenant_id="t", - workflow_app_id="app-id", - name="n", - label="L", - icon={}, - description="", - parameters=[], - labels=["tag1", "tag2"], - ) - - # Assert - mock_label_mgr.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestUpdateWorkflowTool -# --------------------------------------------------------------------------- - - -class TestUpdateWorkflowTool: - """Tests for WorkflowToolManageService.update_workflow_tool.""" - - def _make_provider(self) -> WorkflowToolProvider: - p = MagicMock(spec=WorkflowToolProvider) - p.app_id = "app-id" - p.tenant_id = "tenant-id" - return p - - def test_should_raise_when_name_duplicated(self, monkeypatch: pytest.MonkeyPatch) -> None: - """If another tool with the given name already exists, raise ValueError.""" - # Arrange - existing = MagicMock(spec=WorkflowToolProvider) - - def query(m: type) -> FakeQuery: - return FakeQuery(existing) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="dup", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the workflow tool to update does not exist.""" - # Arrange - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - # 1st call: name uniqueness check → None (no duplicate) - # 2nd call: fetch tool by id → None (not found) - return FakeQuery(None) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="missing", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the tool's referenced App has been removed.""" - # Arrange - provider = self._make_provider() - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - # 1st: duplicate name check (None), 2nd: fetch provider - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(None) # App not found - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the App exists but has no Workflow.""" - # Arrange - provider = self._make_provider() - app_no_wf = SimpleNamespace(workflow=None) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app_no_wf) - - monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Exceptions from from_db are re-raised as ValueError.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=MagicMock()), - ) - monkeypatch.setattr( - workflow_tools_manage_service.WorkflowToolProviderController, - "from_db", - MagicMock(side_effect=RuntimeError("from_db error")), - ) - - # Act + Assert - with pytest.raises(ValueError, match="from_db error"): - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - ) - - def test_should_succeed_and_call_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: provider fields are updated and session committed.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="3.0.0") - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - mock_commit = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=mock_commit), - ) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - - icon = {"type": "emoji", "emoji": "🛠"} - - # Act - result = WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="new_name", - label="New Label", - icon=icon, - description="new desc", - parameters=_build_parameters(), - ) - - # Assert - assert result == {"result": "success"} - mock_commit.assert_called_once() - assert provider.name == "new_name" - assert provider.label == "New Label" - assert provider.icon == json.dumps(icon) - assert provider.version == "3.0.0" - - def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Labels are forwarded to ToolLabelManager during update.""" - # Arrange - provider = self._make_provider() - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - call_count = {"n": 0} - - def query(m: type) -> FakeQuery: - call_count["n"] += 1 - if m is WorkflowToolProvider: - return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) - return FakeQuery(app) - - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=query, commit=MagicMock()), - ) - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) - mock_label_mgr = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", MagicMock() - ) - - # Act - WorkflowToolManageService.update_workflow_tool( - user_id="u", - tenant_id="t", - workflow_tool_id="tool-1", - name="n", - label="L", - icon={}, - description="", - parameters=[], - labels=["a"], - ) - - # Assert - mock_label_mgr.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestListTenantWorkflowTools -# --------------------------------------------------------------------------- - - -class TestListTenantWorkflowTools: - """Tests for WorkflowToolManageService.list_tenant_workflow_tools.""" - - def test_should_return_empty_list_when_no_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """An empty database yields an empty result list.""" - # Arrange - fake_scalars = MagicMock() - fake_scalars.all.return_value = [] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - assert result == [] - - def test_should_skip_broken_providers_and_log(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Providers that fail to load are logged and skipped.""" - # Arrange - good_provider = MagicMock(spec=WorkflowToolProvider) - good_provider.id = "good-id" - good_provider.app_id = "app-good" - bad_provider = MagicMock(spec=WorkflowToolProvider) - bad_provider.id = "bad-id" - bad_provider.app_id = "app-bad" - - fake_scalars = MagicMock() - fake_scalars.all.return_value = [good_provider, bad_provider] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - good_ctrl = MagicMock() - good_ctrl.provider_id = "good-id" - - def to_controller(provider: WorkflowToolProvider) -> MagicMock: - if provider is bad_provider: - raise RuntimeError("broken provider") - return good_ctrl - - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", to_controller - ) - mock_get_labels = MagicMock(return_value={}) - monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", mock_get_labels) - mock_to_user = MagicMock() - mock_to_user.return_value.tools = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_user_provider", mock_to_user - ) - monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) - mock_get_tools = MagicMock(return_value=[MagicMock()]) - good_ctrl.get_tools = mock_get_tools - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() - ) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - only good provider contributed - assert len(result) == 1 - - def test_should_return_tools_for_all_providers(self, monkeypatch: pytest.MonkeyPatch) -> None: - """All successfully loaded providers appear in the result.""" - # Arrange - provider = MagicMock(spec=WorkflowToolProvider) - provider.id = "p-1" - provider.app_id = "app-1" - - fake_scalars = MagicMock() - fake_scalars.all.return_value = [provider] - fake_db = MagicMock() - fake_db.session.scalars.return_value = fake_scalars - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - - ctrl = MagicMock() - ctrl.provider_id = "p-1" - ctrl.get_tools.return_value = [MagicMock()] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", MagicMock(return_value={"p-1": []}) - ) - user_provider = MagicMock() - user_provider.tools = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_user_provider", - MagicMock(return_value=user_provider), - ) - monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() - ) - - # Act - result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") - - # Assert - assert len(result) == 1 - assert result[0] is user_provider - - -# --------------------------------------------------------------------------- -# TestDeleteWorkflowTool -# --------------------------------------------------------------------------- - - -class TestDeleteWorkflowTool: - """Tests for WorkflowToolManageService.delete_workflow_tool.""" - - def test_should_delete_and_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: - """delete_workflow_tool queries, deletes, commits, and returns success.""" - # Arrange - mock_query = MagicMock() - mock_query.where.return_value.delete.return_value = 1 - mock_commit = MagicMock() - fake_session = SimpleNamespace(query=lambda m: mock_query, commit=mock_commit) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) - - # Act - result = WorkflowToolManageService.delete_workflow_tool("u", "t", "tool-1") - - # Assert - assert result == {"result": "success"} - mock_commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# TestGetWorkflowToolByToolId / ByAppId -# --------------------------------------------------------------------------- - - -class TestGetWorkflowToolByToolIdAndAppId: - """Tests for get_workflow_tool_by_tool_id and get_workflow_tool_by_app_id.""" - - def test_get_by_tool_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Raises ValueError when no WorkflowToolProvider found by tool id.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService.get_workflow_tool_by_tool_id("u", "t", "missing") - - def test_get_by_app_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Raises ValueError when no WorkflowToolProvider found by app id.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService.get_workflow_tool_by_app_id("u", "t", "missing-app") - - -# --------------------------------------------------------------------------- -# TestGetWorkflowTool (private _get_workflow_tool) -# --------------------------------------------------------------------------- - - -class TestGetWorkflowTool: - """Tests for the internal _get_workflow_tool helper.""" - - def test_should_raise_when_db_tool_none(self) -> None: - """_get_workflow_tool raises ValueError when db_tool is None.""" - with pytest.raises(ValueError, match="Tool not found"): - WorkflowToolManageService._get_workflow_tool("t", None) - - def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the corresponding App row is missing.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when App has no attached Workflow.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - app = SimpleNamespace(workflow=None) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="Workflow not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the controller returns no WorkflowTool instances.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - db_tool.id = "tool-1" - workflow = DummyWorkflow(graph_dict={"nodes": []}) - app = SimpleNamespace(workflow=workflow) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - ctrl = MagicMock() - ctrl.get_tools.return_value = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService._get_workflow_tool("t", db_tool) - - def test_should_return_dict_on_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: returns a dict with name, label, icon, synced, etc.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.app_id = "app-1" - db_tool.tenant_id = "t" - db_tool.id = "tool-1" - db_tool.name = "my_tool" - db_tool.label = "My Tool" - db_tool.icon = json.dumps({"emoji": "🔧"}) - db_tool.description = "some desc" - db_tool.privacy_policy = "" - db_tool.version = "1.0" - db_tool.parameter_configurations = [] - workflow = DummyWorkflow(graph_dict={"nodes": []}, version="1.0") - app = SimpleNamespace(workflow=workflow) - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(app)), - ) - - workflow_tool = MagicMock() - workflow_tool.entity.output_schema = {"type": "object"} - ctrl = MagicMock() - ctrl.get_tools.return_value = [workflow_tool] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - mock_convert = MagicMock(return_value={"tool": "api_entity"}) - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", mock_convert - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) - ) - - # Act - result = WorkflowToolManageService._get_workflow_tool("t", db_tool) - - # Assert - assert result["name"] == "my_tool" - assert result["label"] == "My Tool" - assert result["synced"] is True - assert "icon" in result - assert "output_schema" in result - - -# --------------------------------------------------------------------------- -# TestListSingleWorkflowTools -# --------------------------------------------------------------------------- - - -class TestListSingleWorkflowTools: - """Tests for WorkflowToolManageService.list_single_workflow_tools.""" - - def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the specified tool does not exist in DB.""" - # Arrange - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(None)), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: - """ValueError when the controller yields no tools for the provider.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.id = "tool-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(db_tool)), - ) - ctrl = MagicMock() - ctrl.get_tools.return_value = [] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - - # Act + Assert - with pytest.raises(ValueError, match="not found"): - WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - def test_should_return_api_entity_list(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Happy path: returns list with one ToolApiEntity.""" - # Arrange - db_tool = MagicMock(spec=WorkflowToolProvider) - db_tool.id = "tool-1" - db_tool.tenant_id = "t" - monkeypatch.setattr( - workflow_tools_manage_service.db, - "session", - SimpleNamespace(query=lambda m: FakeQuery(db_tool)), - ) - workflow_tool = MagicMock() - ctrl = MagicMock() - ctrl.get_tools.return_value = [workflow_tool] - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "workflow_provider_to_controller", - MagicMock(return_value=ctrl), - ) - api_entity = MagicMock() - monkeypatch.setattr( - workflow_tools_manage_service.ToolTransformService, - "convert_tool_entity_to_api_entity", - MagicMock(return_value=api_entity), - ) - monkeypatch.setattr( - workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) - ) - - # Act - result = WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") - - # Assert - assert result == [api_entity] From 75c3ef82d99b71ad30a5eed2fb6182185c229dc6 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:51:10 +0100 Subject: [PATCH 17/34] refactor: use EnumText for TenantCreditPool.pool_type (#33959) --- api/core/provider_manager.py | 4 ++-- api/models/model.py | 5 ++++- api/services/credit_pool_service.py | 6 +++++- .../services/test_credit_pool_service.py | 9 +++++---- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3c3fbd6dd2..6d2be0ab7a 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -918,11 +918,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.TRIAL.value, + pool_type=ProviderQuotaType.TRIAL, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.PAID.value, + pool_type=ProviderQuotaType.PAID, ) else: trail_pool = None diff --git a/api/models/model.py b/api/models/model.py index 331a5b7d8c..4541a3b23a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -44,6 +44,7 @@ from .enums import ( MessageChainType, MessageFileBelongsTo, MessageStatus, + ProviderQuotaType, TagType, ) from .provider_ids import GenericProviderID @@ -2491,7 +2492,9 @@ class TenantCreditPool(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + pool_type: Mapped[ProviderQuotaType] = mapped_column( + EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial" + ) quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 1954602571..2894826935 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -7,6 +7,7 @@ from configs import dify_config from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool +from models.enums import ProviderQuotaType logger = logging.getLogger(__name__) @@ -16,7 +17,10 @@ class CreditPoolService: def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" credit_pool = TenantCreditPool( - tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + tenant_id=tenant_id, + quota_limit=dify_config.HOSTED_POOL_CREDITS, + quota_used=0, + pool_type=ProviderQuotaType.TRIAL, ) db.session.add(credit_pool) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 25de0588fa..0f63d98642 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -6,6 +6,7 @@ import pytest from core.errors.error import QuotaExceededError from models import TenantCreditPool +from models.enums import ProviderQuotaType from services.credit_pool_service import CreditPoolService @@ -20,7 +21,7 @@ class TestCreditPoolService: assert isinstance(pool, TenantCreditPool) assert pool.tenant_id == tenant_id - assert pool.pool_type == "trial" + assert pool.pool_type == ProviderQuotaType.TRIAL assert pool.quota_used == 0 assert pool.quota_limit > 0 @@ -28,14 +29,14 @@ class TestCreditPoolService: tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) - result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial") + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) assert result is not None assert result.tenant_id == tenant_id - assert result.pool_type == "trial" + assert result.pool_type == ProviderQuotaType.TRIAL def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): - result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial") + result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None From dd4f504b39d1afbfeaf51199789da779e006f654 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:53:05 +0100 Subject: [PATCH 18/34] refactor: select in remaining console app controllers (#33969) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/conversation.py | 10 +++------- api/controllers/console/app/generator.py | 2 +- api/controllers/console/app/mcp_server.py | 14 +++++++------- api/controllers/console/app/model_config.py | 4 +--- api/controllers/console/app/site.py | 5 +++-- api/controllers/console/app/wraps.py | 10 +++++----- .../controllers/console/app/test_app_apis.py | 8 ++------ .../console/app/test_conversation_api.py | 12 ++---------- .../app/test_conversation_read_timestamp.py | 2 +- .../controllers/console/app/test_generator_api.py | 12 ++++-------- .../console/app/test_model_config_api.py | 5 +---- .../controllers/console/app/test_wraps.py | 8 ++------ 12 files changed, 32 insertions(+), 60 deletions(-) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 74750981dd..d329d22309 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -458,9 +458,7 @@ class ChatConversationApi(Resource): args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore subquery = ( - db.session.query( - Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") - ) + sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() ) @@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource): def _get_conversation(app_model, conversation_id): current_user, _ = current_account_with_tenant() - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() + conversation = db.session.scalar( + sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1) ) if not conversation: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index af4ac450bb..442d0d2324 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource): try: # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": - app = db.session.query(App).where(App.id == args.flow_id).first() + app = db.session.get(App, args.flow_id) if not app: return {"error": f"app {args.flow_id} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 4b20418b53..412fc8795a 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -47,7 +48,7 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_model) def get(self, app_model): - server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() + server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) return server @console_ns.doc("create_app_mcp_server") @@ -98,7 +99,7 @@ class AppMCPServerController(Resource): @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) - server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() + server = db.session.get(AppMCPServer, payload.id) if not server: raise NotFound() @@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource): @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() - server = ( - db.session.query(AppMCPServer) - .where(AppMCPServer.id == server_id) - .where(AppMCPServer.tenant_id == current_tenant_id) - .first() + server = db.session.scalar( + select(AppMCPServer) + .where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id) + .limit(1) ) if not server: raise NotFound() diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a85e54fb51..e9bd30ba7e 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -69,9 +69,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config - original_app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() - ) + original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id) if original_app_model_config is None: raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index db218d8b81..7f44a99ff1 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -2,6 +2,7 @@ from typing import Literal from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -75,7 +76,7 @@ class AppSite(Resource): def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound @@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index e687d980fa..493022ffea 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar, Union +from sqlalchemy import select + from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -15,16 +17,14 @@ R1 = TypeVar("R1") def _load_app_model(app_id: str) -> App | None: _, current_tenant_id = current_account_with_tenant() - app_model = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app_model = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) return app_model def _load_app_model_with_trial(app_id: str) -> App | None: - app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1)) return app_model diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index 60b8ee96fe..beb8ff55a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -281,12 +281,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr( site_module, @@ -305,12 +303,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 5db8e5c332..11b3b3470d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: conversation = SimpleNamespace(id="c1", app_id="app-1") - query = MagicMock() - query.where.return_value = query - query.first.return_value = conversation - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = conversation monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) @@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py index 460da06ecc..f588ab261d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(): ), patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, ): - mock_session.query.return_value.where.return_value.first.return_value = conversation + mock_session.scalar.return_value = conversation _get_conversation(app_model, "conversation-id") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index f83bc18da3..e64c508b82 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) with app.test_request_context( "/console/api/instruction-generate", @@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) _install_workflow_service(monkeypatch, workflow=None) with app.test_request_context( @@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace(graph_dict={"nodes": []}) _install_workflow_service(monkeypatch, workflow=workflow) @@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace( graph_dict={ diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index 61d92bb5c7..a0e2edb8cf 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc ) session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.first.return_value = original_config - session.query.return_value = query + session.get.return_value = original_config monkeypatch.setattr(model_config_module.db, "session", session) monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py index 7664e492da..b5f751f5a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -11,10 +11,8 @@ from models.model import AppMode def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model def handler(app_model): @@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) def handler(app_model): From 0492ed703457f186b5b7d29d4d8e813d088539c4 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 12:54:33 -0500 Subject: [PATCH 19/34] test: migrate api tools manage service tests to testcontainers (#33956) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../tools/test_api_tools_manage_service.py | 148 ++++ .../tools/test_api_tools_manage_service.py | 643 ------------------ 2 files changed, 148 insertions(+), 643 deletions(-) delete mode 100644 api/tests/unit_tests/services/tools/test_api_tools_manage_service.py diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index bffdca623a..d3e765055a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -536,3 +536,151 @@ class TestApiToolManageService: # Verify mock interactions mock_external_service_dependencies["encrypter"].assert_called_once() mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + + def test_delete_api_tool_provider_success( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of an API tool provider.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + provider = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert provider is not None + + result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert deleted is None + + def test_delete_api_tool_provider_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test deletion raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when original provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name="new-name", + original_provider="nonexistent", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_update_api_tool_provider_missing_auth_type( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when auth_type is missing from credentials.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + original_provider=provider_name, + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_list_api_tool_provider_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing tools raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent") + + def test_test_api_tool_preview_invalid_schema_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test preview raises ValueError for invalid schema type.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id=tenant.id, + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type="bad-schema-type", + schema="schema", + ) diff --git a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py deleted file mode 100644 index ce44818886..0000000000 --- a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py +++ /dev/null @@ -1,643 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.tools.entities.tool_entities import ApiProviderSchemaType -from services.tools.api_tools_manage_service import ApiToolManageService - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - # Arrange - mocked_db = mocker.patch("services.tools.api_tools_manage_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace: - return SimpleNamespace(operation_id=operation_id) - - -def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value), - ) - - # Act - result = ApiToolManageService.parser_api_schema("valid-schema") - - # Assert - assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value - assert len(result["credentials_schema"]) == 3 - assert "warning" in result - - -def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=RuntimeError("bad schema"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"): - ApiToolManageService.parser_api_schema("invalid") - - -def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None: - # Arrange - expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER) - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=expected, - ) - extra_info: dict[str, str] = {} - - # Act - result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info) - - # Assert - assert result == expected - - -def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=ValueError("parse failed"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema: parse failed"): - ApiToolManageService.convert_schema_to_tool_bundles("schema") - - -def test_create_api_tool_provider_should_raise_error_when_provider_already_exists( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="provider provider-a already exists"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name=" provider-a ", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - many_tools = [_tool_bundle(str(i)) for i in range(101)] - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=(many_tools, ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="the number of apis should be less than 100"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_create_provider_when_input_is_valid( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - mock_controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=mock_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.encrypt.return_value = {"auth_type": "none"} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") - - # Act - result = ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=["news"], - ) - - # Assert - assert result == {"result": "success"} - mock_controller.load_bundled_tools.assert_called_once() - mock_db.session.add.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.get", - return_value=SimpleNamespace(status_code=200, text="schema-content"), - ) - mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True}) - - # Act - result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") - - # Assert - assert result == {"schema": "schema-content"} - - -@pytest.mark.parametrize("status_code", [400, 404, 500]) -def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid( - status_code: int, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.get", - return_value=SimpleNamespace(status_code=status_code, text="schema-content"), - ) - mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger") - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema, please check the url you provided"): - ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") - mock_logger.exception.assert_called_once() - - -def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found( - mock_db: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="you have not added provider provider-a"): - ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") - - -def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")]) - mock_db.session.query.return_value.where.return_value.first.return_value = provider - controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", - return_value=controller, - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"]) - mock_convert = mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", - side_effect=[{"name": "tool-a"}, {"name": "tool-b"}], - ) - - # Act - result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") - - # Assert - assert result == [{"name": "tool-a"}, {"name": "tool-b"}] - assert mock_convert.call_count == 2 - - -def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found( - mock_db: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="api provider provider-a does not exists"): - ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - original_provider="provider-a", - icon={}, - credentials={"auth_type": "none"}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy=None, - custom_disclaimer="custom", - labels=[], - ) - - -def test_update_api_tool_provider_should_raise_error_when_auth_type_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace(credentials={}, name="old") - mock_db.session.query.return_value.where.return_value.first.return_value = provider - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - original_provider="provider-a", - icon={}, - credentials={}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy=None, - custom_disclaimer="custom", - labels=[], - ) - - -def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace( - credentials={"auth_type": "none", "api_key_value": "encrypted-old"}, - name="old", - icon="", - schema="", - description="", - schema_type_str="", - tools_str="", - privacy_policy="", - custom_disclaimer="", - credentials_str="", - ) - mock_db.session.query.return_value.where.return_value.first.return_value = provider - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=controller, - ) - cache = MagicMock() - encrypter = MagicMock() - encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"} - encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} - encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(encrypter, cache), - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") - - # Act - result = ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-new", - original_provider="provider-old", - icon={"emoji": "E"}, - credentials={"auth_type": "none", "api_key_value": "***"}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=["news"], - ) - - # Assert - assert result == {"result": "success"} - assert provider.name == "provider-new" - assert provider.privacy_policy == "privacy" - assert provider.credentials_str != "" - cache.delete.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="you have not added provider provider-a"): - ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") - - -def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None: - # Arrange - provider = object() - mock_db.session.query.return_value.where.return_value.first.return_value = provider - - # Act - result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") - - # Assert - assert result == {"result": "success"} - mock_db.session.delete.assert_called_once_with(provider) - mock_db.session.commit.assert_called_once() - - -def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None: - # Arrange - expected = {"provider": "value"} - mock_get = mocker.patch( - "services.tools.api_tools_manage_service.ToolManager.user_get_api_provider", - return_value=expected, - ) - - # Act - result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a") - - # Assert - assert result == expected - mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1") - - -def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None: - # Arrange - schema_type = "bad-schema-type" - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema type"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=schema_type, # type: ignore[arg-type] - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=RuntimeError("invalid"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") - - # Act + Assert - with pytest.raises(ValueError, match="invalid tool name tool-b"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-b", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_auth_type_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) - mock_db.session.query.return_value.where.return_value.first.return_value = db_provider - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - provider_controller = MagicMock() - tool_obj = MagicMock() - tool_obj.fork_tool_runtime.return_value = tool_obj - tool_obj.validate_credentials.side_effect = ValueError("validation failed") - provider_controller.get_tool.return_value = tool_obj - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=provider_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.decrypt.return_value = {"auth_type": "none"} - mock_encrypter.mask_plugin_credentials.return_value = {} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - - # Act - result = ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - # Assert - assert result == {"error": "validation failed"} - - -def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) - mock_db.session.query.return_value.where.return_value.first.return_value = db_provider - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - provider_controller = MagicMock() - tool_obj = MagicMock() - tool_obj.fork_tool_runtime.return_value = tool_obj - tool_obj.validate_credentials.return_value = {"ok": True} - provider_controller.get_tool.return_value = tool_obj - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=provider_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.decrypt.return_value = {"auth_type": "none"} - mock_encrypter.mask_plugin_credentials.return_value = {} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - - # Act - result = ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={"x": "1"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - # Assert - assert result == {"result": {"ok": True}} - - -def test_list_api_tools_should_return_all_user_providers_with_converted_tools( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider_one = SimpleNamespace(name="p1") - provider_two = SimpleNamespace(name="p2") - mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two] - - controller_one = MagicMock() - controller_one.get_tools.return_value = ["tool-a"] - controller_two = MagicMock() - controller_two.get_tools.return_value = ["tool-b", "tool-c"] - - user_provider_one = SimpleNamespace(labels=[], tools=[]) - user_provider_two = SimpleNamespace(labels=[], tools=[]) - - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", - side_effect=[controller_one, controller_two], - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"]) - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider", - side_effect=[user_provider_one, user_provider_two], - ) - mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider") - mock_convert = mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", - side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}], - ) - - # Act - result = ApiToolManageService.list_api_tools("tenant-1") - - # Assert - assert len(result) == 2 - assert user_provider_one.tools == [{"name": "tool-a"}] - assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}] - assert mock_convert.call_count == 3 From f2c71f3668227097f8d3770d1c28020bf17d51a5 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 13:15:22 -0500 Subject: [PATCH 20/34] test: migrate oauth server service tests to testcontainers (#33958) --- .../services/test_oauth_server_service.py | 174 ++++++++++++++ .../services/test_oauth_server_service.py | 224 ------------------ 2 files changed, 174 insertions(+), 224 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_oauth_server_service.py delete mode 100644 api/tests/unit_tests/services/test_oauth_server_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py new file mode 100644 index 0000000000..c146a5924b --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -0,0 +1,174 @@ +"""Testcontainers integration tests for OAuthServerService.""" + +from __future__ import annotations + +import uuid +from typing import cast +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import BadRequest + +from models.model import OAuthProviderApp +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +class TestOAuthServerServiceGetProviderApp: + """DB-backed tests for get_oauth_provider_app.""" + + def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + app = OAuthProviderApp( + app_icon="icon.png", + client_id=client_id, + client_secret=str(uuid4()), + app_label={"en-US": "Test OAuth App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + client_id = f"client-{uuid4()}" + created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) + + result = OAuthServerService.get_oauth_provider_app(client_id) + + assert result is not None + assert result.client_id == client_id + assert result.id == created.id + + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") + + assert result is None + + +class TestOAuthServerServiceTokenOperations: + """Redis-backed tests for token sign/validate operations.""" + + @pytest.fixture + def mock_redis(self): + with patch("services.oauth_server.redis_client") as mock: + yield mock + + def test_sign_authorization_code_stores_and_returns_code(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + assert code == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code), + "user-1", + ex=600, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis): + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids): + mock_redis.get.return_value = b"user-1" + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + mock_redis.delete.assert_called_once_with(code_key) + mock_redis.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + mock_redis.get.return_value = b"user-1" + + access_token, returned_refresh = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + assert access_token == str(deterministic_uuid) + assert returned_refresh == "refresh-1" + + def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis): + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1") + + assert result is None + + def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + assert refresh_token == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_validate_access_token_returns_none_when_not_found(self, mock_redis): + mock_redis.get.return_value = None + + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + assert result is None + + def test_validate_access_token_loads_user_when_exists(self, mock_redis): + mock_redis.get.return_value = b"user-88" + expected_user = MagicMock() + + with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load: + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + assert result is expected_user + mock_load.assert_called_once_with("user-88") diff --git a/api/tests/unit_tests/services/test_oauth_server_service.py b/api/tests/unit_tests/services/test_oauth_server_service.py deleted file mode 100644 index 231ceb74dc..0000000000 --- a/api/tests/unit_tests/services/test_oauth_server_service.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -import uuid -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from werkzeug.exceptions import BadRequest - -from services.oauth_server import ( - OAUTH_ACCESS_TOKEN_EXPIRES_IN, - OAUTH_ACCESS_TOKEN_REDIS_KEY, - OAUTH_AUTHORIZATION_CODE_REDIS_KEY, - OAUTH_REFRESH_TOKEN_EXPIRES_IN, - OAUTH_REFRESH_TOKEN_REDIS_KEY, - OAuthGrantType, - OAuthServerService, -) - - -@pytest.fixture -def mock_redis_client(mocker: MockerFixture) -> MagicMock: - return mocker.patch("services.oauth_server.redis_client") - - -@pytest.fixture -def mock_session(mocker: MockerFixture) -> MagicMock: - """Mock the OAuth server Session context manager.""" - mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object())) - session = MagicMock() - session_cm = MagicMock() - session_cm.__enter__.return_value = session - mocker.patch("services.oauth_server.Session", return_value=session_cm) - return session - - -def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None: - # Arrange - mock_execute_result = MagicMock() - expected_app = MagicMock() - mock_execute_result.scalar_one_or_none.return_value = expected_app - mock_session.execute.return_value = mock_execute_result - - # Act - result = OAuthServerService.get_oauth_provider_app("client-1") - - # Assert - assert result is expected_app - mock_session.execute.assert_called_once() - mock_execute_result.scalar_one_or_none.assert_called_once() - - -def test_sign_oauth_authorization_code_should_store_code_and_return_value( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - - # Act - code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") - - # Assert - expected_code = str(deterministic_uuid) - assert code == expected_code - mock_redis_client.set.assert_called_once_with( - OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code), - "user-1", - ex=600, - ) - - -def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act + Assert - with pytest.raises(BadRequest, match="invalid code"): - OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - code="bad-code", - client_id="client-1", - ) - - -def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - token_uuids = [ - uuid.UUID("00000000-0000-0000-0000-000000000201"), - uuid.UUID("00000000-0000-0000-0000-000000000202"), - ] - mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids) - mock_redis_client.get.return_value = b"user-1" - code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") - - # Act - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - code="code-1", - client_id="client-1", - ) - - # Assert - assert access_token == str(token_uuids[0]) - assert refresh_token == str(token_uuids[1]) - mock_redis_client.delete.assert_called_once_with(code_key) - mock_redis_client.set.assert_any_call( - OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), - b"user-1", - ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, - ) - mock_redis_client.set.assert_any_call( - OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), - b"user-1", - ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, - ) - - -def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act + Assert - with pytest.raises(BadRequest, match="invalid refresh token"): - OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.REFRESH_TOKEN, - refresh_token="stale-token", - client_id="client-1", - ) - - -def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - mock_redis_client.get.return_value = b"user-1" - - # Act - access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.REFRESH_TOKEN, - refresh_token="refresh-1", - client_id="client-1", - ) - - # Assert - assert access_token == str(deterministic_uuid) - assert returned_refresh_token == "refresh-1" - mock_redis_client.set.assert_called_once_with( - OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), - b"user-1", - ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, - ) - - -def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None: - # Arrange - grant_type = cast(OAuthGrantType, "invalid-grant-type") - - # Act - result = OAuthServerService.sign_oauth_access_token( - grant_type=grant_type, - client_id="client-1", - ) - - # Assert - assert result is None - - -def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - - # Act - refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") - - # Assert - assert refresh_token == str(deterministic_uuid) - mock_redis_client.set.assert_called_once_with( - OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), - "user-2", - ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, - ) - - -def test_validate_oauth_access_token_should_return_none_when_token_not_found( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act - result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") - - # Assert - assert result is None - - -def test_validate_oauth_access_token_should_load_user_when_token_exists( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - mock_redis_client.get.return_value = b"user-88" - expected_user = MagicMock() - mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user) - - # Act - result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") - - # Assert - assert result is expected_user - mock_load_user.assert_called_once_with("user-88") From 5d2cb3cd803c0ef152c4c90a1e2752297538f175 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:37:51 +0100 Subject: [PATCH 21/34] refactor: use EnumText for DocumentSegment.type (#33979) --- api/models/dataset.py | 5 ++++- api/models/enums.py | 7 +++++++ api/services/dataset_service.py | 5 +++-- api/tests/unit_tests/services/segment_service.py | 3 ++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/api/models/dataset.py b/api/models/dataset.py index d0163e6984..e3cbbf9cb9 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -43,6 +43,7 @@ from .enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, SummaryStatus, ) from .model import App, Tag, TagBinding, UploadFile @@ -998,7 +999,9 @@ class ChildChunk(Base): # indexing fields index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) - type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) diff --git a/api/models/enums.py b/api/models/enums.py index 8aca1df2b4..cdec7b2f12 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum): TIME = "time" +class SegmentType(StrEnum): + """Document segment type""" + + AUTOMATIC = "automatic" + CUSTOMIZED = "customized" + + class SegmentStatus(StrEnum): """Document segment status""" diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dc..ba4ab6757f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -58,6 +58,7 @@ from models.enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -3786,7 +3787,7 @@ class SegmentService: child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED update_child_chunks.append(child_chunk) else: new_child_chunks_args.append(child_chunk_update_args) @@ -3845,7 +3846,7 @@ class SegmentService: child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) db.session.commit() diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index affbc8d0b5..cc2c0a8032 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -4,6 +4,7 @@ import pytest from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import SegmentType from services.dataset_service import SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -77,7 +78,7 @@ class SegmentTestDataFactory: chunk.word_count = word_count chunk.index_node_id = f"node-{chunk_id}" chunk.index_node_hash = "hash-123" - chunk.type = "automatic" + chunk.type = SegmentType.AUTOMATIC chunk.created_by = "user-123" chunk.updated_by = None chunk.updated_at = None From cc17c8e883accea6a07a8be059b61da0fa4e2455 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:38:29 +0100 Subject: [PATCH 22/34] refactor: use EnumText for TidbAuthBinding.status and MessageFile.type (#33975) --- .../vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py | 3 ++- .../rag/datasource/vdb/tidb_on_qdrant/tidb_service.py | 3 ++- api/models/dataset.py | 5 ++++- api/models/model.py | 4 ++-- api/schedule/create_tidb_serverless_task.py | 3 ++- api/schedule/update_tidb_serverless_status_task.py | 6 +++++- .../services/test_messages_clean_service.py | 3 ++- .../app/task_pipeline/test_easy_ui_message_end_files.py | 8 ++++---- 8 files changed, 23 insertions(+), 12 deletions(-) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 71b6fa0a9b..3c1d5e015f 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -33,6 +33,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding +from models.enums import TidbAuthBindingStatus if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): password=new_cluster["password"], tenant_id=dataset.tenant_id, active=True, - status="ACTIVE", + status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 754c149241..06b17b9e62 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -9,6 +9,7 @@ from configs import dify_config from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus class TidbService: @@ -170,7 +171,7 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = "ACTIVE" + cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" db.session.add(cluster_info) db.session.commit() diff --git a/api/models/dataset.py b/api/models/dataset.py index e3cbbf9cb9..4c6152ed3f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -45,6 +45,7 @@ from .enums import ( SegmentStatus, SegmentType, SummaryStatus, + TidbAuthBindingStatus, ) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -1242,7 +1243,9 @@ class TidbAuthBinding(TypeBase): cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) + status: Mapped[TidbAuthBindingStatus] = mapped_column( + EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'") + ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/model.py b/api/models/model.py index 4541a3b23a..05233f8711 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -21,7 +21,7 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] @@ -1785,7 +1785,7 @@ class MessageFile(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False) transfer_method: Mapped[FileTransferMethod] = mapped_column( EnumText(FileTransferMethod, length=255), nullable=False ) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8b9d973d6d..6ceb3ef856 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -8,6 +8,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -57,7 +58,7 @@ def create_clusters(batch_size): account=new_cluster["account"], password=new_cluster["password"], active=False, - status="CREATING", + status=TidbAuthBindingStatus.CREATING, ) db.session.add(tidb_auth_binding) db.session.commit() diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1befa0e8b5..10003b1b97 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -9,6 +9,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -18,7 +19,10 @@ def update_tidb_serverless_status_task(): try: # check the number of idle tidb serverless tidb_serverless_list = db.session.scalars( - select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + select(TidbAuthBinding).where( + TidbAuthBinding.active == False, + TidbAuthBinding.status == TidbAuthBindingStatus.CREATING, + ) ).all() if len(tidb_serverless_list) == 0: return diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 8707f2e827..57bbc73b50 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from dify_graph.file.enums import FileType from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration: # MessageFile file = MessageFile( message_id=message.id, - type="image", + type=FileType.IMAGE, transfer_method="local_file", url="http://example.com/test.jpg", belongs_to=MessageFileBelongsTo.USER, diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 582990c88a..37dd116470 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -21,7 +21,7 @@ from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from dify_graph.file.enums import FileTransferMethod +from dify_graph.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile @@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.LOCAL_FILE message_file.upload_file_id = str(uuid.uuid4()) message_file.url = None - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.REMOTE_URL message_file.upload_file_id = None message_file.url = "https://example.com/image.jpg" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.TOOL_FILE message_file.upload_file_id = None message_file.url = "tool_file_123.png" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture From 49a1fae55561d8924df19f7085bef4a5a944264d Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 16:04:34 -0500 Subject: [PATCH 23/34] test: migrate password reset tests to testcontainers (#33974) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/auth/test_password_reset.py | 109 +++--------------- 1 file changed, 17 insertions(+), 92 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/controllers/console/auth/test_password_reset.py (81%) diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py similarity index 81% rename from api/tests/unit_tests/controllers/console/auth/test_password_reset.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 9488cf528e..8f9db287e3 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -1,17 +1,10 @@ -""" -Test suite for password reset authentication flows. +"""Testcontainers integration tests for password reset authentication flows.""" -This module tests the password reset mechanism including: -- Password reset email sending -- Verification code validation -- Password reset with token -- Rate limiting and security checks -""" +from __future__ import annotations from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError -@pytest.fixture(autouse=True) -def _mock_forgot_password_session(): - with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - mock_session_cls.return_value.__exit__.return_value = None - yield mock_session - - -@pytest.fixture(autouse=True) -def _mock_forgot_password_db(): - with patch("controllers.console.auth.forgot_password.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, ): - """ - Test successful password reset email sending. - - Verifies that: - - Email is sent to valid account - - Reset token is generated and returned - - IP rate limiting is checked - """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" @@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi: assert response["data"] == "reset_token_123" mock_send_email.assert_called_once() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): """ Test password reset email blocked by IP rate limit. @@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi: - No email is sent when rate limited """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi: (None, "en-US"), # Defaults to en-US when not provided ], ) - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, language_input, @@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi: - Unsupported languages default to en-US """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "token" @@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): """ @@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi: - Rate limit is reset on success """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_generate_token.return_value = (None, "new_token") @@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi: ) mock_reset_rate_limit.assert_called_once_with("test@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} mock_generate_token.return_value = (None, "fresh-token") @@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token.assert_called_once_with("upper_token") mock_reset_rate_limit.assert_called_once_with("user@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app): """ Test code verification blocked by rate limit. @@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi: - Prevents brute force attacks on verification codes """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True # Act & Assert @@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(EmailPasswordResetLimitError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with invalid token. @@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = None @@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with mismatched email. @@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi: - Prevents token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} @@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidEmailError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): """ Test code verification with incorrect code. @@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi: - Rate limit counter is incremented """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} @@ -380,11 +321,8 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -394,7 +332,6 @@ class TestForgotPasswordResetApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -405,7 +342,6 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - mock_wraps_db, app, mock_account, ): @@ -418,7 +354,6 @@ class TestForgotPasswordResetApi: - Success response is returned """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -436,9 +371,8 @@ class TestForgotPasswordResetApi: assert response["result"] == "success" mock_revoke_token.assert_called_once_with("valid_token") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, mock_db, app): + def test_reset_password_mismatch(self, mock_get_data, app): """ Test password reset with mismatched passwords. @@ -447,7 +381,6 @@ class TestForgotPasswordResetApi: - No password update occurs """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} # Act & Assert @@ -460,9 +393,8 @@ class TestForgotPasswordResetApi: with pytest.raises(PasswordMismatchError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, mock_db, app): + def test_reset_password_invalid_token(self, mock_get_data, app): """ Test password reset with invalid token. @@ -470,7 +402,6 @@ class TestForgotPasswordResetApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -483,9 +414,8 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app): + def test_reset_password_wrong_phase(self, mock_get_data, app): """ Test password reset with token not in reset phase. @@ -494,7 +424,6 @@ class TestForgotPasswordResetApi: - Prevents use of verification-phase tokens for reset """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"} # Act & Assert @@ -507,13 +436,10 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found( - self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app - ): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): """ Test password reset for non-existent account. @@ -521,7 +447,6 @@ class TestForgotPasswordResetApi: - AccountNotFound is raised when account doesn't exist """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} mock_get_account.return_value = None From 075b8bf1aeac34195c9e72049b04cb80613a59bc Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 24 Mar 2026 10:04:08 +0800 Subject: [PATCH 24/34] fix(web): update account settings header (#33965) --- .../account-setting/__tests__/index.spec.tsx | 4 +-- .../header/account-setting/index.tsx | 31 +++++++------------ 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/web/app/components/header/account-setting/__tests__/index.spec.tsx b/web/app/components/header/account-setting/__tests__/index.spec.tsx index 2aa9db4771..279af0b114 100644 --- a/web/app/components/header/account-setting/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/__tests__/index.spec.tsx @@ -315,14 +315,14 @@ describe('AccountSetting', () => { it('should handle scroll event in panel', () => { // Act renderAccountSetting() - const scrollContainer = screen.getByRole('dialog').querySelector('.overflow-y-auto') + const scrollContainer = screen.getByRole('dialog').querySelector('.overscroll-contain') // Assert expect(scrollContainer).toBeInTheDocument() if (scrollContainer) { // Scroll down fireEvent.scroll(scrollContainer, { target: { scrollTop: 100 } }) - expect(scrollContainer).toHaveClass('overflow-y-auto') + expect(scrollContainer).toHaveClass('overscroll-contain') // Scroll back up fireEvent.scroll(scrollContainer, { target: { scrollTop: 0 } }) diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index 7e77af2e5f..bfceaeb059 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -1,8 +1,9 @@ 'use client' import type { AccountSettingTab } from '@/app/components/header/account-setting/constants' -import { useCallback, useEffect, useRef, useState } from 'react' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import SearchInput from '@/app/components/base/search-input' +import { ScrollArea } from '@/app/components/base/ui/scroll-area' import BillingPage from '@/app/components/billing/billing-page' import CustomPage from '@/app/components/custom/custom-page' import { @@ -129,20 +130,6 @@ export default function AccountSetting({ ], }, ] - const scrollRef = useRef(null) - const [scrolled, setScrolled] = useState(false) - useEffect(() => { - const targetElement = scrollRef.current - const scrollHandle = (e: Event) => { - const userScrolled = (e.target as HTMLDivElement).scrollTop > 0 - setScrolled(userScrolled) - } - targetElement?.addEventListener('scroll', scrollHandle) - return () => { - targetElement?.removeEventListener('scroll', scrollHandle) - } - }, []) - const activeItem = [...menuItems[0].items, ...menuItems[1].items].find(item => item.key === activeMenu) const [searchValue, setSearchValue] = useState('') @@ -201,7 +188,7 @@ export default function AccountSetting({ } -
+
-
-
+ +
{activeItem?.name} {activeItem?.description && ( @@ -241,7 +234,7 @@ export default function AccountSetting({ {activeMenu === ACCOUNT_SETTING_TAB.CUSTOM && } {activeMenu === ACCOUNT_SETTING_TAB.LANGUAGE && }
-
+
From fbd558762dc7dc0304c4e9efaa169505732c6098 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Tue, 24 Mar 2026 10:36:48 +0800 Subject: [PATCH 25/34] fix: fix chunk not display in indexed document (#33942) --- .../__tests__/document-settings.spec.tsx | 59 +++++++++++++++++++ .../detail/settings/document-settings.tsx | 26 +++++++- web/models/datasets.ts | 6 +- web/service/knowledge/use-create-dataset.ts | 6 +- 4 files changed, 92 insertions(+), 5 deletions(-) diff --git a/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx b/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx index 4ac30289e1..bf516d432b 100644 --- a/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx +++ b/web/app/components/datasets/documents/detail/settings/__tests__/document-settings.spec.tsx @@ -224,6 +224,20 @@ describe('DocumentSettings', () => { // Data source types describe('Data Source Types', () => { + it('should handle upload_file_id data source format', () => { + mockDocumentDetail = { + name: 'test-document', + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: '4a807f05-45d6-4fc4-b7a8-b009a4568b36', + }, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + it('should handle legacy upload_file data source', () => { mockDocumentDetail = { name: 'test-document', @@ -307,6 +321,18 @@ describe('DocumentSettings', () => { expect(screen.getByTestId('files-count')).toHaveTextContent('0') }) + it('should handle empty data_source_info object', () => { + mockDocumentDetail = { + name: 'test-document', + data_source_type: 'upload_file', + data_source_info: {}, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('0') + }) + it('should maintain structure when rerendered', () => { const { rerender } = render( , @@ -317,4 +343,37 @@ describe('DocumentSettings', () => { expect(screen.getByTestId('step-two')).toBeInTheDocument() }) }) + + describe('Files Extraction Regression Tests', () => { + it('should correctly extract file ID from upload_file_id format', () => { + const fileId = '4a807f05-45d6-4fc4-b7a8-b009a4568b36' + mockDocumentDetail = { + name: 'test-document.pdf', + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: fileId, + }, + } + + render() + + // Verify files array is populated with correct file ID + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + + it('should preserve document name when using upload_file_id format', () => { + const documentName = 'my-uploaded-document.txt' + mockDocumentDetail = { + name: documentName, + data_source_type: 'upload_file', + data_source_info: { + upload_file_id: 'some-file-id', + }, + } + + render() + + expect(screen.getByTestId('files-count')).toHaveTextContent('1') + }) + }) }) diff --git a/web/app/components/datasets/documents/detail/settings/document-settings.tsx b/web/app/components/datasets/documents/detail/settings/document-settings.tsx index bcbc149231..2b6cc77683 100644 --- a/web/app/components/datasets/documents/detail/settings/document-settings.tsx +++ b/web/app/components/datasets/documents/detail/settings/document-settings.tsx @@ -8,6 +8,7 @@ import type { LegacyDataSourceInfo, LocalFileInfo, OnlineDocumentInfo, + UploadFileIdInfo, WebsiteCrawlInfo, } from '@/models/datasets' import { useBoolean } from 'ahooks' @@ -61,6 +62,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { const dataSourceInfo = documentDetail?.data_source_info + // Type guards for DataSourceInfo union const isLegacyDataSourceInfo = (info: DataSourceInfo | undefined): info is LegacyDataSourceInfo => { return !!info && 'upload_file' in info } @@ -73,10 +75,15 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { const isLocalFileInfo = (info: DataSourceInfo | undefined): info is LocalFileInfo => { return !!info && 'related_id' in info && 'transfer_method' in info } + const isUploadFileIdInfo = (info: DataSourceInfo | undefined): info is UploadFileIdInfo => { + return !!info && 'upload_file_id' in info + } + const legacyInfo = isLegacyDataSourceInfo(dataSourceInfo) ? dataSourceInfo : undefined const websiteInfo = isWebsiteCrawlInfo(dataSourceInfo) ? dataSourceInfo : undefined const onlineDocumentInfo = isOnlineDocumentInfo(dataSourceInfo) ? dataSourceInfo : undefined const localFileInfo = isLocalFileInfo(dataSourceInfo) ? dataSourceInfo : undefined + const uploadFileIdInfo = isUploadFileIdInfo(dataSourceInfo) ? dataSourceInfo : undefined const currentPage = useMemo(() => { if (legacyInfo) { @@ -101,8 +108,20 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { }, [documentDetail?.data_source_type, documentDetail?.name, legacyInfo, onlineDocumentInfo]) const files = useMemo(() => { - if (legacyInfo?.upload_file) - return [legacyInfo.upload_file as CustomFile] + // Handle upload_file_id format + if (uploadFileIdInfo) { + return [{ + id: uploadFileIdInfo.upload_file_id, + name: documentDetail?.name || '', + } as unknown as CustomFile] + } + + // Handle legacy upload_file format + if (legacyInfo?.upload_file) { + return [legacyInfo.upload_file as unknown as CustomFile] + } + + // Handle local file info format if (localFileInfo) { const { related_id, name, extension } = localFileInfo return [{ @@ -111,8 +130,9 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { extension, } as unknown as CustomFile] } + return [] - }, [legacyInfo?.upload_file, localFileInfo]) + }, [uploadFileIdInfo, legacyInfo?.upload_file, localFileInfo, documentDetail?.name]) const websitePages = useMemo(() => { if (!websiteInfo) diff --git a/web/models/datasets.ts b/web/models/datasets.ts index ed16e1a67c..e4793357f4 100644 --- a/web/models/datasets.ts +++ b/web/models/datasets.ts @@ -381,7 +381,11 @@ export type OnlineDriveInfo = { type: 'file' | 'folder' } -export type DataSourceInfo = LegacyDataSourceInfo | LocalFileInfo | OnlineDocumentInfo | WebsiteCrawlInfo +export type UploadFileIdInfo = { + upload_file_id: string +} + +export type DataSourceInfo = LegacyDataSourceInfo | LocalFileInfo | OnlineDocumentInfo | WebsiteCrawlInfo | UploadFileIdInfo export type InitialDocumentDetail = { id: string diff --git a/web/service/knowledge/use-create-dataset.ts b/web/service/knowledge/use-create-dataset.ts index a0d55eeb99..297bb44827 100644 --- a/web/service/knowledge/use-create-dataset.ts +++ b/web/service/knowledge/use-create-dataset.ts @@ -91,11 +91,15 @@ const getFileIndexingEstimateParamsForFile = ({ processRule, dataset_id, }: GetFileIndexingEstimateParamsOptionFile): IndexingEstimateParams => { + const fileIds = files + .map(file => file.id) + .filter((id): id is string => Boolean(id)) + return { info_list: { data_source_type: dataSourceType, file_info_list: { - file_ids: files.map(file => file.id) as string[], + file_ids: fileIds, }, }, indexing_technique: indexingTechnique, From 27c4faad4f717624564d0f537ca865eaf8559099 Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Tue, 24 Mar 2026 10:52:27 +0800 Subject: [PATCH 26/34] ci: update actions version, fix cache (#33950) --- .github/actions/setup-web/action.yml | 9 +++---- .github/workflows/style.yml | 29 +++++++++++++-------- .github/workflows/translate-i18n-claude.yml | 2 +- web/package.json | 6 ++--- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml index 6f3b3c08b4..24af948732 100644 --- a/.github/actions/setup-web/action.yml +++ b/.github/actions/setup-web/action.yml @@ -4,10 +4,9 @@ runs: using: composite steps: - name: Setup Vite+ - uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0 + uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0 with: - node-version-file: web/.nvmrc + working-directory: web + node-version-file: .nvmrc cache: true - cache-dependency-path: web/pnpm-lock.yaml - run-install: | - cwd: ./web + run-install: true diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 657a481f74..23ae36f7b1 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -84,20 +84,20 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: ./.github/actions/setup-web + - name: Restore ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' + id: eslint-cache-restore + uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: | - vp run lint:ci - # pnpm run lint:report - # continue-on-error: true - - # - name: Annotate Code - # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request' - # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae - # with: - # eslint-report: web/eslint_report.json - # github-token: ${{ secrets.GITHUB_TOKEN }} + run: vp run lint:ci - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' @@ -114,6 +114,13 @@ jobs: working-directory: ./web run: vp run knip + - name: Save ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} + superlinter: name: SuperLinter runs-on: ubuntu-latest diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 84f8000a01..1869254295 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -120,7 +120,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76 + uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/web/package.json b/web/package.json index 7d82b6afde..ee8dbb466e 100644 --- a/web/package.json +++ b/web/package.json @@ -39,9 +39,9 @@ "i18n:check": "tsx ./scripts/check-i18n.js", "knip": "knip", "lint": "eslint --cache --concurrency=auto", - "lint:ci": "eslint --cache --concurrency 2", - "lint:fix": "pnpm lint --fix", - "lint:quiet": "pnpm lint --quiet", + "lint:ci": "eslint --cache --cache-strategy content --concurrency 2", + "lint:fix": "vp run lint --fix", + "lint:quiet": "vp run lint --quiet", "lint:tss": "tsslint --project tsconfig.json", "preinstall": "npx only-allow pnpm", "prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky ./web/.husky", From 0589fa423bdee5a4e7f907fa2459bb956c6af48f Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 24 Mar 2026 11:24:31 +0800 Subject: [PATCH 27/34] fix(sdk): patch flatted vulnerability in nodejs client lockfile (#33996) --- sdks/nodejs-client/package.json | 1 + sdks/nodejs-client/pnpm-lock.yaml | 22 ++++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index 7c8a293446..728aa0d054 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -69,6 +69,7 @@ }, "pnpm": { "overrides": { + "flatted@<=3.4.1": "3.4.2", "rollup@>=4.0.0,<4.59.0": "4.59.0" } } diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml index c4b299cd73..c9081420f5 100644 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ b/sdks/nodejs-client/pnpm-lock.yaml @@ -5,6 +5,7 @@ settings: excludeLinksFromLockfile: false overrides: + flatted@<=3.4.1: 3.4.2 rollup@>=4.0.0,<4.59.0: 4.59.0 importers: @@ -324,66 +325,79 @@ packages: resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==} cpu: [arm] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.59.0': resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==} cpu: [arm] os: [linux] + libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.59.0': resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==} cpu: [arm64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.59.0': resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==} cpu: [arm64] os: [linux] + libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.59.0': resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==} cpu: [loong64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-loong64-musl@4.59.0': resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==} cpu: [loong64] os: [linux] + libc: [musl] '@rollup/rollup-linux-ppc64-gnu@4.59.0': resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==} cpu: [ppc64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-ppc64-musl@4.59.0': resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==} cpu: [ppc64] os: [linux] + libc: [musl] '@rollup/rollup-linux-riscv64-gnu@4.59.0': resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==} cpu: [riscv64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.59.0': resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==} cpu: [riscv64] os: [linux] + libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.59.0': resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==} cpu: [s390x] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.59.0': resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==} cpu: [x64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-musl@4.59.0': resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==} cpu: [x64] os: [linux] + libc: [musl] '@rollup/rollup-openbsd-x64@4.59.0': resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==} @@ -741,8 +755,8 @@ packages: resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} engines: {node: '>=16'} - flatted@3.4.1: - resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==} + flatted@3.4.2: + resolution: {integrity: sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==} follow-redirects@1.15.11: resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} @@ -1836,10 +1850,10 @@ snapshots: flat-cache@4.0.1: dependencies: - flatted: 3.4.1 + flatted: 3.4.2 keyv: 4.5.4 - flatted@3.4.1: {} + flatted@3.4.2: {} follow-redirects@1.15.11: {} From ecd3a964c1f5599642c25c9c8d8e753e8763d07f Mon Sep 17 00:00:00 2001 From: BitToby <218712309+bittoby@users.noreply.github.com> Date: Tue, 24 Mar 2026 06:22:17 +0200 Subject: [PATCH 28/34] refactor(api): type auth service credentials with TypedDict (#33867) --- api/services/auth/api_key_auth_base.py | 10 +++++++++- api/services/auth/api_key_auth_factory.py | 4 ++-- api/services/auth/firecrawl/firecrawl.py | 4 ++-- api/services/auth/jina.py | 4 ++-- api/services/auth/jina/jina.py | 4 ++-- api/services/auth/watercrawl/watercrawl.py | 4 ++-- .../unit_tests/services/auth/test_api_key_auth_base.py | 6 +++--- .../services/auth/test_api_key_auth_factory.py | 4 ++-- 8 files changed, 24 insertions(+), 16 deletions(-) diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py index dd74a8f1b5..2e1b723e82 100644 --- a/api/services/auth/api_key_auth_base.py +++ b/api/services/auth/api_key_auth_base.py @@ -1,8 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any + +from typing_extensions import TypedDict + + +class AuthCredentials(TypedDict): + auth_type: str + config: dict[str, Any] class ApiKeyAuthBase(ABC): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): self.credentials = credentials @abstractmethod diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 7ae31b0768..6e183b70e3 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,9 +1,9 @@ -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials from services.auth.auth_type import AuthType class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): + def __init__(self, provider: str, credentials: AuthCredentials): auth_factory = self.get_apikey_auth_factory(provider) self.auth = auth_factory(credentials) diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index b002706931..c9e5610aea 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class FirecrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index b2d28a83d1..cbdc908690 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -3,11 +3,11 @@ from urllib.parse import urljoin import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class WatercrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "x-api-key": diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py index b5d91ef3fb..388504c07f 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py @@ -13,13 +13,13 @@ class ConcreteApiKeyAuth(ApiKeyAuthBase): class TestApiKeyAuthBase: def test_should_store_credentials_on_init(self): """Test that credentials are properly stored during initialization""" - credentials = {"api_key": "test_key", "auth_type": "bearer"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} auth = ConcreteApiKeyAuth(credentials) assert auth.credentials == credentials def test_should_not_instantiate_abstract_class(self): """Test that ApiKeyAuthBase cannot be instantiated directly""" - credentials = {"api_key": "test_key"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} with pytest.raises(TypeError) as exc_info: ApiKeyAuthBase(credentials) @@ -29,7 +29,7 @@ class TestApiKeyAuthBase: def test_should_allow_subclass_implementation(self): """Test that subclasses can properly implement the abstract method""" - credentials = {"api_key": "test_key", "auth_type": "bearer"} + credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}} auth = ConcreteApiKeyAuth(credentials) # Should not raise any exception diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py index 60af6e20c2..b1f7cf24f3 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py @@ -58,7 +58,7 @@ class TestApiKeyAuthFactory: mock_get_factory.return_value = mock_auth_class # Act - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}}) result = factory.validate_credentials() # Assert @@ -75,7 +75,7 @@ class TestApiKeyAuthFactory: mock_get_factory.return_value = mock_auth_class # Act & Assert - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"}) + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}}) with pytest.raises(Exception) as exc_info: factory.validate_credentials() assert str(exc_info.value) == "Authentication error" From 8b634a9bee3a3238f45540082459031dd155f0f2 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Tue, 24 Mar 2026 05:27:50 +0100 Subject: [PATCH 29/34] =?UTF-8?q?refactor:=20use=20EnumText=20for=20ApiToo?= =?UTF-8?q?lProvider.schema=5Ftype=5Fstr=20and=20Docume=E2=80=A6=20(#33983?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/commands/vector.py | 3 +- api/models/dataset.py | 4 +- api/models/tools.py | 4 +- api/services/dataset_service.py | 20 ++++---- .../rag_pipeline_transform_service.py | 9 ++-- .../batch_create_segment_to_index_task.py | 5 +- api/tasks/document_indexing_task.py | 3 +- api/tasks/regenerate_summary_index_task.py | 5 +- .../test_dataset_retrieval_integration.py | 15 +++--- .../services/document_service_status.py | 3 +- .../services/test_dataset_service.py | 3 +- ...et_service_batch_update_document_status.py | 3 +- .../test_dataset_service_delete_dataset.py | 5 +- .../test_document_service_display_status.py | 3 +- .../test_document_service_rename_document.py | 3 +- .../services/test_metadata_service.py | 3 +- .../tools/test_tools_transform_service.py | 10 ++-- .../tasks/test_batch_clean_document_task.py | 16 ++++-- ...test_batch_create_segment_to_index_task.py | 19 +++---- .../tasks/test_clean_dataset_task.py | 3 +- .../tasks/test_clean_notion_document_task.py | 5 +- .../test_create_segment_to_index_task.py | 11 ++-- .../test_deal_dataset_vector_index_task.py | 51 ++++++++++--------- .../test_disable_segment_from_index_task.py | 9 +++- .../test_disable_segments_from_index_task.py | 9 +++- .../tasks/test_document_indexing_sync_task.py | 3 +- .../test_document_indexing_update_task.py | 3 +- .../test_duplicate_document_indexing_task.py | 7 +-- .../console/datasets/test_data_source.py | 3 +- .../console/datasets/test_datasets.py | 3 +- .../datasets/test_datasets_document.py | 19 +++---- .../datasets/test_datasets_segments.py | 5 +- .../controllers/service_api/conftest.py | 3 +- .../dataset/test_dataset_segment.py | 11 ++-- .../service_api/dataset/test_document.py | 33 ++++++++---- .../rag/retrieval/test_dataset_retrieval.py | 4 +- .../unit_tests/models/test_tool_models.py | 30 +++++------ .../services/document_service_validation.py | 15 +++--- .../unit_tests/services/segment_service.py | 7 +-- .../test_dataset_service_lock_not_owned.py | 7 +-- .../services/test_summary_index_service.py | 15 +++--- .../services/test_vector_service.py | 13 ++--- .../unit_tests/services/vector_service.py | 9 ++-- .../tasks/test_clean_dataset_task.py | 15 +++--- .../tasks/test_dataset_indexing_task.py | 3 +- .../tasks/test_document_indexing_sync_task.py | 3 +- 46 files changed, 255 insertions(+), 180 deletions(-) diff --git a/api/commands/vector.py b/api/commands/vector.py index 4cf11c9ad1..bef18bf73b 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,6 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -269,7 +270,7 @@ def migrate_knowledge_vector_database(): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == "hierarchical_model": + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] diff --git a/api/models/dataset.py b/api/models/dataset.py index 4c6152ed3f..b4fb03a7f4 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -496,7 +496,9 @@ class Document(Base): ) doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) - doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) + doc_form: Mapped[IndexStructureType] = mapped_column( + EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'") + ) doc_language = mapped_column(String(255), nullable=True) need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) diff --git a/api/models/tools.py b/api/models/tools.py index 01182af867..63b27b9413 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -145,7 +145,9 @@ class ApiToolProvider(TypeBase): icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema: Mapped[str] = mapped_column(LongText, nullable=False) - schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) + schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column( + EnumText(ApiProviderSchemaType, length=40), nullable=False + ) # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ba4ab6757f..65e112f1e9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1440,7 +1440,7 @@ class DocumentService: .filter( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, - Document.doc_form != "qa_model", # Skip qa_model documents + Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .update({Document.need_summary: need_summary}, synchronize_session=False) ) @@ -2040,7 +2040,7 @@ class DocumentService: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = knowledge_config.doc_form + document.doc_form = IndexStructureType(knowledge_config.doc_form) document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch @@ -2640,7 +2640,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = document_data.doc_form + document.doc_form = IndexStructureType(document_data.doc_form) db.session.add(document) db.session.commit() # update document segment @@ -3101,7 +3101,7 @@ class DocumentService: class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") if not args["answer"].strip(): @@ -3158,7 +3158,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] @@ -3232,7 +3232,7 @@ class SegmentService: tokens = 0 if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( texts=[content + segment_item["answer"]] )[0] @@ -3255,7 +3255,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.answer = segment_item["answer"] segment_document.word_count += len(segment_item["answer"]) increment_word_count += segment_document.word_count @@ -3322,7 +3322,7 @@ class SegmentService: content = args.content or segment.content if segment.content == content: segment.word_count = len(content) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3419,7 +3419,7 @@ class SegmentService: ) # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore else: @@ -3436,7 +3436,7 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 1d0aafd5fd..7dcfecdd1d 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,6 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -79,9 +80,9 @@ class RagPipelineTransformService: pipeline = self._create_pipeline(pipeline_yaml) # save chunk structure to dataset - if doc_form == "hierarchical_model": + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset.chunk_structure = "hierarchical_model" - elif doc_form == "text_model": + elif doc_form == IndexStructureType.PARAGRAPH_INDEX: dataset.chunk_structure = "text_model" else: raise ValueError("Unsupported doc form") @@ -101,7 +102,7 @@ class RagPipelineTransformService: def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None): pipeline_yaml = {} - if doc_form == "text_model": + if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: if indexing_technique == "high_quality": @@ -132,7 +133,7 @@ class RagPipelineTransformService: pipeline_yaml = yaml.safe_load(f) case _: raise ValueError("Unsupported datasource type") - elif doc_form == "hierarchical_model": + elif doc_form == IndexStructureType.PARENT_CHILD_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 49dee00919..7f810129ef 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,6 +11,7 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexStructureType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -109,7 +110,7 @@ def batch_create_segment_to_index_task( df = pd.read_csv(file_path) content = [] for _, row in df.iterrows(): - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: data = {"content": row.iloc[0], "answer": row.iloc[1]} else: data = {"content": row.iloc[0]} @@ -159,7 +160,7 @@ def batch_create_segment_to_index_task( status="completed", completed_at=naive_utc_now(), ) - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: segment_document.answer = segment["answer"] segment_document.word_count += len(segment["answer"]) word_count_change += segment_document.word_count diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e05d63426c..b5794e33e2 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,6 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) if ( document.indexing_status == IndexingStatus.COMPLETED - and document.doc_form != "qa_model" + and document.doc_form != IndexStructureType.QA_INDEX and document.need_summary is True ): try: diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 39c2f4103e..ac5d23408a 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,6 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -106,7 +107,7 @@ def regenerate_summary_index_task( ), DatasetDocument.enabled == True, # Document must be enabled DatasetDocument.archived == False, # Document must not be archived - DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) .all() @@ -209,7 +210,7 @@ def regenerate_summary_index_task( for dataset_document in dataset_documents: # Skip qa_model documents - if dataset_document.doc_form == "qa_model": + if dataset_document.doc_form == IndexStructureType.QA_INDEX: continue try: diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 781e297fa4..ea8d04502a 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration: name=f"Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived @@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, @@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=status, # Not completed enabled=True, archived=False, @@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index f995ac7bef..42d587b7f7 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id document.indexing_status = indexing_status diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index ac3d9f9604..a484c7be87 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory: created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.flush() diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index ab7e2a3f50..c1d088755c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by or str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id or str(uuid4()) document.enabled = enabled diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index ed070527c9..807d18322c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,7 @@ from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory: tenant_id: str, dataset_id: str, created_by: str, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """Persist a document so dataset.doc_form resolves through the real document path.""" document = Document( @@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset: tenant_id=tenant.id, dataset_id=dataset.id, created_by=owner.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Act diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index 47d259d8a0..c0047df810 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -3,6 +3,7 @@ from uuid import uuid4 from sqlalchemy import select +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -42,7 +43,7 @@ def _create_document( name=f"doc-{uuid4()}", created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = str(uuid4()) document.indexing_status = indexing_status diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index bffa520ce6..34532ed7f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -69,7 +70,7 @@ def make_document( name=name, created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) doc.id = document_id doc.indexing_status = "completed" diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e847329c5b..8b1349be9a 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -5,6 +5,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom @@ -139,7 +140,7 @@ class TestMetadataService: name=fake.file_name(), created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 0f38218c51..7ab059bb75 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -52,7 +52,7 @@ class TestToolTransformService: user_id="test_user_id", credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) elif provider_type == "builtin": @@ -659,7 +659,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -695,7 +695,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -731,7 +731,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 210d9eb39e..6cbbe43137 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask: created_from=DocumentCreatedFrom.WEB, created_by=account.id, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) @@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Execute the task with non-existent dataset - batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) + batch_clean_document_task( + document_ids=[document_id], + dataset_id=dataset_id, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + file_ids=[], + ) # Verify that no index processing occurred mock_external_service_dependencies["index_processor"].clean.assert_not_called() @@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask: account = self._create_test_account(db_session_with_containers) # Test different doc_form types - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 202ccb0098..5ebf141828 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,6 +19,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -179,7 +180,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ) @@ -221,17 +222,17 @@ class TestBatchCreateSegmentToIndexTask: return upload_file - def _create_test_csv_content(self, content_type="text_model"): + def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX): """ Helper method to create test CSV content. Args: - content_type: Type of content to create ("text_model" or "qa_model") + content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX) Returns: str: CSV content as string """ - if content_type == "qa_model": + if content_type == IndexStructureType.QA_INDEX: csv_content = "content,answer\n" csv_content += "This is the first segment content,This is the first answer\n" csv_content += "This is the second segment content,This is the second answer\n" @@ -264,7 +265,7 @@ class TestBatchCreateSegmentToIndexTask: upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant) # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] @@ -451,7 +452,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Archived document @@ -467,7 +468,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Document with incomplete indexing @@ -483,7 +484,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), ] @@ -655,7 +656,7 @@ class TestBatchCreateSegmentToIndexTask: db_session_with_containers.commit() # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1cd698b870..9449fee0af 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,6 +18,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -192,7 +193,7 @@ class TestCleanDatasetTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=100, created_at=datetime.now(), updated_at=datetime.now(), diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index a2a190fd69..926c839c8b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -12,6 +12,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask: name=f"Notion Page {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", # Set doc_form to ensure dataset.doc_form works + doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works doc_language="en", indexing_status=IndexingStatus.COMPLETED, ) @@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask: # Test different index types # Note: Only testing text_model to avoid dependency on external services - index_types = ["text_model"] + index_types = [IndexStructureType.PARAGRAPH_INDEX] for index_type in index_types: # Create dataset (doc_form will be set via document creation) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 132f43c320..979435282b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask: - Processing completes successfully for different forms """ # Arrange: Test different doc_forms - doc_forms = ["qa_model", "text_model", "web_model"] + doc_forms = [ + IndexStructureType.QA_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + ] for doc_form in doc_forms: # Create fresh test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index e80b37ac1b..d457b59d58 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="qa_index", + doc_form=IndexStructureType.QA_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type - mock_index_processor_factory.assert_called_once_with("qa_index") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type - mock_index_processor_factory.assert_called_once_with("text_model") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask: name=f"Test Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask: name="Enabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask: name="Disabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped @@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask: name="Active Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask: name="Archived Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask: name="Completed Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask: name="Incomplete Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index da42fc7167..d21f1daf23 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,6 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask: dataset: Dataset, tenant: Tenant, account: Account, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """ Helper method to create a test document. @@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask: - Index processor clean method is called correctly """ # Test different document forms - doc_forms = ["text_model", "qa_model", "table_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Arrange: Create test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 4bc9bb4749..fbcb7b5264 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask: document.indexing_status = "completed" document.enabled = True document.archived = False - document.doc_form = "text_model" # Use text_model form for testing + document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing document.doc_language = "en" db_session_with_containers.add(document) db_session_with_containers.commit() @@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask: segment_ids = [segment.id for segment in segments] # Test different document forms - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Update document form diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 6a17a19a54..10d97919fb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,6 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory: created_by=created_by, indexing_status=indexing_status, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 2fbea1388c..c650d56091 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -80,7 +81,7 @@ class TestDocumentIndexingUpdateTask: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index f1f5a4b105..76b6a8ae73 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -130,7 +131,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -265,7 +266,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -524,7 +525,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=dataset.created_by, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) extra_documents.append(document) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py index 3060062adf..d841f67f9b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -11,6 +11,7 @@ from controllers.console.datasets.data_source import ( DataSourceNotionDocumentSyncApi, DataSourceNotionListApi, ) +from core.rag.index_processor.constant.index_type import IndexStructureType def unwrap(func): @@ -343,7 +344,7 @@ class TestDataSourceNotionApi: } ], "process_rule": {"rules": {}}, - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", } diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 0ee76e504b..68a7b30b9e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import ( from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models.enums import CreatorUserRole from models.model import ApiToken, UploadFile @@ -1146,7 +1147,7 @@ class TestDatasetIndexingEstimateApi: }, "process_rule": {"chunk_size": 100}, "indexing_technique": "high_quality", - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", "dataset_id": None, } diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index f23dd5b44a..f08f21ee14 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -30,6 +30,7 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import DataSourceType, IndexingStatus @@ -66,7 +67,7 @@ def document(): indexing_status=IndexingStatus.INDEXING, data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, archived=False, is_paused=False, dataset_process_rule=None, @@ -765,8 +766,8 @@ class TestDocumentGenerateSummaryApi: summary_index_setting={"enable": True}, ) - doc1 = MagicMock(id="doc-1", doc_form="qa_model") - doc2 = MagicMock(id="doc-2", doc_form="text") + doc1 = MagicMock(id="doc-1", doc_form=IndexStructureType.QA_INDEX) + doc2 = MagicMock(id="doc-2", doc_form=IndexStructureType.PARAGRAPH_INDEX) payload = {"document_list": ["doc-1", "doc-2"]} @@ -822,7 +823,7 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -849,7 +850,7 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -973,7 +974,7 @@ class TestDocumentBatchIndexingEstimateApi: "mode": "single", "only_main_content": True, }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1001,7 +1002,7 @@ class TestDocumentBatchIndexingEstimateApi: "notion_page_id": "p1", "type": "page", }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1024,7 +1025,7 @@ class TestDocumentBatchIndexingEstimateApi: indexing_status=IndexingStatus.INDEXING, data_source_type="unknown", data_source_info_dict={}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): @@ -1353,7 +1354,7 @@ class TestDocumentIndexingEdgeCases: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index e67e4daad9..1482499c41 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -24,6 +24,7 @@ from controllers.console.datasets.error import ( InvalidActionError, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -366,7 +367,7 @@ class TestDatasetDocumentSegmentAddApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() segment.id = "seg-1" @@ -505,7 +506,7 @@ class TestDatasetDocumentSegmentUpdateApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 4337a0c8c0..01d2d1e7c0 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -12,6 +12,7 @@ from unittest.mock import Mock import pytest from flask import Flask +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser from tests.unit_tests.conftest import setup_mock_tenant_account_query @@ -175,7 +176,7 @@ def mock_document(): document.name = "test_document.txt" document.indexing_status = "completed" document.enabled = True - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX return document diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 5c48ef1804..73a87761d5 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -31,6 +31,7 @@ from controllers.service_api.dataset.segment import ( SegmentCreatePayload, SegmentListQuery, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -788,7 +789,7 @@ class TestSegmentApiGet: # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc_svc.get_document.return_value = Mock(doc_form="text_model") + mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_marshal.return_value = [{"id": mock_segment.id}] @@ -903,7 +904,7 @@ class TestSegmentApiPost: mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.segment_create_args_validate.return_value = None @@ -1091,7 +1092,7 @@ class TestDatasetSegmentApiDelete: mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = None # Segment not found @@ -1371,7 +1372,7 @@ class TestDatasetSegmentApiGetSingle: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None - mock_doc = Mock(doc_form="text_model") + mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = mock_segment mock_marshal.return_value = {"id": mock_segment.id} @@ -1390,7 +1391,7 @@ class TestDatasetSegmentApiGetSingle: assert status == 200 assert "data" in response - assert response["doc_form"] == "text_model" + assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX @patch("controllers.service_api.dataset.segment.current_account_with_tenant") @patch("controllers.service_api.dataset.segment.db") diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index e6e841be19..7f77e61ee4 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -52,7 +53,7 @@ class TestDocumentTextCreatePayload: def test_payload_with_defaults(self): """Test payload default values.""" payload = DocumentTextCreatePayload(name="Doc", text="Content") - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" assert payload.process_rule is None assert payload.indexing_technique is None @@ -62,14 +63,14 @@ class TestDocumentTextCreatePayload: payload = DocumentTextCreatePayload( name="Full Document", text="Complete document content here", - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, doc_language="Chinese", indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) assert payload.name == "Full Document" - assert payload.doc_form == "qa_model" + assert payload.doc_form == IndexStructureType.QA_INDEX assert payload.doc_language == "Chinese" assert payload.indexing_technique == "high_quality" assert payload.embedding_model == "text-embedding-ada-002" @@ -147,8 +148,8 @@ class TestDocumentTextUpdate: def test_payload_with_doc_form_update(self): """Test payload with doc_form update.""" - payload = DocumentTextUpdate(doc_form="qa_model") - assert payload.doc_form == "qa_model" + payload = DocumentTextUpdate(doc_form=IndexStructureType.QA_INDEX) + assert payload.doc_form == IndexStructureType.QA_INDEX def test_payload_with_language_update(self): """Test payload with doc_language update.""" @@ -158,7 +159,7 @@ class TestDocumentTextUpdate: def test_payload_default_values(self): """Test payload default values.""" payload = DocumentTextUpdate() - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" @@ -272,14 +273,24 @@ class TestDocumentDocForm: def test_text_model_form(self): """Test text_model form.""" - doc_form = "text_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.PARAGRAPH_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms def test_qa_model_form(self): """Test qa_model form.""" - doc_form = "qa_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.QA_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms @@ -504,7 +515,7 @@ class TestDocumentApiGet: doc.name = "test_document.txt" doc.indexing_status = "completed" doc.enabled = True - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.doc_type = "book" doc.doc_metadata_details = {"source": "upload"} diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 665e98bd9c..a34ca330ca 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -4800,8 +4800,8 @@ class TestInternalHooksCoverage: dataset_docs = [ SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), - SimpleNamespace(id="doc-c", doc_form="qa_model"), - SimpleNamespace(id="doc-d", doc_form="qa_model"), + SimpleNamespace(id="doc-c", doc_form=IndexStructureType.QA_INDEX), + SimpleNamespace(id="doc-d", doc_form=IndexStructureType.QA_INDEX), ] child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] diff --git a/api/tests/unit_tests/models/test_tool_models.py b/api/tests/unit_tests/models/test_tool_models.py index a6c2eae2c0..8e3c4da904 100644 --- a/api/tests/unit_tests/models/test_tool_models.py +++ b/api/tests/unit_tests/models/test_tool_models.py @@ -238,7 +238,7 @@ class TestApiToolProviderValidation: name=provider_name, icon='{"type": "emoji", "value": "🔧"}', schema=schema, - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Custom API for testing", tools_str=json.dumps(tools), credentials_str=json.dumps(credentials), @@ -249,7 +249,7 @@ class TestApiToolProviderValidation: assert api_provider.user_id == user_id assert api_provider.name == provider_name assert api_provider.schema == schema - assert api_provider.schema_type_str == "openapi" + assert api_provider.schema_type_str == ApiProviderSchemaType.OPENAPI assert api_provider.description == "Custom API for testing" def test_api_tool_provider_schema_type_property(self): @@ -261,7 +261,7 @@ class TestApiToolProviderValidation: name="Test API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", @@ -314,7 +314,7 @@ class TestApiToolProviderValidation: name="Weather API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Weather API", tools_str=json.dumps(tools_data), credentials_str="{}", @@ -343,7 +343,7 @@ class TestApiToolProviderValidation: name="Secure API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Secure API", tools_str="[]", credentials_str=json.dumps(credentials_data), @@ -369,7 +369,7 @@ class TestApiToolProviderValidation: name="Privacy API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with privacy policy", tools_str="[]", credentials_str="{}", @@ -391,7 +391,7 @@ class TestApiToolProviderValidation: name="Disclaimer API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with disclaimer", tools_str="[]", credentials_str="{}", @@ -410,7 +410,7 @@ class TestApiToolProviderValidation: name="Default API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API", tools_str="[]", credentials_str="{}", @@ -432,7 +432,7 @@ class TestApiToolProviderValidation: name=provider_name, icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Unique API", tools_str="[]", credentials_str="{}", @@ -454,7 +454,7 @@ class TestApiToolProviderValidation: name="Public API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Public API with no auth", tools_str="[]", credentials_str=json.dumps(credentials), @@ -479,7 +479,7 @@ class TestApiToolProviderValidation: name="Query Auth API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="API with query auth", tools_str="[]", credentials_str=json.dumps(credentials), @@ -741,7 +741,7 @@ class TestCredentialStorage: name="Test API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str=json.dumps(credentials), @@ -788,7 +788,7 @@ class TestCredentialStorage: name="Update Test", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str=json.dumps(original_credentials), @@ -897,7 +897,7 @@ class TestToolProviderRelationships: name="User API", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", @@ -931,7 +931,7 @@ class TestToolProviderRelationships: name="Custom API 1", icon="{}", schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, description="Test", tools_str="[]", credentials_str="{}", diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 6829691507..1f68ff6b3d 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -111,6 +111,7 @@ from unittest.mock import Mock, patch import pytest from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType from dify_graph.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService @@ -188,7 +189,7 @@ class DocumentValidationTestDataFactory: def create_knowledge_config_mock( data_source: DataSource | None = None, process_rule: ProcessRule | None = None, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, indexing_technique: str = "high_quality", **kwargs, ) -> Mock: @@ -326,8 +327,8 @@ class TestDatasetServiceCheckDocForm: - Validation logic works correctly """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "text_model" + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -349,7 +350,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=None) - doc_form = "text_model" + doc_form = IndexStructureType.PARAGRAPH_INDEX # Act (should not raise) DatasetService.check_doc_form(dataset, doc_form) @@ -370,8 +371,8 @@ class TestDatasetServiceCheckDocForm: - Error type is correct """ # Arrange - dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model") - doc_form = "table_model" # Different form + dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) + doc_form = IndexStructureType.PARENT_CHILD_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): @@ -390,7 +391,7 @@ class TestDatasetServiceCheckDocForm: """ # Arrange dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="knowledge_card") - doc_form = "text_model" # Different form + doc_form = IndexStructureType.PARAGRAPH_INDEX # Different form # Act & Assert with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"): diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index cc2c0a8032..5e625fa0cd 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import SegmentType @@ -91,7 +92,7 @@ class SegmentTestDataFactory: document_id: str = "doc-123", dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, word_count: int = 100, **kwargs, ) -> Mock: @@ -210,7 +211,7 @@ class TestSegmentServiceCreateSegment: def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user): """Test creation of segment with QA model (requires answer).""" # Arrange - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]} @@ -429,7 +430,7 @@ class TestSegmentServiceUpdateSegment: """Test update segment with QA model (includes answer).""" # Arrange segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10) - document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100) + document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100) dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy") args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"]) diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index bd226f7536..d2287e8982 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -4,6 +4,7 @@ from unittest.mock import Mock, create_autospec import pytest from redis.exceptions import LockNotOwnedError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService, SegmentService @@ -76,7 +77,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( info_list = types.SimpleNamespace(data_source_type="upload_file") data_source = types.SimpleNamespace(info_list=info_list) knowledge_config = types.SimpleNamespace( - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, original_document_id=None, # go into "new document" branch data_source=data_source, indexing_technique="high_quality", @@ -131,7 +132,7 @@ def test_add_segment_ignores_lock_not_owned( document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX # Minimal args required by add_segment args = { @@ -174,4 +175,4 @@ def test_multi_create_segment_ignores_lock_not_owned( document.id = "doc-1" document.dataset_id = dataset.id document.word_count = 0 - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index be64e431ba..c4285c73a0 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -48,7 +49,7 @@ def _segment(*, has_document: bool = True) -> MagicMock: if has_document: doc = MagicMock(name="document") doc.doc_language = "en" - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX segment.document = doc else: segment.document = None @@ -623,13 +624,13 @@ def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.Mon dataset = _dataset(indexing_technique="economy") document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] dataset = _dataset() assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] - document.doc_form = "qa_model" + document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] @@ -637,7 +638,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg1 = _segment() seg2 = _segment() @@ -673,7 +674,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX session = MagicMock() query = MagicMock() @@ -696,7 +697,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu dataset = _dataset() document = MagicMock(spec=summary_module.DatasetDocument) document.id = "doc-1" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX seg = _segment() session = MagicMock() @@ -935,7 +936,7 @@ def test_update_summary_for_segment_skip_conditions() -> None: SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None ) seg = _segment(has_document=True) - seg.document.doc_form = "qa_model" + seg.document.doc_form = IndexStructureType.QA_INDEX assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index 7b0103a2a1..d3a98dd4bb 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -9,6 +9,7 @@ from unittest.mock import MagicMock import pytest import services.vector_service as vector_service_module +from core.rag.index_processor.constant.index_type import IndexStructureType from services.vector_service import VectorService @@ -32,7 +33,7 @@ class _ParentDocStub: def _make_dataset( *, indexing_technique: str = "high_quality", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, tenant_id: str = "tenant-1", dataset_id: str = "dataset-1", is_multimodal: bool = False, @@ -106,7 +107,7 @@ def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(mo factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_called_once() args, kwargs = index_processor.load.call_args @@ -131,7 +132,7 @@ def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monk factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) assert index_processor.load.call_count == 2 first_args, first_kwargs = index_processor.load.call_args_list[0] @@ -153,7 +154,7 @@ def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pyte factory_instance.init_index_processor.return_value = index_processor monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) index_processor.load.assert_not_called() @@ -392,7 +393,7 @@ def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkey def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX, tenant_id="tenant-1", dataset_id="dataset-1") segment = _make_segment(segment_id="seg-1") dataset_document = MagicMock() @@ -439,7 +440,7 @@ def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = _make_dataset(doc_form="text_model") + dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX) segment = _make_segment() dataset_document = MagicMock() dataset_document.doc_language = "en" diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index c99275c6b2..e180063041 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -121,6 +121,7 @@ import pytest from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment from services.vector_service import VectorService @@ -151,7 +152,7 @@ class VectorServiceTestDataFactory: def create_dataset_mock( dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, indexing_technique: str = "high_quality", embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", @@ -493,7 +494,7 @@ class TestVectorService: """ # Arrange dataset = VectorServiceTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality" ) segment = VectorServiceTestDataFactory.create_document_segment_mock() @@ -505,7 +506,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(keywords_list, [segment], dataset, "text_model") + VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_called_once() @@ -649,7 +650,7 @@ class TestVectorService: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor # Act - VectorService.create_segments_vector(None, [], dataset, "text_model") + VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX) # Assert mock_index_processor.load.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index 74ba7f9c34..c0a4d2f113 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task @@ -186,7 +187,7 @@ class TestErrorHandling: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -231,7 +232,7 @@ class TestPipelineAndWorkflowDeletion: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=pipeline_id, ) @@ -267,7 +268,7 @@ class TestPipelineAndWorkflowDeletion: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, pipeline_id=None, ) @@ -323,7 +324,7 @@ class TestSegmentAttachmentCleanup: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -368,7 +369,7 @@ class TestSegmentAttachmentCleanup: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert - storage delete was attempted @@ -410,7 +411,7 @@ class TestEdgeCases: indexing_technique="high_quality", index_struct='{"type": "paragraph"}', collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert @@ -454,7 +455,7 @@ class TestIndexProcessorParameters: indexing_technique=indexing_technique, index_struct=index_struct, collection_binding_id=collection_binding_id, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Assert diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 8a721124d6..6804ade5aa 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client @@ -222,7 +223,7 @@ def mock_documents(document_ids, dataset_id): doc.stopped_at = None doc.processing_started_at = None # optional attribute used in some code paths - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX documents.append(doc) return documents diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 3668416e36..f49f4535af 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -62,7 +63,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, document.tenant_id = str(uuid.uuid4()) document.data_source_type = "notion_import" document.indexing_status = "completed" - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX document.data_source_info_dict = { "notion_workspace_id": notion_workspace_id, "notion_page_id": notion_page_id, From b0920ecd17743a55638bded693e3d93237ebfcc5 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:02:52 +0800 Subject: [PATCH 30/34] refactor(web): migrate plugin toast usage to new UI toast API and update tests (#34001) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../plugins/plugin-install-flow.test.ts | 12 +++- .../install-plugin/__tests__/hooks.spec.ts | 20 ++++--- .../plugins/install-plugin/hooks.ts | 20 ++----- .../__tests__/index.spec.tsx | 14 +++-- .../install-from-github/index.tsx | 21 ++----- .../__tests__/detail-header.spec.tsx | 19 ++++++- .../__tests__/endpoint-card.spec.tsx | 19 ++++++- .../__tests__/endpoint-modal.spec.tsx | 20 ++++++- .../__tests__/use-plugin-operations.spec.ts | 21 ++++++- .../hooks/use-plugin-operations.ts | 14 ++--- .../plugin-detail-panel/endpoint-card.tsx | 20 +++---- .../plugin-detail-panel/endpoint-list.tsx | 12 ++-- .../plugin-detail-panel/endpoint-modal.tsx | 13 +++-- .../model-selector/__tests__/index.spec.tsx | 32 +++++++---- .../model-selector/index.tsx | 9 +-- .../__tests__/log-viewer.spec.tsx | 20 +++++-- .../__tests__/selector-entry.spec.tsx | 14 +++-- .../__tests__/selector-view.spec.tsx | 14 ++++- .../__tests__/subscription-card.spec.tsx | 14 ++++- .../create/__tests__/common-modal.spec.tsx | 14 +++-- .../create/__tests__/index.spec.tsx | 21 ++++--- .../create/__tests__/oauth-client.spec.tsx | 17 ++++-- .../__tests__/use-oauth-client-state.spec.ts | 17 ++++-- .../create/hooks/use-common-modal-state.ts | 42 +++----------- .../create/hooks/use-oauth-client-state.ts | 27 ++------- .../subscription-list/create/index.tsx | 12 +--- .../subscription-list/create/oauth-client.tsx | 7 +-- .../edit/__tests__/apikey-edit-modal.spec.tsx | 16 ++++-- .../edit/__tests__/index.spec.tsx | 12 +++- .../edit/__tests__/manual-edit-modal.spec.tsx | 16 ++++-- .../edit/__tests__/oauth-edit-modal.spec.tsx | 16 ++++-- .../edit/apikey-edit-modal.tsx | 24 ++------ .../edit/manual-edit-modal.tsx | 12 +--- .../edit/oauth-edit-modal.tsx | 12 +--- .../subscription-list/log-viewer.tsx | 7 +-- .../tool-selector/__tests__/index.spec.tsx | 14 ++++- .../__tests__/tool-credentials-form.spec.tsx | 16 ++++-- .../components/tool-credentials-form.tsx | 7 ++- .../plugin-item/__tests__/action.spec.tsx | 31 ++++++---- .../components/plugins/plugin-item/action.tsx | 4 +- web/eslint-suppressions.json | 57 +++---------------- 41 files changed, 390 insertions(+), 339 deletions(-) diff --git a/web/__tests__/plugins/plugin-install-flow.test.ts b/web/__tests__/plugins/plugin-install-flow.test.ts index 8edb6705d4..8fa2246198 100644 --- a/web/__tests__/plugins/plugin-install-flow.test.ts +++ b/web/__tests__/plugins/plugin-install-flow.test.ts @@ -12,8 +12,16 @@ vi.mock('@/config', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockToastNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockUploadGitHub = vi.fn() diff --git a/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts b/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts index 918a9b36e3..6b0fc27adf 100644 --- a/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts +++ b/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts @@ -3,8 +3,16 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { useGitHubReleases, useGitHubUpload } from '../hooks' const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((...args: unknown[]) => mockNotify(...args), { + success: (...args: unknown[]) => mockNotify(...args), + error: (...args: unknown[]) => mockNotify(...args), + warning: (...args: unknown[]) => mockNotify(...args), + info: (...args: unknown[]) => mockNotify(...args), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) vi.mock('@/config', () => ({ @@ -56,9 +64,7 @@ describe('install-plugin/hooks', () => { const releases = await result.current.fetchReleases('owner', 'repo') expect(releases).toEqual([]) - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockNotify).toHaveBeenCalledWith('Failed to fetch repository releases') }) }) @@ -130,9 +136,7 @@ describe('install-plugin/hooks', () => { await expect( result.current.handleUpload('url', 'v1', 'pkg'), ).rejects.toThrow('Upload failed') - expect(mockNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error', message: 'Error uploading package' }), - ) + expect(mockNotify).toHaveBeenCalledWith('Error uploading package') }) }) }) diff --git a/web/app/components/plugins/install-plugin/hooks.ts b/web/app/components/plugins/install-plugin/hooks.ts index 2addba4a04..cc7148cc17 100644 --- a/web/app/components/plugins/install-plugin/hooks.ts +++ b/web/app/components/plugins/install-plugin/hooks.ts @@ -1,6 +1,5 @@ import type { GitHubRepoReleaseResponse } from '../types' -import type { IToastProps } from '@/app/components/base/toast' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { GITHUB_ACCESS_TOKEN } from '@/config' import { uploadGitHub } from '@/service/plugins' import { compareVersion, getLatestVersion } from '@/utils/semver' @@ -37,16 +36,10 @@ export const useGitHubReleases = () => { } catch (error) { if (error instanceof Error) { - Toast.notify({ - type: 'error', - message: error.message, - }) + toast.error(error.message) } else { - Toast.notify({ - type: 'error', - message: 'Failed to fetch repository releases', - }) + toast.error('Failed to fetch repository releases') } return [] } @@ -54,7 +47,7 @@ export const useGitHubReleases = () => { const checkForUpdates = (fetchedReleases: GitHubRepoReleaseResponse[], currentVersion: string) => { let needUpdate = false - const toastProps: IToastProps = { + const toastProps: { type?: 'success' | 'error' | 'info' | 'warning', message: string } = { type: 'info', message: 'No new version available', } @@ -99,10 +92,7 @@ export const useGitHubUpload = () => { return GitHubPackage } catch (error) { - Toast.notify({ - type: 'error', - message: 'Error uploading package', - }) + toast.error('Error uploading package') throw error } } diff --git a/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx b/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx index 0fe6b88ed8..8abec7817b 100644 --- a/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx +++ b/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx @@ -57,10 +57,16 @@ const createUpdatePayload = (overrides: Partial = {}): // Mock external dependencies const mockNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (props: { type: string, message: string }) => mockNotify(props), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((props: { type: string, message: string }) => mockNotify(props), { + success: (message: string) => mockNotify({ type: 'success', message }), + error: (message: string) => mockNotify({ type: 'error', message }), + warning: (message: string) => mockNotify({ type: 'warning', message }), + info: (message: string) => mockNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockGetIconUrl = vi.fn() diff --git a/web/app/components/plugins/install-plugin/install-from-github/index.tsx b/web/app/components/plugins/install-plugin/install-from-github/index.tsx index 91425031cf..ff51698478 100644 --- a/web/app/components/plugins/install-plugin/install-from-github/index.tsx +++ b/web/app/components/plugins/install-plugin/install-from-github/index.tsx @@ -7,7 +7,7 @@ import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import useGetIcon from '@/app/components/plugins/install-plugin/base/use-get-icon' import { cn } from '@/utils/classnames' import { InstallStepFromGitHub } from '../../types' @@ -81,10 +81,7 @@ const InstallFromGitHub: React.FC = ({ updatePayload, on const handleUrlSubmit = async () => { const { isValid, owner, repo } = parseGitHubUrl(state.repoUrl) if (!isValid || !owner || !repo) { - Toast.notify({ - type: 'error', - message: t('error.inValidGitHubUrl', { ns: 'plugin' }), - }) + toast.error(t('error.inValidGitHubUrl', { ns: 'plugin' })) return } try { @@ -97,17 +94,11 @@ const InstallFromGitHub: React.FC = ({ updatePayload, on })) } else { - Toast.notify({ - type: 'error', - message: t('error.noReleasesFound', { ns: 'plugin' }), - }) + toast.error(t('error.noReleasesFound', { ns: 'plugin' })) } } catch { - Toast.notify({ - type: 'error', - message: t('error.fetchReleasesError', { ns: 'plugin' }), - }) + toast.error(t('error.fetchReleasesError', { ns: 'plugin' })) } } @@ -175,10 +166,10 @@ const InstallFromGitHub: React.FC = ({ updatePayload, on >
-
+
{getTitle()}
-
+
{!([InstallStepFromGitHub.uploadFailed, InstallStepFromGitHub.installed, InstallStepFromGitHub.installFailed].includes(state.step)) && t('installFromGitHub.installNote', { ns: 'plugin' })}
diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx index f0ec5b6c83..f8d6488128 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx @@ -2,10 +2,25 @@ import type { PluginDetail } from '../../types' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import * as amplitude from '@/app/components/base/amplitude' -import Toast from '@/app/components/base/toast' import { PluginSource } from '../../types' import DetailHeader from '../detail-header' +const { mockToast } = vi.hoisted(() => ({ + mockToast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: mockToast, +})) + const { mockSetShowUpdatePluginModal, mockRefreshModelProviders, @@ -272,7 +287,7 @@ describe('DetailHeader', () => { vi.clearAllMocks() mockAutoUpgradeInfo = null mockEnableMarketplace = true - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) + vi.clearAllMocks() vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {}) }) diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx index 480f399c91..237c72adf0 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx @@ -1,7 +1,6 @@ import type { EndpointListItem, PluginDetail } from '../../types' import { act, fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import EndpointCard from '../endpoint-card' const mockHandleChange = vi.fn() @@ -9,6 +8,22 @@ const mockEnableEndpoint = vi.fn() const mockDisableEndpoint = vi.fn() const mockDeleteEndpoint = vi.fn() const mockUpdateEndpoint = vi.fn() +const mockToastNotify = vi.fn() + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) // Flags to control whether operations should fail const failureFlags = { @@ -127,8 +142,6 @@ describe('EndpointCard', () => { failureFlags.disable = false failureFlags.delete = false failureFlags.update = false - // Mock Toast.notify to prevent toast elements from accumulating in DOM - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) // Polyfill document.execCommand for copy-to-clipboard in jsdom if (typeof document.execCommand !== 'function') { document.execCommand = vi.fn().mockReturnValue(true) diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-modal.spec.tsx index 1dfe31c6b1..a467de7142 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-modal.spec.tsx @@ -2,9 +2,25 @@ import type { FormSchema } from '../../../base/form/types' import type { PluginDetail } from '../../types' import { fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import EndpointModal from '../endpoint-modal' +const mockToastNotify = vi.fn() + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) + vi.mock('@/hooks/use-i18n', () => ({ useRenderI18nObject: () => (obj: Record | string) => typeof obj === 'string' ? obj : obj?.en_US || '', @@ -69,11 +85,9 @@ const mockPluginDetail: PluginDetail = { describe('EndpointModal', () => { const mockOnCancel = vi.fn() const mockOnSaved = vi.fn() - let mockToastNotify: ReturnType beforeEach(() => { vi.clearAllMocks() - mockToastNotify = vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) }) describe('Rendering', () => { diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts index 0fcec7f16b..77d41c5bce 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts @@ -3,7 +3,6 @@ import type { ModalStates, VersionTarget } from '../use-detail-header-state' import { act, renderHook } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import * as amplitude from '@/app/components/base/amplitude' -import Toast from '@/app/components/base/toast' import { PluginSource } from '../../../../types' import { usePluginOperations } from '../use-plugin-operations' @@ -20,6 +19,7 @@ const { mockUninstallPlugin, mockFetchReleases, mockCheckForUpdates, + mockToastNotify, } = vi.hoisted(() => { return { mockSetShowUpdatePluginModal: vi.fn(), @@ -29,9 +29,25 @@ const { mockUninstallPlugin: vi.fn(() => Promise.resolve({ success: true })), mockFetchReleases: vi.fn(() => Promise.resolve([{ tag_name: 'v2.0.0' }])), mockCheckForUpdates: vi.fn(() => ({ needUpdate: true, toastProps: { type: 'success', message: 'Update available' } })), + mockToastNotify: vi.fn(), } }) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) + vi.mock('@/context/modal-context', () => ({ useModalContext: () => ({ setShowUpdatePluginModal: mockSetShowUpdatePluginModal, @@ -124,7 +140,6 @@ describe('usePluginOperations', () => { modalStates = createModalStatesMock() versionPicker = createVersionPickerMock() mockOnUpdate = vi.fn() - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) vi.spyOn(amplitude, 'trackEvent').mockImplementation(() => {}) }) @@ -233,7 +248,7 @@ describe('usePluginOperations', () => { }) expect(mockCheckForUpdates).toHaveBeenCalled() - expect(Toast.notify).toHaveBeenCalled() + expect(mockToastNotify).toHaveBeenCalledWith({ type: 'success', message: 'Update available' }) }) it('should show update plugin modal when update is needed', async () => { diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts index bf6bb4aae6..ade47cec5f 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts @@ -5,7 +5,7 @@ import type { ModalStates, VersionTarget } from './use-detail-header-state' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import { uninstallPlugin } from '@/service/plugins' @@ -60,10 +60,7 @@ export const usePluginOperations = ({ } if (!meta?.repo || !meta?.version || !meta?.package) { - Toast.notify({ - type: 'error', - message: 'Missing plugin metadata for GitHub update', - }) + toast.error('Missing plugin metadata for GitHub update') return } @@ -74,7 +71,7 @@ export const usePluginOperations = ({ return const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta.version) - Toast.notify(toastProps) + toast(toastProps.message, { type: toastProps.type }) if (needUpdate) { setShowUpdatePluginModal({ @@ -122,10 +119,7 @@ export const usePluginOperations = ({ if (res.success) { modalStates.hideDeleteConfirm() - Toast.notify({ - type: 'success', - message: t('action.deleteSuccess', { ns: 'plugin' }), - }) + toast.success(t('action.deleteSuccess', { ns: 'plugin' })) handlePluginUpdated(true) if (PluginCategoryEnum.model.includes(category)) diff --git a/web/app/components/plugins/plugin-detail-panel/endpoint-card.tsx b/web/app/components/plugins/plugin-detail-panel/endpoint-card.tsx index 164bab0f04..9f95d9c7e1 100644 --- a/web/app/components/plugins/plugin-detail-panel/endpoint-card.tsx +++ b/web/app/components/plugins/plugin-detail-panel/endpoint-card.tsx @@ -9,8 +9,8 @@ import ActionButton from '@/app/components/base/action-button' import Confirm from '@/app/components/base/confirm' import { CopyCheck } from '@/app/components/base/icons/src/vender/line/files' import Switch from '@/app/components/base/switch' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import Indicator from '@/app/components/header/indicator' import { addDefaultValue, toolCredentialToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { @@ -47,7 +47,7 @@ const EndpointCard = ({ await handleChange() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) setActive(false) }, }) @@ -57,7 +57,7 @@ const EndpointCard = ({ hideDisableConfirm() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) setActive(false) }, }) @@ -83,7 +83,7 @@ const EndpointCard = ({ hideDeleteConfirm() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) }, }) @@ -108,7 +108,7 @@ const EndpointCard = ({ hideEndpointModalConfirm() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) }, }) const handleUpdate = (state: Record) => updateEndpoint({ @@ -139,7 +139,7 @@ const EndpointCard = ({
-
+
{data.name}
@@ -154,8 +154,8 @@ const EndpointCard = ({
{data.declaration.endpoints.filter(endpoint => !endpoint.hidden).map((endpoint, index) => (
-
{endpoint.method}
-
+
{endpoint.method}
+
{`${data.url}${endpoint.path}`}
handleCopy(`${data.url}${endpoint.path}`)}> @@ -168,13 +168,13 @@ const EndpointCard = ({
{active && ( -
+
{t('detailPanel.serviceOk', { ns: 'plugin' })}
)} {!active && ( -
+
{t('detailPanel.disabled', { ns: 'plugin' })}
diff --git a/web/app/components/plugins/plugin-detail-panel/endpoint-list.tsx b/web/app/components/plugins/plugin-detail-panel/endpoint-list.tsx index 357e714ba2..366139d12d 100644 --- a/web/app/components/plugins/plugin-detail-panel/endpoint-list.tsx +++ b/web/app/components/plugins/plugin-detail-panel/endpoint-list.tsx @@ -9,8 +9,8 @@ import * as React from 'react' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { toolCredentialToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { useDocLink } from '@/context/i18n' import { @@ -50,7 +50,7 @@ const EndpointList = ({ detail }: Props) => { hideEndpointModal() }, onError: () => { - Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) }) + toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) }, }) @@ -64,7 +64,7 @@ const EndpointList = ({ detail }: Props) => { return (
-
+
{t('detailPanel.endpoints', { ns: 'plugin' })} {
-
{t('detailPanel.endpointsTip', { ns: 'plugin' })}
+
{t('detailPanel.endpointsTip', { ns: 'plugin' })}
-
+
{t('detailPanel.endpointsDocLink', { ns: 'plugin' })}
@@ -95,7 +95,7 @@ const EndpointList = ({ detail }: Props) => {
{data.endpoints.length === 0 && ( -
{t('detailPanel.endpointsEmpty', { ns: 'plugin' })}
+
{t('detailPanel.endpointsEmpty', { ns: 'plugin' })}
)}
{data.endpoints.map((item, index) => ( diff --git a/web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx b/web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx index 929e990f90..4d93e14c8b 100644 --- a/web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import Button from '@/app/components/base/button' import Drawer from '@/app/components/base/drawer' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import { useRenderI18nObject } from '@/hooks/use-i18n' import { cn } from '@/utils/classnames' @@ -48,7 +48,10 @@ const EndpointModal: FC = ({ const handleSave = () => { for (const field of formSchemas) { if (field.required && !tempCredential[field.name]) { - Toast.notify({ type: 'error', message: t('errorMsg.fieldRequired', { ns: 'common', field: typeof field.label === 'string' ? field.label : getValueFromI18nObject(field.label as Record) }) }) + toast.error(t('errorMsg.fieldRequired', { + ns: 'common', + field: typeof field.label === 'string' ? field.label : getValueFromI18nObject(field.label as Record), + })) return } } @@ -83,12 +86,12 @@ const EndpointModal: FC = ({ <>
-
{t('detailPanel.endpointModalTitle', { ns: 'plugin' })}
+
{t('detailPanel.endpointModalTitle', { ns: 'plugin' })}
-
{t('detailPanel.endpointModalDesc', { ns: 'plugin' })}
+
{t('detailPanel.endpointModalDesc', { ns: 'plugin' })}
@@ -109,7 +112,7 @@ const EndpointModal: FC = ({ href={item.url} target="_blank" rel="noopener noreferrer" - className="body-xs-regular inline-flex items-center text-text-accent-secondary" + className="inline-flex items-center text-text-accent-secondary body-xs-regular" > {t('howToGet', { ns: 'tools' })} diff --git a/web/app/components/plugins/plugin-detail-panel/model-selector/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/model-selector/__tests__/index.spec.tsx index 9b04a710e0..107d42ada2 100644 --- a/web/app/components/plugins/plugin-detail-panel/model-selector/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/model-selector/__tests__/index.spec.tsx @@ -1,14 +1,29 @@ import type { Model, ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -// Import component after mocks -import Toast from '@/app/components/base/toast' - import { ConfigurationMethodEnum, ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' + +// Import component after mocks import ModelParameterModal from '../index' // ==================== Mock Setup ==================== +const mockToastNotify = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) + // Mock provider context const mockProviderContextValue = { isAPIKeySet: true, @@ -53,8 +68,6 @@ vi.mock('@/utils/completion-params', () => ({ fetchAndMergeValidCompletionParams: (...args: unknown[]) => mockFetchAndMergeValidCompletionParams(...args), })) -const mockToastNotify = vi.spyOn(Toast, 'notify') - // Mock child components vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ default: ({ defaultModel, modelList, scopeFeatures, onSelect }: { @@ -244,7 +257,6 @@ const setupModelLists = (config: { describe('ModelParameterModal', () => { beforeEach(() => { vi.clearAllMocks() - mockToastNotify.mockReturnValue({}) mockProviderContextValue.isAPIKeySet = true mockProviderContextValue.modelProviders = [] setupModelLists() @@ -865,9 +877,7 @@ describe('ModelParameterModal', () => { // Assert await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'warning' }), - ) + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'warning' })) }) }) @@ -892,9 +902,7 @@ describe('ModelParameterModal', () => { // Assert await waitFor(() => { - expect(Toast.notify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' })) }) }) }) diff --git a/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx index 04b78f98b7..6838bcec43 100644 --- a/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx @@ -10,12 +10,12 @@ import type { import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' import { Popover, PopoverContent, PopoverTrigger, } from '@/app/components/base/ui/popover' +import { toast } from '@/app/components/base/ui/toast' import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelList, @@ -134,14 +134,11 @@ const ModelParameterModal: FC = ({ const keys = Object.keys(removedDetails || {}) if (keys.length) { - Toast.notify({ - type: 'warning', - message: `${t('modelProvider.parametersInvalidRemoved', { ns: 'common' })}: ${keys.map(k => `${k} (${removedDetails[k]})`).join(', ')}`, - }) + toast.warning(`${t('modelProvider.parametersInvalidRemoved', { ns: 'common' })}: ${keys.map(k => `${k} (${removedDetails[k]})`).join(', ')}`) } } catch { - Toast.notify({ type: 'error', message: t('error', { ns: 'common' }) }) + toast.error(t('error', { ns: 'common' })) } } diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/log-viewer.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/log-viewer.spec.tsx index c6fb42faab..351c1f9d2d 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/log-viewer.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/log-viewer.spec.tsx @@ -1,12 +1,26 @@ import type { TriggerLogEntity } from '@/app/components/workflow/block-selector/types' import { cleanup, fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import LogViewer from '../log-viewer' const mockToastNotify = vi.fn() const mockWriteText = vi.fn() +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), +})) + vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({ default: ({ value }: { value: unknown }) => (
{JSON.stringify(value)}
@@ -57,10 +71,6 @@ beforeEach(() => { }, configurable: true, }) - vi.spyOn(Toast, 'notify').mockImplementation((args) => { - mockToastNotify(args) - return { clear: vi.fn() } - }) }) describe('LogViewer', () => { diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-entry.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-entry.spec.tsx index d8d41ff9b2..3c4ff83fc8 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-entry.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-entry.spec.tsx @@ -26,10 +26,16 @@ vi.mock('@/service/use-triggers', () => ({ useDeleteTriggerSubscription: () => ({ mutate: vi.fn(), isPending: false }), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-view.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-view.spec.tsx index 83d0cdd89d..44cec53e28 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-view.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/selector-view.spec.tsx @@ -1,7 +1,6 @@ import type { TriggerSubscription } from '@/app/components/workflow/block-selector/types' import { fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { SubscriptionSelectorView } from '../selector-view' @@ -26,6 +25,18 @@ vi.mock('@/service/use-triggers', () => ({ useDeleteTriggerSubscription: () => ({ mutate: mockDelete, isPending: false }), })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', @@ -42,7 +53,6 @@ const createSubscription = (overrides: Partial = {}): Trigg beforeEach(() => { vi.clearAllMocks() mockSubscriptions = [createSubscription()] - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) }) describe('SubscriptionSelectorView', () => { diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/subscription-card.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/subscription-card.spec.tsx index a51bc2954f..4665c921ca 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/subscription-card.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/__tests__/subscription-card.spec.tsx @@ -1,7 +1,6 @@ import type { TriggerSubscription } from '@/app/components/workflow/block-selector/types' import { fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import SubscriptionCard from '../subscription-card' @@ -30,6 +29,18 @@ vi.mock('@/service/use-triggers', () => ({ useDeleteTriggerSubscription: () => ({ mutate: vi.fn(), isPending: false }), })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), +})) + const createSubscription = (overrides: Partial = {}): TriggerSubscription => ({ id: 'sub-1', name: 'Subscription One', @@ -45,7 +56,6 @@ const createSubscription = (overrides: Partial = {}): Trigg beforeEach(() => { vi.clearAllMocks() - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) }) describe('SubscriptionCard', () => { diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/common-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/common-modal.spec.tsx index 21a4c3defa..72532ea38d 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/common-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/common-modal.spec.tsx @@ -122,10 +122,16 @@ vi.mock('@/utils/urlValidation', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (params: unknown) => mockToastNotify(params), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((params: unknown) => mockToastNotify(params), { + success: (message: unknown) => mockToastNotify({ type: 'success', message }), + error: (message: unknown) => mockToastNotify({ type: 'error', message }), + warning: (message: unknown) => mockToastNotify({ type: 'warning', message }), + info: (message: unknown) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) vi.mock('@/app/components/base/modal/modal', () => ({ diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/index.spec.tsx index 3fe9884b92..a36c108160 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/index.spec.tsx @@ -2,6 +2,7 @@ import type { SimpleDetail } from '../../../store' import type { TriggerOAuthConfig, TriggerProviderApiEntity, TriggerSubscription, TriggerSubscriptionBuilder } from '@/app/components/workflow/block-selector/types' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { toast } from '@/app/components/base/ui/toast' import { SupportedCreationMethods } from '@/app/components/plugins/types' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { CreateButtonType, CreateSubscriptionButton, DEFAULT_METHOD } from '../index' @@ -33,10 +34,16 @@ vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ }, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: vi.fn(), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) let mockStoreDetail: SimpleDetail | undefined @@ -908,8 +915,6 @@ describe('CreateSubscriptionButton', () => { it('should handle OAuth initiation error', async () => { // Arrange - const Toast = await import('@/app/components/base/toast') - mockInitiateOAuth.mockImplementation((_provider: string, callbacks: { onError: () => void }) => { callbacks.onError() }) @@ -932,9 +937,7 @@ describe('CreateSubscriptionButton', () => { // Assert await waitFor(() => { - expect(Toast.default.notify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(toast.error).toHaveBeenCalled() }) }) }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx index 12419a9bf3..ce53bf5b9a 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx @@ -86,10 +86,19 @@ vi.mock('@/hooks/use-oauth', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (params: unknown) => mockToastNotify(params), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), })) const mockClipboardWriteText = vi.fn() diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-oauth-client-state.spec.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-oauth-client-state.spec.ts index 89566f3af7..68864b0b80 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-oauth-client-state.spec.ts +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/__tests__/use-oauth-client-state.spec.ts @@ -77,10 +77,19 @@ vi.mock('@/hooks/use-oauth', () => ({ })) const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { - notify: (params: unknown) => mockToastNotify(params), - }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), })) // ============================================================================ diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts index b01312d3d1..99c42f6fc5 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts @@ -7,7 +7,7 @@ import type { BuildTriggerSubscriptionPayload } from '@/service/use-triggers' import { debounce } from 'es-toolkit/compat' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { SupportedCreationMethods } from '@/app/components/plugins/types' import { TriggerCredentialTypeEnum } from '@/app/components/workflow/block-selector/types' import { @@ -154,10 +154,7 @@ export const useCommonModalState = ({ onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('modal.errors.updateFailed', { ns: 'pluginTrigger' }) console.error('Failed to update subscription builder:', error) - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) }, }, ) @@ -178,10 +175,7 @@ export const useCommonModalState = ({ } catch (error) { console.error('createBuilder error:', error) - Toast.notify({ - type: 'error', - message: t('modal.errors.createFailed', { ns: 'pluginTrigger' }), - }) + toast.error(t('modal.errors.createFailed', { ns: 'pluginTrigger' })) } } if (!isInitializedRef.current && !subscriptionBuilder && detail?.provider) @@ -239,10 +233,7 @@ export const useCommonModalState = ({ const handleVerify = useCallback(() => { // Guard against uninitialized state if (!detail?.provider || !subscriptionBuilder?.id) { - Toast.notify({ - type: 'error', - message: 'Subscription builder not initialized', - }) + toast.error('Subscription builder not initialized') return } @@ -250,10 +241,7 @@ export const useCommonModalState = ({ const credentials = apiKeyCredentialsFormValues.values if (!Object.keys(credentials).length) { - Toast.notify({ - type: 'error', - message: 'Please fill in all required credentials', - }) + toast.error('Please fill in all required credentials') return } @@ -270,10 +258,7 @@ export const useCommonModalState = ({ }, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.apiKey.verify.success', { ns: 'pluginTrigger' })) setCurrentStep(ApiKeyStep.Configuration) }, onError: async (error: unknown) => { @@ -290,10 +275,7 @@ export const useCommonModalState = ({ // Handle create const handleCreate = useCallback(() => { if (!subscriptionBuilder) { - Toast.notify({ - type: 'error', - message: 'Subscription builder not found', - }) + toast.error('Subscription builder not found') return } @@ -327,19 +309,13 @@ export const useCommonModalState = ({ params, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.createSuccess', { ns: 'pluginTrigger' }), - }) + toast.success(t('subscription.createSuccess', { ns: 'pluginTrigger' })) onClose() refetch?.() }, onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('subscription.createFailed', { ns: 'pluginTrigger' }) - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) }, }, ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts index 6a551051e2..e5a5ded9df 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts @@ -4,7 +4,7 @@ import type { TriggerOAuthClientParams, TriggerOAuthConfig, TriggerSubscriptionB import type { ConfigureTriggerOAuthPayload } from '@/service/use-triggers' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { openOAuthPopup } from '@/hooks/use-oauth' import { useConfigureTriggerOAuth, @@ -118,20 +118,14 @@ export const useOAuthClientState = ({ openOAuthPopup(response.authorization_url, (callbackData) => { if (!callbackData) return - Toast.notify({ - type: 'success', - message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' })) onClose() showOAuthCreateModal(response.subscription_builder) }) }, onError: () => { setAuthorizationStatus(AuthorizationStatusEnum.Failed) - Toast.notify({ - type: 'error', - message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }), - }) + toast.error(t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' })) }, }) }, [providerName, initiateOAuth, onClose, showOAuthCreateModal, t]) @@ -141,16 +135,10 @@ export const useOAuthClientState = ({ deleteOAuth(providerName, { onSuccess: () => { onClose() - Toast.notify({ - type: 'success', - message: t('modal.oauth.remove.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.oauth.remove.success', { ns: 'pluginTrigger' })) }, onError: (error: unknown) => { - Toast.notify({ - type: 'error', - message: getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' })), - }) + toast.error(getErrorMessage(error, t('modal.oauth.remove.failed', { ns: 'pluginTrigger' }))) }, }) }, [providerName, deleteOAuth, onClose, t]) @@ -187,10 +175,7 @@ export const useOAuthClientState = ({ return } onClose() - Toast.notify({ - type: 'success', - message: t('modal.oauth.save.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.oauth.save.success', { ns: 'pluginTrigger' })) }, }) }, [clientType, providerName, oauthClientSchema, oauthConfig?.params, configureOAuth, handleAuthorization, onClose, t]) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx index eecaf165fb..bd0846c15e 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx @@ -8,8 +8,8 @@ import { ActionButton, ActionButtonState } from '@/app/components/base/action-bu import Badge from '@/app/components/base/badge' import { Button } from '@/app/components/base/button' import CustomSelect from '@/app/components/base/select/custom' -import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { toast } from '@/app/components/base/ui/toast' import { openOAuthPopup } from '@/hooks/use-oauth' import { useInitiateTriggerOAuth, useTriggerOAuthConfig, useTriggerProviderInfo } from '@/service/use-triggers' import { cn } from '@/utils/classnames' @@ -107,19 +107,13 @@ export const CreateSubscriptionButton = ({ buttonType = CreateButtonType.FULL_BU onSuccess: (response) => { openOAuthPopup(response.authorization_url, (callbackData) => { if (callbackData) { - Toast.notify({ - type: 'success', - message: t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.oauth.authorization.authSuccess', { ns: 'pluginTrigger' })) setSelectedCreateInfo({ type: SupportedCreationMethods.OAUTH, builder: response.subscription_builder }) } }) }, onError: () => { - Toast.notify({ - type: 'error', - message: t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' }), - }) + toast.error(t('modal.oauth.authorization.authFailed', { ns: 'pluginTrigger' })) }, }) } diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx index b7f9b8ebec..d4bc92169c 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { BaseForm } from '@/app/components/base/form/components/base' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card' import { usePluginStore } from '../../store' import { ClientTypeEnum, useOAuthClientState } from './hooks/use-oauth-client-state' @@ -48,10 +48,7 @@ export const OAuthClientSettingsModal = ({ oauthConfig, onClose, showOAuthCreate const handleCopyRedirectUri = () => { navigator.clipboard.writeText(oauthConfig?.redirect_uri || '') - Toast.notify({ - type: 'success', - message: t('actionMsg.copySuccessfully', { ns: 'common' }), - }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) } return ( diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx index e0fb7455ce..7f22a06295 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/apikey-edit-modal.spec.tsx @@ -47,13 +47,19 @@ vi.mock('@/service/use-triggers', () => ({ useTriggerPluginDynamicOptions: () => ({ data: [], isLoading: false }), })) -vi.mock('@/app/components/base/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, - default: { - notify: (args: { type: string, message: string }) => mockToast(args), - }, + toast: Object.assign((args: { type: string, message: string }) => mockToast(args), { + success: (message: string) => mockToast({ type: 'success', message }), + error: (message: string) => mockToast({ type: 'error', message }), + warning: (message: string) => mockToast({ type: 'warning', message }), + info: (message: string) => mockToast({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), } }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/index.spec.tsx index 7d188a3f6d..126d8e366d 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/index.spec.tsx @@ -13,8 +13,16 @@ import { OAuthEditModal } from '../oauth-edit-modal' // ==================== Mock Setup ==================== const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (params: unknown) => mockToastNotify(params) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockParsePluginErrorMessage = vi.fn() diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx index 60a8428287..52572b3560 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/manual-edit-modal.spec.tsx @@ -30,13 +30,19 @@ vi.mock('@/service/use-triggers', () => ({ useTriggerPluginDynamicOptions: () => ({ data: [], isLoading: false }), })) -vi.mock('@/app/components/base/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, - default: { - notify: (args: { type: string, message: string }) => mockToast(args), - }, + toast: Object.assign((args: { type: string, message: string }) => mockToast(args), { + success: (message: string) => mockToast({ type: 'success', message }), + error: (message: string) => mockToast({ type: 'error', message }), + warning: (message: string) => mockToast({ type: 'warning', message }), + info: (message: string) => mockToast({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), } }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx index 8835b46695..95b2cca6af 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/__tests__/oauth-edit-modal.spec.tsx @@ -30,13 +30,19 @@ vi.mock('@/service/use-triggers', () => ({ useTriggerPluginDynamicOptions: () => ({ data: [], isLoading: false }), })) -vi.mock('@/app/components/base/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, - default: { - notify: (args: { type: string, message: string }) => mockToast(args), - }, + toast: Object.assign((args: { type: string, message: string }) => mockToast(args), { + success: (message: string) => mockToast({ type: 'success', message }), + error: (message: string) => mockToast({ type: 'error', message }), + warning: (message: string) => mockToast({ type: 'warning', message }), + info: (message: string) => mockToast({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), } }) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx index a4093ed00b..247beaa626 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx @@ -9,7 +9,7 @@ import { EncryptedBottom } from '@/app/components/base/encrypted-bottom' import { BaseForm } from '@/app/components/base/form/components/base' import { FormTypeEnum } from '@/app/components/base/form/types' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ReadmeEntrance } from '@/app/components/plugins/readme-panel/entrance' import { useUpdateTriggerSubscription, useVerifyTriggerSubscription } from '@/service/use-triggers' import { parsePluginErrorMessage } from '@/utils/error-parser' @@ -65,7 +65,7 @@ const StatusStep = ({ isActive, text, onClick, clickable }: { }) => { return (
{ - Toast.notify({ - type: 'success', - message: t('modal.apiKey.verify.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('modal.apiKey.verify.success', { ns: 'pluginTrigger' })) // Only save credentials if any field was modified (not all hidden) setVerifiedCredentials(areAllCredentialsHidden(credentials) ? null : credentials) setCurrentStep(EditStep.EditConfiguration) }, onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('modal.apiKey.verify.error', { ns: 'pluginTrigger' }) - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) }, }, ) @@ -192,19 +186,13 @@ export const ApiKeyEditModal = ({ onClose, subscription, pluginDetail }: Props) }, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' })) refetch?.() onClose() }, onError: async (error: unknown) => { const errorMessage = await parsePluginErrorMessage(error) || t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' }) - Toast.notify({ - type: 'error', - message: errorMessage, - }) + toast.error(errorMessage) }, }, ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx index 262235e6ed..e1741da8e7 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import { BaseForm } from '@/app/components/base/form/components/base' import { FormTypeEnum } from '@/app/components/base/form/types' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ReadmeEntrance } from '@/app/components/plugins/readme-panel/entrance' import { useUpdateTriggerSubscription } from '@/service/use-triggers' import { ReadmeShowType } from '../../../readme-panel/store' @@ -94,18 +94,12 @@ export const ManualEditModal = ({ onClose, subscription, pluginDetail }: Props) }, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' })) refetch?.() onClose() }, onError: (error: unknown) => { - Toast.notify({ - type: 'error', - message: getErrorMessage(error, t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' })), - }) + toast.error(getErrorMessage(error, t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' }))) }, }, ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx index e57b9c0151..c43a00a322 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next' import { BaseForm } from '@/app/components/base/form/components/base' import { FormTypeEnum } from '@/app/components/base/form/types' import Modal from '@/app/components/base/modal/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ReadmeEntrance } from '@/app/components/plugins/readme-panel/entrance' import { useUpdateTriggerSubscription } from '@/service/use-triggers' import { ReadmeShowType } from '../../../readme-panel/store' @@ -94,18 +94,12 @@ export const OAuthEditModal = ({ onClose, subscription, pluginDetail }: Props) = }, { onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' }), - }) + toast.success(t('subscription.list.item.actions.edit.success', { ns: 'pluginTrigger' })) refetch?.() onClose() }, onError: (error: unknown) => { - Toast.notify({ - type: 'error', - message: getErrorMessage(error, t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' })), - }) + toast.error(getErrorMessage(error, t('subscription.list.item.actions.edit.error', { ns: 'pluginTrigger' }))) }, }, ) diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/log-viewer.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/log-viewer.tsx index 3b4edd1b85..6ccab000c4 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/log-viewer.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/log-viewer.tsx @@ -11,7 +11,7 @@ import dayjs from 'dayjs' import * as React from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { cn } from '@/utils/classnames' @@ -89,10 +89,7 @@ const LogViewer = ({ logs, className }: Props) => { onClick={(e) => { e.stopPropagation() navigator.clipboard.writeText(String(parsedData)) - Toast.notify({ - type: 'success', - message: t('actionMsg.copySuccessfully', { ns: 'common' }), - }) + toast.success(t('actionMsg.copySuccessfully', { ns: 'common' })) }} className="rounded-md p-0.5 hover:bg-components-panel-border" > diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx index 26e4de0fd7..537e99d733 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx @@ -298,8 +298,16 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-modal // Mock Toast - need to track notify calls for assertions const mockToastNotify = vi.fn() -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (...args: unknown[]) => mockToastNotify(...args) }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign((message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) // ==================== Test Utilities ==================== @@ -1943,7 +1951,7 @@ describe('ToolCredentialsForm Component', () => { const saveBtn = screen.getByText(/save/i) fireEvent.click(saveBtn) - // Toast.notify should have been called with error (lines 49-50) + // notifyToast should have been called with error (lines 49-50) expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' })) // onSaved should not be called because validation fails expect(onSaved).not.toHaveBeenCalled() diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx index cb5b929d29..e1b8ca86fe 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/__tests__/tool-credentials-form.spec.tsx @@ -10,12 +10,16 @@ vi.mock('@/utils/classnames', () => ({ cn: (...args: unknown[]) => args.filter(Boolean).join(' '), })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: vi.fn() }, -})) - -vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: vi.fn() }), +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign(vi.fn(), { + success: vi.fn(), + error: vi.fn(), + warning: vi.fn(), + info: vi.fn(), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }), })) const mockFormSchemas = [ diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx index 0207f65336..f4d39c9ec1 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx @@ -10,7 +10,7 @@ import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import { addDefaultValue, toolCredentialToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import { useRenderI18nObject } from '@/hooks/use-i18n' @@ -49,7 +49,10 @@ const ToolCredentialForm: FC = ({ return for (const field of credentialSchema) { if (field.required && !tempCredential[field.name]) { - Toast.notify({ type: 'error', message: t('errorMsg.fieldRequired', { ns: 'common', field: getValueFromI18nObject(field.label) }) }) + toast.error(t('errorMsg.fieldRequired', { + ns: 'common', + field: getValueFromI18nObject(field.label), + })) return } } diff --git a/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx b/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx index 8467c983d8..b4d21c9403 100644 --- a/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx +++ b/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx @@ -1,7 +1,6 @@ import type { MetaData, PluginCategoryEnum } from '../../types' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import Toast from '@/app/components/base/toast' // ==================== Imports (after mocks) ==================== @@ -17,12 +16,29 @@ const { mockCheckForUpdates, mockSetShowUpdatePluginModal, mockInvalidateInstalledPluginList, + mockToastNotify, } = vi.hoisted(() => ({ mockUninstallPlugin: vi.fn(), mockFetchReleases: vi.fn(), mockCheckForUpdates: vi.fn(), mockSetShowUpdatePluginModal: vi.fn(), mockInvalidateInstalledPluginList: vi.fn(), + mockToastNotify: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: Object.assign( + (message: string, options?: { type?: string }) => mockToastNotify({ type: options?.type, message }), + { + success: (message: string) => mockToastNotify({ type: 'success', message }), + error: (message: string) => mockToastNotify({ type: 'error', message }), + warning: (message: string) => mockToastNotify({ type: 'warning', message }), + info: (message: string) => mockToastNotify({ type: 'info', message }), + dismiss: vi.fn(), + update: vi.fn(), + promise: vi.fn(), + }, + ), })) // Mock uninstall plugin service @@ -140,13 +156,8 @@ const getActionButtons = () => screen.getAllByRole('button') const queryActionButtons = () => screen.queryAllByRole('button') describe('Action Component', () => { - // Spy on Toast.notify - real component but we track calls - let toastNotifySpy: ReturnType - beforeEach(() => { vi.clearAllMocks() - // Spy on Toast.notify and mock implementation to avoid DOM side effects - toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) mockUninstallPlugin.mockResolvedValue({ success: true }) mockFetchReleases.mockResolvedValue([]) mockCheckForUpdates.mockReturnValue({ @@ -155,10 +166,6 @@ describe('Action Component', () => { }) }) - afterEach(() => { - toastNotifySpy.mockRestore() - }) - // ==================== Rendering Tests ==================== describe('Rendering', () => { it('should render delete button when isShowDelete is true', () => { @@ -563,9 +570,9 @@ describe('Action Component', () => { render() fireEvent.click(getActionButtons()[0]) - // Assert - Toast.notify is called with the toast props + // Assert - toast is called with the translated payload await waitFor(() => { - expect(toastNotifySpy).toHaveBeenCalledWith({ type: 'success', message: 'Already up to date' }) + expect(mockToastNotify).toHaveBeenCalledWith({ type: 'success', message: 'Already up to date' }) }) }) diff --git a/web/app/components/plugins/plugin-item/action.tsx b/web/app/components/plugins/plugin-item/action.tsx index 171e54acab..413b41e895 100644 --- a/web/app/components/plugins/plugin-item/action.tsx +++ b/web/app/components/plugins/plugin-item/action.tsx @@ -7,7 +7,7 @@ import { useBoolean } from 'ahooks' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useModalContext } from '@/context/modal-context' import { uninstallPlugin } from '@/service/plugins' import { useInvalidateInstalledPluginList } from '@/service/use-plugins' @@ -65,7 +65,7 @@ const Action: FC = ({ if (fetchedReleases.length === 0) return const { needUpdate, toastProps } = checkForUpdates(fetchedReleases, meta!.version) - Toast.notify(toastProps) + toast(toastProps.message, { type: toastProps.type }) if (needUpdate) { setShowUpdatePluginModal({ onSaveCallback: () => { diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 04e8be3afd..0f69e6ce33 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -5062,9 +5062,6 @@ } }, "app/components/plugins/install-plugin/hooks.ts": { - "no-restricted-imports": { - "count": 2 - }, "ts/no-explicit-any": { "count": 4 } @@ -5100,9 +5097,6 @@ }, "app/components/plugins/install-plugin/install-from-github/index.tsx": { "no-restricted-imports": { - "count": 3 - }, - "tailwindcss/enforce-consistent-class-order": { "count": 2 }, "ts/no-explicit-any": { @@ -5367,17 +5361,9 @@ "count": 1 } }, - "app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts": { - "no-restricted-imports": { - "count": 1 - } - }, "app/components/plugins/plugin-detail-panel/endpoint-card.tsx": { "no-restricted-imports": { - "count": 3 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 5 + "count": 2 }, "ts/no-explicit-any": { "count": 2 @@ -5385,22 +5371,13 @@ }, "app/components/plugins/plugin-detail-panel/endpoint-list.tsx": { "no-restricted-imports": { - "count": 2 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 4 + "count": 1 }, "ts/no-explicit-any": { "count": 2 } }, "app/components/plugins/plugin-detail-panel/endpoint-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "tailwindcss/enforce-consistent-class-order": { - "count": 3 - }, "ts/no-explicit-any": { "count": 7 } @@ -5414,9 +5391,6 @@ } }, "app/components/plugins/plugin-detail-panel/model-selector/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 3 } @@ -5471,27 +5445,21 @@ "app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-common-modal-state.ts": { "erasable-syntax-only/enums": { "count": 1 - }, - "no-restricted-imports": { - "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts": { "erasable-syntax-only/enums": { "count": 2 - }, - "no-restricted-imports": { - "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": { "no-restricted-imports": { - "count": 4 + "count": 3 } }, "app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 }, "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -5507,20 +5475,17 @@ "count": 1 }, "no-restricted-imports": { - "count": 2 - }, - "tailwindcss/enforce-consistent-class-order": { "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx": { "no-restricted-imports": { - "count": 2 + "count": 1 } }, "app/components/plugins/plugin-detail-panel/subscription-list/index.tsx": { @@ -5540,9 +5505,6 @@ "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "tailwindcss/enforce-consistent-class-order": { "count": 5 }, @@ -5600,11 +5562,6 @@ "count": 2 } }, - "app/components/plugins/plugin-detail-panel/tool-selector/components/tool-credentials-form.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "app/components/plugins/plugin-detail-panel/tool-selector/components/tool-item.tsx": { "no-restricted-imports": { "count": 2 @@ -5643,7 +5600,7 @@ }, "app/components/plugins/plugin-item/action.tsx": { "no-restricted-imports": { - "count": 3 + "count": 2 } }, "app/components/plugins/plugin-item/index.tsx": { From 508350ec6ac89c7764a145ef61cad7e3d1f39ad6 Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:19:32 +0800 Subject: [PATCH 31/34] test: enhance useChat hook tests with additional scenarios (#33928) Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../__tests__/hooks/handle-resume.spec.ts | 1090 +++++++++++++++++ .../__tests__/hooks/handle-send.spec.ts | 194 +++ .../hooks/handle-stop-restart.spec.ts | 199 +++ .../__tests__/hooks/misc.spec.ts | 380 ++++++ .../opening-statement.spec.ts} | 69 +- .../__tests__/hooks/sse-callbacks.spec.ts | 914 ++++++++++++++ 6 files changed, 2831 insertions(+), 15 deletions(-) create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-send.spec.ts create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-stop-restart.spec.ts create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/misc.spec.ts rename web/app/components/workflow/panel/debug-and-preview/__tests__/{hooks.spec.ts => hooks/opening-statement.spec.ts} (63%) create mode 100644 web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts new file mode 100644 index 0000000000..3367bd4766 --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-resume.spec.ts @@ -0,0 +1,1090 @@ +/* eslint-disable ts/no-explicit-any */ +import type { ChatItemInTree } from '@/app/components/base/chat/types' +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleResume', () => { + let capturedResumeOptions: any + + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockSseGet.mockReset() + }) + + async function setupResumeWithTree() { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => useChat({})) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', {}) + }) + + return hook + } + + it('should call sseGet with the correct URL', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleResume('msg-1', 'wfr-1', {}) + }) + + expect(mockSseGet).toHaveBeenCalledWith( + '/workflow/wfr-1/events?include_state_snapshot=true', + {}, + expect.any(Object), + ) + }) + + it('should abort previous SSE connection when handleResume is called again', () => { + const mockAbortCtrl = new AbortController() + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + options.getAbortController(mockAbortCtrl) + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleResume('msg-1', 'wfr-1', {}) + }) + + const mockAbort2 = vi.fn() + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + options.getAbortController({ abort: mockAbort2 }) + }) + + act(() => { + result.current.handleResume('msg-1', 'wfr-2', {}) + }) + + expect(mockAbortCtrl.signal.aborted).toBe(true) + }) + + it('should abort previous workflowEventsAbortController before sseGet', () => { + const mockAbort = vi.fn() + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.getAbortController({ abort: mockAbort } as any) + }) + + mockSseGet.mockImplementation(() => {}) + + act(() => { + result.current.handleResume('msg-1', 'wfr-1', {}) + }) + + expect(mockAbort).toHaveBeenCalledTimes(1) + }) + + describe('onWorkflowStarted', () => { + it('should set isResponding and update workflow process', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + expect(result.current.isResponding).toBe(true) + }) + + it('should resume existing workflow when tracing exists', async () => { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => useChat({})) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + + act(() => { + sendCallbacks.onNodeStarted({ + data: { node_id: 'n1', id: 'trace-1' }, + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', {}) + }) + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + const answer = hook.result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.status).toBe('running') + }) + }) + + describe('onWorkflowFinished', () => { + it('should update workflow status', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onWorkflowFinished({ + data: { status: 'succeeded' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.status).toBe('succeeded') + }) + }) + + describe('onData', () => { + it('should append message content', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onData('resumed', false, { + conversationId: 'conv-2', + messageId: 'msg-resume', + taskId: 'task-2', + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.content).toContain('resumed') + }) + + it('should update conversationId when provided', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onData('msg', false, { + conversationId: 'new-conv-resume', + messageId: null, + taskId: 'task-2', + }) + }) + + expect(result.current.conversationId).toBe('new-conv-resume') + }) + }) + + describe('onCompleted', () => { + it('should set isResponding to false', async () => { + const { result } = await setupResumeWithTree() + await act(async () => { + await capturedResumeOptions.onCompleted(false) + }) + expect(result.current.isResponding).toBe(false) + }) + + it('should not call fetchInspectVars when paused', async () => { + mockWorkflowRunningData = { result: { status: 'paused' } } + await setupResumeWithTree() + mockFetchInspectVars.mockClear() + await act(async () => { + await capturedResumeOptions.onCompleted(false) + }) + expect(mockFetchInspectVars).not.toHaveBeenCalled() + }) + + it('should still call fetchInspectVars on error but skip suggested questions', async () => { + const mockGetSuggested = vi.fn().mockResolvedValue({ data: ['s1'] }) + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + mockFetchInspectVars.mockClear() + mockInvalidAllLastRun.mockClear() + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + await act(async () => { + await capturedResumeOptions.onCompleted(true) + }) + + expect(mockFetchInspectVars).toHaveBeenCalledWith({}) + expect(mockInvalidAllLastRun).toHaveBeenCalled() + expect(mockGetSuggested).not.toHaveBeenCalled() + }) + + it('should fetch suggested questions when enabled', async () => { + const mockGetSuggested = vi.fn().mockImplementation((_id: string, getAbortCtrl: any) => { + getAbortCtrl(new AbortController()) + return Promise.resolve({ data: ['s1'] }) + }) + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedResumeOptions.onCompleted(false) + }) + + expect(mockGetSuggested).toHaveBeenCalled() + }) + + it('should set suggestedQuestions to empty on fetch error', async () => { + const mockGetSuggested = vi.fn().mockRejectedValue(new Error('fail')) + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-resume', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + hook.result.current.handleResume('msg-resume', 'wfr-1', { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedResumeOptions.onCompleted(false) + }) + + expect(hook.result.current.suggestedQuestions).toEqual([]) + }) + }) + + describe('onError', () => { + it('should set isResponding to false', async () => { + const { result } = await setupResumeWithTree() + act(() => { + capturedResumeOptions.onError() + }) + expect(result.current.isResponding).toBe(false) + }) + }) + + describe('onMessageEnd', () => { + it('should update citation and files', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onMessageEnd({ + metadata: { retriever_resources: [{ id: 'cite-1' }] }, + files: [], + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.citation).toEqual([{ id: 'cite-1' }]) + }) + }) + + describe('onMessageReplace', () => { + it('should replace content', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onMessageReplace({ answer: 'replaced' }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.content).toBe('replaced') + }) + }) + + describe('onIterationStart / onIterationFinish', () => { + it('should push and update iteration tracing entries', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onIterationStart({ + data: { id: 'iter-r1', node_id: 'n-iter-r' }, + }) + }) + + let answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('iter-r1') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + + act(() => { + capturedResumeOptions.onIterationFinish({ + data: { id: 'iter-r1', node_id: 'n-iter-r', execution_metadata: {} }, + }) + }) + + answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].status).toBe('succeeded') + }) + + it('should handle iteration finish when no match found', async () => { + await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onIterationFinish({ + data: { id: 'no-match', node_id: 'no-match', execution_metadata: {} }, + }) + }) + }) + }) + + describe('onLoopStart / onLoopFinish', () => { + it('should push and update loop tracing entries', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onLoopStart({ + data: { id: 'loop-r1', node_id: 'n-loop-r' }, + }) + }) + + let answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('loop-r1') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + + act(() => { + capturedResumeOptions.onLoopFinish({ + data: { id: 'loop-r1', node_id: 'n-loop-r', execution_metadata: {} }, + }) + }) + + answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].status).toBe('succeeded') + }) + + it('should handle loop finish when no match found', async () => { + await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onLoopFinish({ + data: { id: 'no-match', node_id: 'no-match', execution_metadata: {} }, + }) + }) + }) + }) + + describe('onNodeStarted / onNodeFinished', () => { + it('should add and update node tracing entries', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-1', id: 'rtrace-1' }, + }) + }) + + let answer = result.current.chatList.find(item => item.id === 'msg-resume') + const startedTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'rn-1') + expect(startedTrace).toBeDefined() + expect(startedTrace!.id).toBe('rtrace-1') + expect(startedTrace!.status).toBe('running') + + act(() => { + capturedResumeOptions.onNodeFinished({ + data: { node_id: 'rn-1', id: 'rtrace-1', status: 'succeeded' }, + }) + }) + + answer = result.current.chatList.find(item => item.id === 'msg-resume') + const finishedTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'rn-1') + expect(finishedTrace).toBeDefined() + expect((finishedTrace as any).status).toBe('succeeded') + }) + + it('should skip onNodeStarted when iteration_id is present', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-child', id: 'rtrace-child', iteration_id: 'iter-parent' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.tracing.some((t: any) => t.node_id === 'rn-child')).toBe(false) + }) + + it('should skip onNodeFinished when iteration_id is present', async () => { + await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeFinished({ + data: { node_id: 'rn-1', id: 'rtrace-1', iteration_id: 'iter-parent' }, + }) + }) + }) + + it('should update existing node in tracing on onNodeStarted', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-1', id: 'rtrace-1' }, + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-1', id: 'rtrace-1-v2' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + const matchingTraces = answer!.workflowProcess!.tracing.filter((t: any) => t.node_id === 'rn-1') + expect(matchingTraces).toHaveLength(1) + expect(matchingTraces[0].id).toBe('rtrace-1-v2') + expect(matchingTraces[0].status).toBe('running') + }) + + it('should match nodeFinished with parallel_id', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-1', id: 'rtrace-1', execution_metadata: { parallel_id: 'p1' } }, + }) + }) + + act(() => { + capturedResumeOptions.onNodeFinished({ + data: { + node_id: 'rn-1', + id: 'rtrace-1', + status: 'succeeded', + execution_metadata: { parallel_id: 'p1' }, + }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.id === 'rtrace-1') + expect(trace).toBeDefined() + expect((trace as any).status).toBe('succeeded') + expect((trace as any).execution_metadata.parallel_id).toBe('p1') + }) + }) + + describe('onHumanInputRequired', () => { + it('should initialize humanInputFormDataList', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFormDataList).toHaveLength(1) + }) + + it('should update existing form for same node and push for different node', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-2' }, + }) + }) + + let answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFormDataList).toHaveLength(1) + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human-2', form_token: 'rt-3' }, + }) + }) + + answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFormDataList).toHaveLength(2) + }) + + it('should set tracing node to Paused when tracing match is found', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-human', id: 'trace-human' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'rn-human') + expect(trace!.status).toBe('paused') + }) + }) + + describe('onHumanInputFormFilled', () => { + it('should move form from pending to filled list', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputFormFilled({ + data: { node_id: 'rn-human', form_data: { a: 1 } }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFormDataList).toHaveLength(0) + expect(answer!.humanInputFilledFormDataList).toHaveLength(1) + }) + + it('should initialize humanInputFilledFormDataList when not present', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputFormFilled({ + data: { node_id: 'rn-human', form_data: { b: 2 } }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.humanInputFilledFormDataList).toHaveLength(1) + }) + }) + + describe('onHumanInputFormTimeout', () => { + it('should set expiration_time on the form entry', async () => { + const { result } = await setupResumeWithTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'rn-human', form_token: 'rt-1' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputFormTimeout({ + data: { node_id: 'rn-human', expiration_time: '2025-06-01' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + const form = answer!.humanInputFormDataList!.find((f: any) => f.node_id === 'rn-human') + expect(form!.expiration_time).toBe('2025-06-01') + }) + }) + + describe('onWorkflowPaused', () => { + it('should re-subscribe via sseGet and set status to Paused', async () => { + const { result } = await setupResumeWithTree() + const sseGetCallsBefore = mockSseGet.mock.calls.length + + act(() => { + capturedResumeOptions.onWorkflowPaused({ + data: { workflow_run_id: 'wfr-paused' }, + }) + }) + + expect(mockSseGet.mock.calls.length).toBeGreaterThan(sseGetCallsBefore) + const answer = result.current.chatList.find(item => item.id === 'msg-resume') + expect(answer!.workflowProcess!.status).toBe('paused') + }) + }) +}) + +describe('useChat – handleResume with bare prevChatTree (no humanInputFormDataList / no tracing)', () => { + let capturedResumeOptions: any + + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockSseGet.mockReset() + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + capturedResumeOptions = options + }) + }) + + function setupWithBareTree() { + const prevChatTree: ChatItemInTree[] = [ + { + id: 'q1', + content: 'question', + isAnswer: false, + children: [ + { + id: 'bare-msg', + content: '', + isAnswer: true, + workflow_run_id: 'wfr-bare', + workflowProcess: { + status: 'running' as any, + tracing: [], + }, + children: [], + }, + ], + }, + ] + + const hook = renderHook(() => useChat({}, undefined, prevChatTree)) + + act(() => { + hook.result.current.handleResume('bare-msg', 'wfr-bare', {}) + }) + + return hook + } + + function setupWithBareTreeNoTracing() { + const prevChatTree: ChatItemInTree[] = [ + { + id: 'q1', + content: 'question', + isAnswer: false, + children: [ + { + id: 'bare-msg-nt', + content: '', + isAnswer: true, + workflow_run_id: 'wfr-bare-nt', + workflowProcess: { + status: 'running' as any, + tracing: undefined as any, + }, + children: [], + }, + ], + }, + ] + + const hook = renderHook(() => useChat({}, undefined, prevChatTree)) + + act(() => { + hook.result.current.handleResume('bare-msg-nt', 'wfr-bare-nt', {}) + }) + + return hook + } + + it('onHumanInputRequired should initialize humanInputFormDataList when null', () => { + const { result } = setupWithBareTree() + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'hn-bare', form_token: 'ft-bare' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg') + expect(answer!.humanInputFormDataList).toHaveLength(1) + }) + + it('onHumanInputFormFilled should initialize humanInputFilledFormDataList when null', () => { + const { result } = setupWithBareTree() + + act(() => { + capturedResumeOptions.onHumanInputFormFilled({ + data: { node_id: 'hn-bare', form_data: { x: 1 } }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg') + expect(answer!.humanInputFilledFormDataList).toHaveLength(1) + }) + + it('onLoopStart should initialize tracing array when not present', () => { + const { result } = setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onLoopStart({ + data: { id: 'loop-bare', node_id: 'n-loop-bare' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg-nt') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('loop-bare') + expect(answer!.workflowProcess!.tracing[0].node_id).toBe('n-loop-bare') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + }) + + it('onLoopFinish should return early when no tracing', () => { + setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onLoopFinish({ + data: { id: 'loop-bare', node_id: 'n-loop-bare', execution_metadata: {} }, + }) + }) + }) + + it('onIterationStart should initialize tracing when not present', () => { + const { result } = setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onIterationStart({ + data: { id: 'iter-bare', node_id: 'n-iter-bare' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg-nt') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('iter-bare') + expect(answer!.workflowProcess!.tracing[0].node_id).toBe('n-iter-bare') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + }) + + it('onIterationFinish should return early when no tracing', () => { + setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onIterationFinish({ + data: { id: 'iter-bare', node_id: 'n-iter-bare', execution_metadata: {} }, + }) + }) + }) + + it('onNodeStarted should initialize tracing when not present', () => { + const { result } = setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'rn-bare', id: 'rtrace-bare' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg-nt') + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect(answer!.workflowProcess!.tracing[0].id).toBe('rtrace-bare') + expect(answer!.workflowProcess!.tracing[0].node_id).toBe('rn-bare') + expect(answer!.workflowProcess!.tracing[0].status).toBe('running') + }) + + it('onNodeFinished should return early when no tracing', () => { + setupWithBareTreeNoTracing() + + act(() => { + capturedResumeOptions.onNodeFinished({ + data: { node_id: 'rn-bare', id: 'rtrace-bare', status: 'succeeded' }, + }) + }) + }) + + it('onIterationStart/onNodeStarted/onLoopStart should return early when no workflowProcess', () => { + const prevChatTreeNoWP: ChatItemInTree[] = [ + { + id: 'q-nowp', + content: 'question', + isAnswer: false, + children: [ + { + id: 'bare-nowp', + content: '', + isAnswer: true, + children: [], + }, + ], + }, + ] + + const hook = renderHook(() => useChat({}, undefined, prevChatTreeNoWP)) + let opts: any + mockSseGet.mockImplementation((_url: any, _opts: any, options: any) => { + opts = options + }) + + act(() => { + hook.result.current.handleResume('bare-nowp', 'wfr-x', {}) + }) + + act(() => { + opts.onIterationStart({ data: { id: 'i1', node_id: 'ni1' } }) + }) + + act(() => { + opts.onNodeStarted({ data: { node_id: 'ns1', id: 'ts1' } }) + }) + + act(() => { + opts.onLoopStart({ data: { id: 'l1', node_id: 'nl1' } }) + }) + + const answer = hook.result.current.chatList.find(item => item.id === 'bare-nowp') + expect(answer!.workflowProcess).toBeUndefined() + }) + + it('onHumanInputRequired should set Paused on tracing node when found', () => { + const { result } = setupWithBareTree() + + act(() => { + capturedResumeOptions.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + }) + }) + + act(() => { + capturedResumeOptions.onNodeStarted({ + data: { node_id: 'hn-with-trace', id: 'trace-hn' }, + }) + }) + + act(() => { + capturedResumeOptions.onHumanInputRequired({ + data: { node_id: 'hn-with-trace', form_token: 'ft-tr' }, + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'bare-msg') + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'hn-with-trace') + expect(trace!.status).toBe('paused') + }) +}) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-send.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-send.spec.ts new file mode 100644 index 0000000000..0a12aac582 --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-send.spec.ts @@ -0,0 +1,194 @@ +/* eslint-disable ts/no-explicit-any */ +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleSend', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + }) + + it('should call handleRun with processed params', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'hello', inputs: {} }, {}) + }) + + expect(mockHandleRun).toHaveBeenCalledTimes(1) + const [bodyParams] = mockHandleRun.mock.calls[0] + expect(bodyParams.query).toBe('hello') + }) + + it('should show notification and return false when already responding', () => { + mockHandleRun.mockImplementation(() => {}) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'first' }, {}) + }) + + act(() => { + const returned = result.current.handleSend({ query: 'second' }, {}) + expect(returned).toBe(false) + }) + + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ type: 'info' })) + }) + + it('should set isResponding to true after sending', () => { + const { result } = renderHook(() => useChat({})) + act(() => { + result.current.handleSend({ query: 'hello' }, {}) + }) + expect(result.current.isResponding).toBe(true) + }) + + it('should add placeholder question and answer to chatList', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test question' }, {}) + }) + + const questionItem = result.current.chatList.find(item => item.content === 'test question') + expect(questionItem).toBeDefined() + expect(questionItem!.isAnswer).toBe(false) + + const answerPlaceholder = result.current.chatList.find( + item => item.isAnswer && !item.isOpeningStatement && item.content === '', + ) + expect(answerPlaceholder).toBeDefined() + }) + + it('should strip url from local_file transfer method files', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend( + { + query: 'hello', + files: [ + { + id: 'f1', + name: 'test.png', + size: 1024, + type: 'image/png', + progress: 100, + transferMethod: 'local_file', + supportFileType: 'image', + url: 'blob://local', + uploadedId: 'up1', + }, + { + id: 'f2', + name: 'remote.png', + size: 2048, + type: 'image/png', + progress: 100, + transferMethod: 'remote_url', + supportFileType: 'image', + url: 'https://example.com/img.png', + uploadedId: '', + }, + ] as any, + }, + {}, + ) + }) + + expect(mockHandleRun).toHaveBeenCalledTimes(1) + const [bodyParams] = mockHandleRun.mock.calls[0] + const localFile = bodyParams.files.find((f: any) => f.transfer_method === 'local_file') + const remoteFile = bodyParams.files.find((f: any) => f.transfer_method === 'remote_url') + expect(localFile.url).toBe('') + expect(remoteFile.url).toBe('https://example.com/img.png') + }) + + it('should abort previous workflowEventsAbortController before sending', () => { + const mockAbort = vi.fn() + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + callbacks.getAbortController({ abort: mockAbort } as any) + callbacks.onCompleted(false) + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'first' }, {}) + }) + + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + callbacks.getAbortController({ abort: vi.fn() } as any) + }) + + act(() => { + result.current.handleSend({ query: 'second' }, {}) + }) + + expect(mockAbort).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-stop-restart.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-stop-restart.spec.ts new file mode 100644 index 0000000000..7127634a10 --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/handle-stop-restart.spec.ts @@ -0,0 +1,199 @@ +/* eslint-disable ts/no-explicit-any */ +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockStopChat = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleStop', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + }) + + it('should set isResponding to false', () => { + const { result } = renderHook(() => useChat({})) + act(() => { + result.current.handleStop() + }) + expect(result.current.isResponding).toBe(false) + }) + + it('should not call stopChat when taskId is empty even if stopChat is provided', () => { + const { result } = renderHook(() => useChat({}, undefined, undefined, mockStopChat)) + act(() => { + result.current.handleStop() + }) + expect(mockStopChat).not.toHaveBeenCalled() + }) + + it('should reset iter/loop times to defaults', () => { + const { result } = renderHook(() => useChat({})) + act(() => { + result.current.handleStop() + }) + expect(mockSetIterTimes).toHaveBeenCalledWith(1) + expect(mockSetLoopTimes).toHaveBeenCalledWith(1) + }) + + it('should abort workflowEventsAbortController when set', () => { + const mockWfAbort = vi.fn() + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + callbacks.getAbortController({ abort: mockWfAbort } as any) + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + result.current.handleStop() + }) + + expect(mockWfAbort).toHaveBeenCalledTimes(1) + }) + + it('should abort suggestedQuestionsAbortController when set', async () => { + const mockSqAbort = vi.fn() + let capturedCb: any + + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCb = callbacks + }) + + const mockGetSuggested = vi.fn().mockImplementation((_id: string, getAbortCtrl: any) => { + getAbortCtrl({ abort: mockSqAbort } as any) + return Promise.resolve({ data: ['s'] }) + }) + + const { result } = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + act(() => { + result.current.handleSend({ query: 'test' }, { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedCb.onCompleted(false) + }) + + act(() => { + result.current.handleStop() + }) + + expect(mockSqAbort).toHaveBeenCalledTimes(1) + }) + + it('should call stopChat with taskId when both are available', () => { + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + callbacks.onData('msg', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 'task-stop', + }) + }) + + const { result } = renderHook(() => useChat({}, undefined, undefined, mockStopChat)) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + result.current.handleStop() + }) + + expect(mockStopChat).toHaveBeenCalledWith('task-stop') + }) +}) + +describe('useChat – handleRestart', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + }) + + it('should clear suggestedQuestions and set isResponding to false', () => { + const config = { opening_statement: 'Hello' } + const { result } = renderHook(() => useChat(config)) + + act(() => { + result.current.handleRestart() + }) + + expect(result.current.suggestedQuestions).toEqual([]) + expect(result.current.isResponding).toBe(false) + }) + + it('should reset iter/loop times to defaults', () => { + const { result } = renderHook(() => useChat({})) + act(() => { + result.current.handleRestart() + }) + expect(mockSetIterTimes).toHaveBeenCalledWith(1) + expect(mockSetLoopTimes).toHaveBeenCalledWith(1) + }) +}) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/misc.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/misc.spec.ts new file mode 100644 index 0000000000..fe2a800561 --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/misc.spec.ts @@ -0,0 +1,380 @@ +/* eslint-disable ts/no-explicit-any */ +import type { ChatItemInTree } from '@/app/components/base/chat/types' +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleSwitchSibling', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockSseGet.mockReset() + }) + + it('should call handleResume when target has workflow_run_id and pending humanInputFormData', async () => { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation(() => {}) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-switch', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-switch', + }) + }) + + act(() => { + sendCallbacks.onHumanInputRequired({ + data: { node_id: 'human-n', form_token: 'ft-1' }, + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSwitchSibling('msg-switch', {}) + }) + + expect(mockSseGet).toHaveBeenCalled() + }) + + it('should not call handleResume when target has no humanInputFormDataList', async () => { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'test' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-switch', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-switch', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSwitchSibling('msg-switch', {}) + }) + + expect(mockSseGet).not.toHaveBeenCalled() + }) + + it('should return undefined from findMessageInTree when not found', () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSwitchSibling('nonexistent-id', {}) + }) + + expect(mockSseGet).not.toHaveBeenCalled() + }) + + it('should search children recursively in findMessageInTree', async () => { + let sendCallbacks: any + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + sendCallbacks = callbacks + }) + mockSseGet.mockImplementation(() => {}) + + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'parent' }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'msg-parent', + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSend({ + query: 'child', + parent_message_id: 'msg-parent', + }, {}) + }) + + act(() => { + sendCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-2', + task_id: 'task-2', + conversation_id: null, + message_id: 'msg-child', + }) + }) + + act(() => { + sendCallbacks.onHumanInputRequired({ + data: { node_id: 'h-child', form_token: 'ft-c' }, + }) + }) + + await act(async () => { + await sendCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSwitchSibling('msg-child', {}) + }) + + expect(mockSseGet).toHaveBeenCalled() + }) +}) + +describe('useChat – handleSubmitHumanInputForm', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + mockSubmitHumanInputForm.mockResolvedValue({}) + }) + + it('should call submitHumanInputForm with token and data', async () => { + const { result } = renderHook(() => useChat({})) + + await act(async () => { + await result.current.handleSubmitHumanInputForm('token-123', { field: 'value' }) + }) + + expect(mockSubmitHumanInputForm).toHaveBeenCalledWith('token-123', { field: 'value' }) + }) +}) + +describe('useChat – getHumanInputNodeData', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + mockGetNodes.mockReturnValue([]) + }) + + it('should return the custom node matching the given nodeID', () => { + const mockNode = { id: 'node-1', type: 'custom', data: { title: 'Human Input' } } + mockGetNodes.mockReturnValue([ + mockNode, + { id: 'node-2', type: 'custom', data: { title: 'Other' } }, + ]) + + const { result } = renderHook(() => useChat({})) + const node = result.current.getHumanInputNodeData('node-1') + expect(node).toEqual(mockNode) + }) + + it('should return undefined when no matching node', () => { + mockGetNodes.mockReturnValue([{ id: 'node-2', type: 'custom', data: {} }]) + + const { result } = renderHook(() => useChat({})) + const node = result.current.getHumanInputNodeData('nonexistent') + expect(node).toBeUndefined() + }) + + it('should filter out non-custom nodes', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', type: 'default', data: {} }, + { id: 'node-1', type: 'custom', data: { found: true } }, + ]) + + const { result } = renderHook(() => useChat({})) + const node = result.current.getHumanInputNodeData('node-1') + expect(node).toEqual({ id: 'node-1', type: 'custom', data: { found: true } }) + }) +}) + +describe('useChat – conversationId and setTargetMessageId', () => { + beforeEach(() => { + resetMocksAndWorkflowState() + }) + + it('should initially be an empty string', () => { + const { result } = renderHook(() => useChat({})) + expect(result.current.conversationId).toBe('') + }) + + it('setTargetMessageId should change chatList thread path', () => { + const prevChatTree: ChatItemInTree[] = [ + { + id: 'q1', + content: 'question 1', + isAnswer: false, + children: [ + { + id: 'a1', + content: 'answer 1', + isAnswer: true, + children: [ + { + id: 'q2-branch-a', + content: 'branch A question', + isAnswer: false, + children: [ + { id: 'a2-branch-a', content: 'branch A answer', isAnswer: true, children: [] }, + ], + }, + { + id: 'q2-branch-b', + content: 'branch B question', + isAnswer: false, + children: [ + { id: 'a2-branch-b', content: 'branch B answer', isAnswer: true, children: [] }, + ], + }, + ], + }, + ], + }, + ] + + const { result } = renderHook(() => useChat({}, undefined, prevChatTree)) + + const defaultList = result.current.chatList + expect(defaultList.some(item => item.id === 'a1')).toBe(true) + + act(() => { + result.current.setTargetMessageId('a2-branch-a') + }) + + const listA = result.current.chatList + expect(listA.some(item => item.id === 'a2-branch-a')).toBe(true) + expect(listA.some(item => item.id === 'a2-branch-b')).toBe(false) + + act(() => { + result.current.setTargetMessageId('a2-branch-b') + }) + + const listB = result.current.chatList + expect(listB.some(item => item.id === 'a2-branch-b')).toBe(true) + expect(listB.some(item => item.id === 'a2-branch-a')).toBe(false) + }) +}) + +describe('useChat – updateCurrentQAOnTree with parent_message_id', () => { + let capturedCallbacks: any + + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCallbacks = callbacks + }) + }) + + it('should handle follow-up message with parent_message_id', async () => { + const { result } = renderHook(() => useChat({})) + + act(() => { + result.current.handleSend({ query: 'first' }, {}) + }) + + const firstCallbacks = capturedCallbacks + + act(() => { + firstCallbacks.onData('answer1', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 't1', + }) + }) + + await act(async () => { + await firstCallbacks.onCompleted(false) + }) + + act(() => { + result.current.handleSend({ + query: 'follow up', + parent_message_id: 'msg-1', + }, {}) + }) + + expect(mockHandleRun).toHaveBeenCalledTimes(2) + expect(result.current.chatList.length).toBeGreaterThan(0) + }) +}) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/opening-statement.spec.ts similarity index 63% rename from web/app/components/workflow/panel/debug-and-preview/__tests__/hooks.spec.ts rename to web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/opening-statement.spec.ts index 397e1c22d6..e985d6790d 100644 --- a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks.spec.ts +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/opening-statement.spec.ts @@ -1,50 +1,73 @@ +/* eslint-disable ts/no-explicit-any */ import type { ChatItemInTree } from '@/app/components/base/chat/types' import { renderHook } from '@testing-library/react' -import { useChat } from '../hooks' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null vi.mock('@/service/base', () => ({ - sseGet: vi.fn(), + sseGet: (...args: any[]) => mockSseGet(...args), })) vi.mock('@/service/use-workflow', () => ({ - useInvalidAllLastRun: () => vi.fn(), + useInvalidAllLastRun: () => mockInvalidAllLastRun, })) vi.mock('@/service/workflow', () => ({ - submitHumanInputForm: vi.fn(), + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), })) vi.mock('@/app/components/base/toast/context', () => ({ - useToastContext: () => ({ notify: vi.fn() }), + useToastContext: () => ({ notify: mockNotify }), })) vi.mock('reactflow', () => ({ - useStoreApi: () => ({ getState: () => ({}) }), + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), })) -vi.mock('../../../hooks', () => ({ - useWorkflowRun: () => ({ handleRun: vi.fn() }), - useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: vi.fn() }), +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), })) -vi.mock('../../../hooks-store', () => ({ +vi.mock('../../../../hooks-store', () => ({ useHooksStore: () => null, })) -vi.mock('../../../store', () => ({ +vi.mock('../../../../store', () => ({ useWorkflowStore: () => ({ getState: () => ({ - setIterTimes: vi.fn(), - setLoopTimes: vi.fn(), + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, inputs: {}, + workflowRunningData: mockWorkflowRunningData, }), }), useStore: () => vi.fn(), })) +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + describe('workflow debug useChat – opening statement stability', () => { beforeEach(() => { - vi.clearAllMocks() + resetMocksAndWorkflowState() }) it('should return empty chatList when config has no opening_statement', () => { @@ -59,7 +82,6 @@ describe('workflow debug useChat – opening statement stability', () => { it('should use stable id "opening-statement" instead of Date.now()', () => { const config = { opening_statement: 'Welcome!' } - const { result } = renderHook(() => useChat(config)) expect(result.current.chatList[0].id).toBe('opening-statement') }) @@ -132,4 +154,21 @@ describe('workflow debug useChat – opening statement stability', () => { const openerAfter = result.current.chatList[0] expect(openerAfter).toBe(openerBefore) }) + + it('should include suggestedQuestions in opening statement when config has them', () => { + const config = { + opening_statement: 'Welcome!', + suggested_questions: ['How are you?', 'What can you do?'], + } + const { result } = renderHook(() => useChat(config)) + const opener = result.current.chatList[0] + expect(opener.suggestedQuestions).toEqual(['How are you?', 'What can you do?']) + }) + + it('should not include suggestedQuestions when config has none', () => { + const config = { opening_statement: 'Welcome!' } + const { result } = renderHook(() => useChat(config)) + const opener = result.current.chatList[0] + expect(opener.suggestedQuestions).toBeUndefined() + }) }) diff --git a/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts new file mode 100644 index 0000000000..073adc59de --- /dev/null +++ b/web/app/components/workflow/panel/debug-and-preview/__tests__/hooks/sse-callbacks.spec.ts @@ -0,0 +1,914 @@ +/* eslint-disable ts/no-explicit-any */ +import { act, renderHook } from '@testing-library/react' +import { useChat } from '../../hooks' + +const mockHandleRun = vi.fn() +const mockNotify = vi.fn() +const mockFetchInspectVars = vi.fn() +const mockInvalidAllLastRun = vi.fn() +const mockSetIterTimes = vi.fn() +const mockSetLoopTimes = vi.fn() +const mockSubmitHumanInputForm = vi.fn() +const mockSseGet = vi.fn() +const mockGetNodes = vi.fn((): any[] => []) + +let mockWorkflowRunningData: any = null + +vi.mock('@/service/base', () => ({ + sseGet: (...args: any[]) => mockSseGet(...args), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +vi.mock('@/service/workflow', () => ({ + submitHumanInputForm: (...args: any[]) => mockSubmitHumanInputForm(...args), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + useToastContext: () => ({ notify: mockNotify }), +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: mockGetNodes, + }), + }), +})) + +vi.mock('../../../../hooks', () => ({ + useWorkflowRun: () => ({ handleRun: mockHandleRun }), + useSetWorkflowVarsWithValue: () => ({ fetchInspectVars: mockFetchInspectVars }), +})) + +vi.mock('../../../../hooks-store', () => ({ + useHooksStore: () => null, +})) + +vi.mock('../../../../store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setIterTimes: mockSetIterTimes, + setLoopTimes: mockSetLoopTimes, + inputs: {}, + workflowRunningData: mockWorkflowRunningData, + }), + }), + useStore: () => vi.fn(), +})) + +const resetMocksAndWorkflowState = () => { + vi.clearAllMocks() + mockWorkflowRunningData = null +} + +describe('useChat – handleSend SSE callbacks', () => { + let capturedCallbacks: any + + beforeEach(() => { + resetMocksAndWorkflowState() + mockHandleRun.mockReset() + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCallbacks = callbacks + }) + }) + + function setupAndSend(config: any = {}) { + const hook = renderHook(() => useChat(config)) + act(() => { + hook.result.current.handleSend({ query: 'test' }, { + onGetSuggestedQuestions: vi.fn().mockResolvedValue({ data: ['q1'] }), + }) + }) + return hook + } + + function startWorkflow(overrides: Record = {}) { + act(() => { + capturedCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: null, + ...overrides, + }) + }) + } + + function startNode(nodeId: string, traceId: string, extra: Record = {}) { + act(() => { + capturedCallbacks.onNodeStarted({ + data: { node_id: nodeId, id: traceId, ...extra }, + }) + }) + } + + describe('onData', () => { + it('should append message content', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('Hello', true, { + conversationId: 'conv-1', + messageId: 'msg-1', + taskId: 'task-1', + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.content).toContain('Hello') + }) + + it('should set response id from messageId on first call', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('Hi', true, { + conversationId: 'conv-1', + messageId: 'msg-123', + taskId: 'task-1', + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-123') + expect(answer).toBeDefined() + }) + + it('should set conversationId on first message with newConversationId', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('Hi', true, { + conversationId: 'new-conv-id', + messageId: 'msg-1', + taskId: 'task-1', + }) + }) + + expect(result.current.conversationId).toBe('new-conv-id') + }) + + it('should not set conversationId when isFirstMessage is false', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('Hi', false, { + conversationId: 'conv-should-not-set', + messageId: 'msg-1', + taskId: 'task-1', + }) + }) + + expect(result.current.conversationId).toBe('') + }) + + it('should not update hasSetResponseId when messageId is empty', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('msg1', true, { + conversationId: '', + messageId: '', + taskId: 'task-1', + }) + }) + + act(() => { + capturedCallbacks.onData('msg2', false, { + conversationId: '', + messageId: 'late-id', + taskId: 'task-1', + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'late-id') + expect(answer).toBeDefined() + }) + + it('should only set hasSetResponseId once', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('msg1', true, { + conversationId: 'c1', + messageId: 'msg-first', + taskId: 'task-1', + }) + }) + + act(() => { + capturedCallbacks.onData('msg2', false, { + conversationId: 'c1', + messageId: 'msg-second', + taskId: 'task-1', + }) + }) + + const question = result.current.chatList.find(item => !item.isAnswer) + expect(question!.id).toBe('question-msg-first') + }) + }) + + describe('onCompleted', () => { + it('should set isResponding to false', async () => { + const { result } = setupAndSend() + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + expect(result.current.isResponding).toBe(false) + }) + + it('should call fetchInspectVars and invalidAllLastRun when not paused', async () => { + setupAndSend() + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + expect(mockFetchInspectVars).toHaveBeenCalledWith({}) + expect(mockInvalidAllLastRun).toHaveBeenCalled() + }) + + it('should not call fetchInspectVars when workflow is paused', async () => { + mockWorkflowRunningData = { result: { status: 'paused' } } + setupAndSend() + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + expect(mockFetchInspectVars).not.toHaveBeenCalled() + }) + + it('should set error content on response item when hasError with errorMessage', async () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('partial', true, { + conversationId: 'c1', + messageId: 'msg-err', + taskId: 't1', + }) + }) + + await act(async () => { + await capturedCallbacks.onCompleted(true, 'Something went wrong') + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-err') + expect(answer!.content).toBe('Something went wrong') + expect(answer!.isError).toBe(true) + }) + + it('should not set error content when hasError is true but errorMessage is empty', async () => { + const { result } = setupAndSend() + await act(async () => { + await capturedCallbacks.onCompleted(true) + }) + expect(result.current.isResponding).toBe(false) + }) + + it('should fetch suggested questions when enabled and invoke abort controller callback', async () => { + const mockGetSuggested = vi.fn().mockImplementation((_id: string, getAbortCtrl: any) => { + getAbortCtrl(new AbortController()) + return Promise.resolve({ data: ['suggestion1'] }) + }) + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCallbacks = callbacks + }) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + + expect(mockGetSuggested).toHaveBeenCalled() + }) + + it('should set suggestedQuestions to empty array when fetch fails', async () => { + const mockGetSuggested = vi.fn().mockRejectedValue(new Error('fail')) + const hook = renderHook(() => + useChat({ suggested_questions_after_answer: { enabled: true } }), + ) + + mockHandleRun.mockImplementation((_params: any, callbacks: any) => { + capturedCallbacks = callbacks + }) + + act(() => { + hook.result.current.handleSend({ query: 'test' }, { + onGetSuggestedQuestions: mockGetSuggested, + }) + }) + + await act(async () => { + await capturedCallbacks.onCompleted(false) + }) + + expect(hook.result.current.suggestedQuestions).toEqual([]) + }) + }) + + describe('onError', () => { + it('should set isResponding to false', () => { + const { result } = setupAndSend() + act(() => { + capturedCallbacks.onError() + }) + expect(result.current.isResponding).toBe(false) + }) + }) + + describe('onMessageEnd', () => { + it('should update citation and files', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('response', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 't1', + }) + }) + + act(() => { + capturedCallbacks.onMessageEnd({ + metadata: { retriever_resources: [{ id: 'r1' }] }, + files: [], + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-1') + expect(answer!.citation).toEqual([{ id: 'r1' }]) + }) + + it('should default citation to empty array when no retriever_resources', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('response', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 't1', + }) + }) + + act(() => { + capturedCallbacks.onMessageEnd({ metadata: {}, files: [] }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-1') + expect(answer!.citation).toEqual([]) + }) + }) + + describe('onMessageReplace', () => { + it('should replace answer content on responseItem', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onData('old', true, { + conversationId: 'c1', + messageId: 'msg-1', + taskId: 't1', + }) + }) + + act(() => { + capturedCallbacks.onMessageReplace({ answer: 'replaced' }) + }) + + act(() => { + capturedCallbacks.onMessageEnd({ metadata: {}, files: [] }) + }) + + const answer = result.current.chatList.find(item => item.id === 'msg-1') + expect(answer!.content).toBe('replaced') + }) + }) + + describe('onWorkflowStarted', () => { + it('should create workflow process with Running status', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: 'conv-1', + message_id: 'msg-1', + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('running') + expect(answer!.workflowProcess!.tracing).toEqual([]) + }) + + it('should set conversationId when provided', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: 'from-workflow', + message_id: null, + }) + }) + + expect(result.current.conversationId).toBe('from-workflow') + }) + + it('should not override existing conversationId when conversation_id is null', () => { + const { result } = setupAndSend() + startWorkflow() + expect(result.current.conversationId).toBe('') + }) + + it('should resume existing workflow process when tracing exists', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('n1', 'trace-1') + startWorkflow({ workflow_run_id: 'wfr-2', task_id: 'task-2' }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('running') + expect(answer!.workflowProcess!.tracing.length).toBe(1) + }) + + it('should replace placeholder answer id with real message_id from server', () => { + const { result } = setupAndSend() + + act(() => { + capturedCallbacks.onWorkflowStarted({ + workflow_run_id: 'wfr-1', + task_id: 'task-1', + conversation_id: null, + message_id: 'wf-msg-id', + }) + }) + + const answer = result.current.chatList.find(item => item.id === 'wf-msg-id') + expect(answer).toBeDefined() + }) + }) + + describe('onWorkflowFinished', () => { + it('should update workflow process status', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onWorkflowFinished({ data: { status: 'succeeded' } }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('succeeded') + }) + }) + + describe('onIterationStart / onIterationFinish', () => { + it('should push tracing entry on start', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onIterationStart({ + data: { id: 'iter-1', node_id: 'n-iter' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('iter-1') + expect(trace.node_id).toBe('n-iter') + expect(trace.status).toBe('running') + }) + + it('should update matching tracing on finish', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onIterationStart({ + data: { id: 'iter-1', node_id: 'n-iter' }, + }) + }) + + act(() => { + capturedCallbacks.onIterationFinish({ + data: { id: 'iter-1', node_id: 'n-iter', output: 'done' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.id === 'iter-1') + expect(trace).toBeDefined() + expect(trace!.node_id).toBe('n-iter') + expect((trace as any).output).toBe('done') + }) + + it('should not update tracing on finish when id does not match', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onIterationStart({ + data: { id: 'iter-1', node_id: 'n-iter' }, + }) + }) + + act(() => { + capturedCallbacks.onIterationFinish({ + data: { id: 'iter-nonexistent', node_id: 'n-other' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect((answer!.workflowProcess!.tracing[0] as any).output).toBeUndefined() + }) + }) + + describe('onLoopStart / onLoopFinish', () => { + it('should push tracing entry on start', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onLoopStart({ + data: { id: 'loop-1', node_id: 'n-loop' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('loop-1') + expect(trace.node_id).toBe('n-loop') + expect(trace.status).toBe('running') + }) + + it('should update matching tracing on finish', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onLoopStart({ + data: { id: 'loop-1', node_id: 'n-loop' }, + }) + }) + + act(() => { + capturedCallbacks.onLoopFinish({ + data: { id: 'loop-1', node_id: 'n-loop', output: 'done' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('loop-1') + expect(trace.node_id).toBe('n-loop') + expect((trace as any).output).toBe('done') + }) + + it('should not update tracing on finish when id does not match', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onLoopStart({ + data: { id: 'loop-1', node_id: 'n-loop' }, + }) + }) + + act(() => { + capturedCallbacks.onLoopFinish({ + data: { id: 'loop-nonexistent', node_id: 'n-other' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + expect((answer!.workflowProcess!.tracing[0] as any).output).toBeUndefined() + }) + }) + + describe('onNodeStarted / onNodeRetry / onNodeFinished', () => { + it('should add new tracing entry', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('node-1', 'trace-1') + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('trace-1') + expect(trace.node_id).toBe('node-1') + expect(trace.status).toBe('running') + }) + + it('should update existing tracing entry with same node_id', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('node-1', 'trace-1') + startNode('node-1', 'trace-1-v2') + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('trace-1-v2') + expect(trace.node_id).toBe('node-1') + expect(trace.status).toBe('running') + }) + + it('should push retry data to tracing', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onNodeRetry({ + data: { node_id: 'node-1', id: 'retry-1', retry_index: 1 }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('retry-1') + expect(trace.node_id).toBe('node-1') + expect((trace as any).retry_index).toBe(1) + }) + + it('should update tracing entry on finish by id', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('node-1', 'trace-1') + + act(() => { + capturedCallbacks.onNodeFinished({ + data: { node_id: 'node-1', id: 'trace-1', status: 'succeeded', outputs: { text: 'done' } }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('trace-1') + expect(trace.status).toBe('succeeded') + expect((trace as any).outputs).toEqual({ text: 'done' }) + }) + + it('should not update tracing on finish when id does not match', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('node-1', 'trace-1') + + act(() => { + capturedCallbacks.onNodeFinished({ + data: { node_id: 'node-x', id: 'trace-x', status: 'succeeded' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.tracing).toHaveLength(1) + const trace = answer!.workflowProcess!.tracing[0] + expect(trace.id).toBe('trace-1') + expect(trace.status).toBe('running') + }) + }) + + describe('onAgentLog', () => { + function setupWithNode() { + const hook = setupAndSend() + startWorkflow() + return hook + } + + it('should create execution_metadata.agent_log when no execution_metadata exists', () => { + const { result } = setupWithNode() + startNode('agent-node', 'trace-agent') + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'agent-node', message_id: 'log-1', content: 'init' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const agentTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'agent-node') + expect(agentTrace!.execution_metadata!.agent_log).toHaveLength(1) + }) + + it('should create agent_log array when execution_metadata exists but no agent_log', () => { + const { result } = setupWithNode() + startNode('agent-node', 'trace-agent', { execution_metadata: { parallel_id: 'p1' } }) + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'agent-node', message_id: 'log-1', content: 'step1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const agentTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'agent-node') + expect(agentTrace!.execution_metadata!.agent_log).toHaveLength(1) + }) + + it('should update existing agent_log entry by message_id', () => { + const { result } = setupWithNode() + startNode('agent-node', 'trace-agent', { + execution_metadata: { agent_log: [{ message_id: 'log-1', content: 'v1' }] }, + }) + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'agent-node', message_id: 'log-1', content: 'v2' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const agentTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'agent-node') + expect(agentTrace!.execution_metadata!.agent_log).toHaveLength(1) + expect((agentTrace!.execution_metadata!.agent_log as any[])[0].content).toBe('v2') + }) + + it('should push new agent_log entry when message_id does not match', () => { + const { result } = setupWithNode() + startNode('agent-node', 'trace-agent', { + execution_metadata: { agent_log: [{ message_id: 'log-1', content: 'v1' }] }, + }) + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'agent-node', message_id: 'log-2', content: 'new' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const agentTrace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'agent-node') + expect(agentTrace!.execution_metadata!.agent_log).toHaveLength(2) + }) + + it('should not crash when node_id is not found in tracing', () => { + setupWithNode() + + act(() => { + capturedCallbacks.onAgentLog({ + data: { node_id: 'nonexistent-node', message_id: 'log-1', content: 'noop' }, + }) + }) + }) + }) + + describe('onHumanInputRequired', () => { + it('should add form data to humanInputFormDataList', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('human-node', 'trace-human') + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.humanInputFormDataList).toHaveLength(1) + expect(answer!.humanInputFormDataList![0].node_id).toBe('human-node') + expect((answer!.humanInputFormDataList![0] as any).form_token).toBe('token-1') + }) + + it('should update existing form for same node_id', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('human-node', 'trace-human') + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-2' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.humanInputFormDataList).toHaveLength(1) + expect((answer!.humanInputFormDataList![0] as any).form_token).toBe('token-2') + }) + + it('should push new form data for different node_id', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node-1', form_token: 'token-1' }, + }) + }) + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node-2', form_token: 'token-2' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.humanInputFormDataList).toHaveLength(2) + expect(answer!.humanInputFormDataList![0].node_id).toBe('human-node-1') + expect(answer!.humanInputFormDataList![1].node_id).toBe('human-node-2') + }) + + it('should set tracing node status to Paused when tracing index found', () => { + const { result } = setupAndSend() + startWorkflow() + startNode('human-node', 'trace-human') + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const trace = answer!.workflowProcess!.tracing.find((t: any) => t.node_id === 'human-node') + expect(trace!.status).toBe('paused') + }) + }) + + describe('onHumanInputFormFilled', () => { + it('should remove form and add to filled list', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + act(() => { + capturedCallbacks.onHumanInputFormFilled({ + data: { node_id: 'human-node', form_data: { answer: 'yes' } }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.humanInputFormDataList).toHaveLength(0) + expect(answer!.humanInputFilledFormDataList).toHaveLength(1) + expect(answer!.humanInputFilledFormDataList![0].node_id).toBe('human-node') + expect((answer!.humanInputFilledFormDataList![0] as any).form_data).toEqual({ answer: 'yes' }) + }) + }) + + describe('onHumanInputFormTimeout', () => { + it('should update expiration_time on form data', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onHumanInputRequired({ + data: { node_id: 'human-node', form_token: 'token-1' }, + }) + }) + + act(() => { + capturedCallbacks.onHumanInputFormTimeout({ + data: { node_id: 'human-node', expiration_time: '2025-01-01T00:00:00Z' }, + }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + const form = answer!.humanInputFormDataList!.find((f: any) => f.node_id === 'human-node') + expect(form!.expiration_time).toBe('2025-01-01T00:00:00Z') + }) + }) + + describe('onWorkflowPaused', () => { + it('should set status to Paused', () => { + const { result } = setupAndSend() + startWorkflow() + + act(() => { + capturedCallbacks.onWorkflowPaused({ data: {} }) + }) + + const answer = result.current.chatList.find(item => item.isAnswer && !item.isOpeningStatement) + expect(answer!.workflowProcess!.status).toBe('paused') + }) + }) +}) From 7fe25f136569ec3796b9350d2f748c3e4e207396 Mon Sep 17 00:00:00 2001 From: Zhanyuan Guo <2364319479@qq.com> Date: Tue, 24 Mar 2026 15:08:55 +0800 Subject: [PATCH 32/34] fix(rate_limit): flush redis cache when __init__ is triggered by changing max_active_requests (#33830) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../app/features/rate_limiting/rate_limit.py | 12 +++++-- .../features/rate_limiting/test_rate_limit.py | 34 +++++++++++++++++-- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a..e0f1759e5e 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,6 +19,7 @@ class RateLimit: _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} + max_active_requests: int def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: @@ -27,7 +28,13 @@ class RateLimit: return cls._instance_dict[client_id] def __init__(self, client_id: str, max_active_requests: int): + flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests self.max_active_requests = max_active_requests + # Only flush here if this instance has already been fully initialized, + # i.e. the Redis key attributes exist. Otherwise, rely on the flush at + # the end of initialization below. + if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"): + self.flush_cache(use_local_value=True) # must be called after max_active_requests is set if self.disabled(): return @@ -41,8 +48,6 @@ class RateLimit: self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): - if self.disabled(): - return self.last_recalculate_time = time.time() # flush max active requests if use_local_value or not redis_client.exists(self.max_active_requests_key): @@ -50,7 +55,8 @@ class RateLimit: else: self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) - + if self.disabled(): + return # flush max active requests (in-transit request list) if not redis_client.exists(self.active_requests_key): return diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py index 3db10c1c72..538b130cac 100644 --- a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py +++ b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py @@ -68,8 +68,8 @@ class TestRateLimit: assert rate_limit.disabled() assert not hasattr(rate_limit, "initialized") - def test_should_skip_reinitialization_of_existing_instance(self, redis_patch): - """Test that existing instance doesn't reinitialize.""" + def test_should_flush_cache_when_reinitializing_existing_instance(self, redis_patch): + """Test existing instance refreshes Redis cache on reinitialization.""" redis_patch.configure_mock( **{ "exists.return_value": False, @@ -82,7 +82,37 @@ class TestRateLimit: RateLimit("client1", 10) + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) + + def test_should_reinitialize_after_being_disabled(self, redis_patch): + """Test disabled instance can be reinitialized and writes max_active_requests to Redis.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + # First construct with max_active_requests = 0 (disabled), which should skip initialization. + RateLimit("client1", 0) + + # Redis should not have been written to during disabled initialization. redis_patch.setex.assert_not_called() + redis_patch.reset_mock() + + # Reinitialize with a positive max_active_requests value; this should not raise + # and must write the max_active_requests key to Redis. + RateLimit("client1", 10) + + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) def test_should_be_disabled_when_max_requests_is_zero_or_negative(self): """Test disabled state for zero or negative limits.""" From 1674f8c2fb5ac6bc6e2468dfebbf8ae9096d4221 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Tue, 24 Mar 2026 15:10:05 +0800 Subject: [PATCH 33/34] fix: fix omitted app icon_type updates (#33988) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/app.py | 8 +- api/services/app_service.py | 10 +- .../services/test_app_service.py | 105 +++++++++++++++++- .../controllers/console/app/test_app_apis.py | 49 +++++++- .../unit_tests/services/test_app_service.py | 76 ++++++++++++- 5 files changed, 239 insertions(+), 9 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5ac0e342e6..7e41260eeb 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel): class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") @@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel): class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -594,7 +594,7 @@ class AppApi(Resource): args_dict: AppService.ArgsDict = { "name": args.name, "description": args.description or "", - "icon_type": args.icon_type or "", + "icon_type": args.icon_type, "icon": args.icon or "", "icon_background": args.icon_background or "", "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, diff --git a/api/services/app_service.py b/api/services/app_service.py index c5d1479a20..69c7c0c95a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -241,7 +241,7 @@ class AppService: class ArgsDict(TypedDict): name: str description: str - icon_type: str + icon_type: IconType | str | None icon: str icon_background: str use_icon_as_answer_icon: bool @@ -257,7 +257,13 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None + icon_type = args.get("icon_type") + if icon_type is None: + resolved_icon_type = app.icon_type + else: + resolved_icon_type = IconType(icon_type) + + app.icon_type = resolved_icon_type app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index d79f80c009..9ca8729b77 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account -from models.model import App, Site +from models.model import App, IconType, Site from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -463,6 +463,109 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by + def test_update_app_should_preserve_icon_type_when_omitted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app keeps the persisted icon_type when the update payload omits it. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + updated_app = app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": None, + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + + assert updated_app.icon_type == IconType.EMOJI + + def test_update_app_should_reject_empty_icon_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app rejects an explicit empty icon_type. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + with pytest.raises(ValueError): + app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index beb8ff55a5..1d1e119fd6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -7,14 +7,19 @@ from __future__ import annotations import uuid from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound +from controllers.console import console_ns from controllers.console.app import ( annotation as annotation_module, ) +from controllers.console.app import ( + app as app_module, +) from controllers.console.app import ( completion as completion_module, ) @@ -203,6 +208,48 @@ class TestCompletionEndpoints: method(app_model=MagicMock(id="app-1")) +class TestAppEndpoints: + """Tests for app endpoints.""" + + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + api = app_module.AppApi() + method = _unwrap(api.put) + payload = { + "name": "Updated App", + "description": "Updated description", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetailWithSite, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) + + assert response == {"id": "app-1"} + assert app_service.update_app.call_args.args[1]["icon_type"] is None + + def test_update_app_payload_should_reject_empty_icon_type(self): + with pytest.raises(ValidationError): + app_module.UpdateAppPayload.model_validate( + { + "name": "Updated App", + "description": "Updated description", + "icon_type": "", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + ) + + # ========== OpsTrace Tests ========== class TestOpsTraceEndpoints: """Tests for ops_trace endpoint.""" diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py index bff8dc92c6..95fc28b1e7 100644 --- a/api/tests/unit_tests/services/test_app_service.py +++ b/api/tests/unit_tests/services/test_app_service.py @@ -9,7 +9,7 @@ import pytest from core.errors.error import ProviderTokenNotInitError from models import Account, Tenant -from models.model import App, AppMode +from models.model import App, AppMode, IconType from services.app_service import AppService @@ -411,6 +411,7 @@ class TestAppServiceGetAndUpdate: # Assert assert updated is app + assert updated.icon_type == IconType.IMAGE assert renamed is app assert iconed is app assert site_same is app @@ -419,6 +420,79 @@ class TestAppServiceGetAndUpdate: assert api_changed is app assert mock_db.session.commit.call_count >= 5 + def test_update_app_should_preserve_icon_type_when_not_provided(self, service: AppService) -> None: + """Test update_app keeps the existing icon_type when the payload omits it.""" + # Arrange + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type=IconType.EMOJI, + icon="a", + icon_background="#111", + use_icon_as_answer_icon=False, + max_active_requests=1, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": None, + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + patch("services.app_service.naive_utc_now", return_value="now"), + ): + # Act + updated = service.update_app(app, args) + + # Assert + assert updated is app + assert updated.icon_type == IconType.EMOJI + mock_db.session.commit.assert_called_once() + + def test_update_app_should_reject_empty_icon_type(self, service: AppService) -> None: + """Test update_app rejects an explicit empty icon_type.""" + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type=IconType.EMOJI, + icon="a", + icon_background="#111", + use_icon_as_answer_icon=False, + max_active_requests=1, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": "", + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + ): + with pytest.raises(ValueError): + service.update_app(app, args) + + mock_db.session.commit.assert_not_called() + class TestAppServiceDeleteAndMeta: """Test suite for delete and metadata methods.""" From 0c3d11f920000109c3eb54823d10a3feb493d61f Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Tue, 24 Mar 2026 15:29:42 +0800 Subject: [PATCH 34/34] refactor: lazy load large modules (#33888) --- web/app/components/apps/index.tsx | 8 +++-- web/app/components/apps/list.tsx | 14 ++++---- .../base/amplitude/AmplitudeProvider.tsx | 9 ++--- .../__tests__/AmplitudeProvider.spec.tsx | 30 +++++++--------- .../base/amplitude/__tests__/index.spec.ts | 32 ----------------- .../base/amplitude/__tests__/utils.spec.ts | 6 ++-- web/app/components/base/amplitude/index.ts | 2 +- .../amplitude/lazy-amplitude-provider.tsx | 11 ++++++ web/app/components/base/amplitude/utils.ts | 10 +++--- .../components/devtools/agentation-loader.tsx | 13 +++++++ .../account-dropdown/__tests__/index.spec.tsx | 3 ++ web/app/components/header/app-nav/index.tsx | 8 +++-- .../components/lazy-sentry-initializer.tsx | 16 +++++++++ web/app/components/sentry-initializer.tsx | 7 ++-- web/app/layout.tsx | 34 +++++++++---------- web/config/index.ts | 2 ++ web/eslint-suppressions.json | 5 --- 17 files changed, 104 insertions(+), 106 deletions(-) delete mode 100644 web/app/components/base/amplitude/__tests__/index.spec.ts create mode 100644 web/app/components/base/amplitude/lazy-amplitude-provider.tsx create mode 100644 web/app/components/devtools/agentation-loader.tsx create mode 100644 web/app/components/lazy-sentry-initializer.tsx diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index dce9de190d..b6ca60bd7b 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -8,12 +8,14 @@ import AppListContext from '@/context/app-list-context' import useDocumentTitle from '@/hooks/use-document-title' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' +import dynamic from '@/next/dynamic' import { fetchAppDetail } from '@/service/explore' -import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal' -import CreateAppModal from '../explore/create-app-modal' -import TryApp from '../explore/try-app' import List from './list' +const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) +const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) +const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) + const Apps = () => { const { t } = useTranslation() diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 0d52bd468c..2ef344f816 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -5,11 +5,11 @@ import { useDebounceFn } from 'ahooks' import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagFilter from '@/app/components/base/tag-management/filter' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' -import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -205,12 +205,12 @@ const List: FC = ({ options={options} />
- + { - return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY -} - // Map URL pathname to English page name for consistent Amplitude tracking const getEnglishPageName = (pathname: string): string => { // Remove leading slash and get the first segment @@ -59,7 +54,7 @@ const AmplitudeProvider: FC = ({ }) => { useEffect(() => { // Only enable in Saas edition with valid API key - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return // Initialize Amplitude diff --git a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx index b30da72091..5835634eb7 100644 --- a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx +++ b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx @@ -2,14 +2,24 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import { render } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' +import AmplitudeProvider from '../AmplitudeProvider' const mockConfig = vi.hoisted(() => ({ AMPLITUDE_API_KEY: 'test-api-key', IS_CLOUD_EDITION: true, })) -vi.mock('@/config', () => mockConfig) +vi.mock('@/config', () => ({ + get AMPLITUDE_API_KEY() { + return mockConfig.AMPLITUDE_API_KEY + }, + get IS_CLOUD_EDITION() { + return mockConfig.IS_CLOUD_EDITION + }, + get isAmplitudeEnabled() { + return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY + }, +})) vi.mock('@amplitude/analytics-browser', () => ({ init: vi.fn(), @@ -27,22 +37,6 @@ describe('AmplitudeProvider', () => { mockConfig.IS_CLOUD_EDITION = true }) - describe('isAmplitudeEnabled', () => { - it('returns true when cloud edition and api key present', () => { - expect(isAmplitudeEnabled()).toBe(true) - }) - - it('returns false when cloud edition but no api key', () => { - mockConfig.AMPLITUDE_API_KEY = '' - expect(isAmplitudeEnabled()).toBe(false) - }) - - it('returns false when not cloud edition', () => { - mockConfig.IS_CLOUD_EDITION = false - expect(isAmplitudeEnabled()).toBe(false) - }) - }) - describe('Component', () => { it('initializes amplitude when enabled', () => { render() diff --git a/web/app/components/base/amplitude/__tests__/index.spec.ts b/web/app/components/base/amplitude/__tests__/index.spec.ts deleted file mode 100644 index 2d7ad6ab84..0000000000 --- a/web/app/components/base/amplitude/__tests__/index.spec.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { describe, expect, it } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' -import indexDefault, { - isAmplitudeEnabled as indexIsAmplitudeEnabled, - resetUser, - setUserId, - setUserProperties, - trackEvent, -} from '../index' -import { - resetUser as utilsResetUser, - setUserId as utilsSetUserId, - setUserProperties as utilsSetUserProperties, - trackEvent as utilsTrackEvent, -} from '../utils' - -describe('Amplitude index exports', () => { - it('exports AmplitudeProvider as default', () => { - expect(indexDefault).toBe(AmplitudeProvider) - }) - - it('exports isAmplitudeEnabled', () => { - expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled) - }) - - it('exports utils', () => { - expect(resetUser).toBe(utilsResetUser) - expect(setUserId).toBe(utilsSetUserId) - expect(setUserProperties).toBe(utilsSetUserProperties) - expect(trackEvent).toBe(utilsTrackEvent) - }) -}) diff --git a/web/app/components/base/amplitude/__tests__/utils.spec.ts b/web/app/components/base/amplitude/__tests__/utils.spec.ts index ecbc57e387..f1ff5db1e3 100644 --- a/web/app/components/base/amplitude/__tests__/utils.spec.ts +++ b/web/app/components/base/amplitude/__tests__/utils.spec.ts @@ -20,8 +20,10 @@ const MockIdentify = vi.hoisted(() => }, ) -vi.mock('../AmplitudeProvider', () => ({ - isAmplitudeEnabled: () => mockState.enabled, +vi.mock('@/config', () => ({ + get isAmplitudeEnabled() { + return mockState.enabled + }, })) vi.mock('@amplitude/analytics-browser', () => ({ diff --git a/web/app/components/base/amplitude/index.ts b/web/app/components/base/amplitude/index.ts index acc792339e..44cbf728e2 100644 --- a/web/app/components/base/amplitude/index.ts +++ b/web/app/components/base/amplitude/index.ts @@ -1,2 +1,2 @@ -export { default, isAmplitudeEnabled } from './AmplitudeProvider' +export { default } from './lazy-amplitude-provider' export { resetUser, setUserId, setUserProperties, trackEvent } from './utils' diff --git a/web/app/components/base/amplitude/lazy-amplitude-provider.tsx b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx new file mode 100644 index 0000000000..5dfa0e7b53 --- /dev/null +++ b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx @@ -0,0 +1,11 @@ +'use client' + +import type { FC } from 'react' +import type { IAmplitudeProps } from './AmplitudeProvider' +import dynamic from '@/next/dynamic' + +const AmplitudeProvider = dynamic(() => import('./AmplitudeProvider'), { ssr: false }) + +const LazyAmplitudeProvider: FC = props => + +export default LazyAmplitudeProvider diff --git a/web/app/components/base/amplitude/utils.ts b/web/app/components/base/amplitude/utils.ts index 57b96243ec..8faa8e852e 100644 --- a/web/app/components/base/amplitude/utils.ts +++ b/web/app/components/base/amplitude/utils.ts @@ -1,5 +1,5 @@ import * as amplitude from '@amplitude/analytics-browser' -import { isAmplitudeEnabled } from './AmplitudeProvider' +import { isAmplitudeEnabled } from '@/config' /** * Track custom event @@ -7,7 +7,7 @@ import { isAmplitudeEnabled } from './AmplitudeProvider' * @param eventProperties Event properties (optional) */ export const trackEvent = (eventName: string, eventProperties?: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.track(eventName, eventProperties) } @@ -17,7 +17,7 @@ export const trackEvent = (eventName: string, eventProperties?: Record { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.setUserId(userId) } @@ -27,7 +27,7 @@ export const setUserId = (userId: string) => { * @param properties User properties */ export const setUserProperties = (properties: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return const identifyEvent = new amplitude.Identify() Object.entries(properties).forEach(([key, value]) => { @@ -40,7 +40,7 @@ export const setUserProperties = (properties: Record) => { * Reset user (e.g., when user logs out) */ export const resetUser = () => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.reset() } diff --git a/web/app/components/devtools/agentation-loader.tsx b/web/app/components/devtools/agentation-loader.tsx new file mode 100644 index 0000000000..87e1b44c87 --- /dev/null +++ b/web/app/components/devtools/agentation-loader.tsx @@ -0,0 +1,13 @@ +'use client' + +import { IS_DEV } from '@/config' +import dynamic from '@/next/dynamic' + +const Agentation = dynamic(() => import('agentation').then(module => module.Agentation), { ssr: false }) + +export function AgentationLoader() { + if (!IS_DEV) + return null + + return +} diff --git a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx index eb4d543e66..9d4226c33a 100644 --- a/web/app/components/header/account-dropdown/__tests__/index.spec.tsx +++ b/web/app/components/header/account-dropdown/__tests__/index.spec.tsx @@ -69,6 +69,7 @@ vi.mock('@/context/i18n', () => ({ const { mockConfig, mockEnv } = vi.hoisted(() => ({ mockConfig: { IS_CLOUD_EDITION: false, + AMPLITUDE_API_KEY: '', ZENDESK_WIDGET_KEY: '', SUPPORT_EMAIL_ADDRESS: '', }, @@ -80,6 +81,8 @@ const { mockConfig, mockEnv } = vi.hoisted(() => ({ })) vi.mock('@/config', () => ({ get IS_CLOUD_EDITION() { return mockConfig.IS_CLOUD_EDITION }, + get AMPLITUDE_API_KEY() { return mockConfig.AMPLITUDE_API_KEY }, + get isAmplitudeEnabled() { return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY }, get ZENDESK_WIDGET_KEY() { return mockConfig.ZENDESK_WIDGET_KEY }, get SUPPORT_EMAIL_ADDRESS() { return mockConfig.SUPPORT_EMAIL_ADDRESS }, IS_DEV: false, diff --git a/web/app/components/header/app-nav/index.tsx b/web/app/components/header/app-nav/index.tsx index 214b7612bb..54ddf75711 100644 --- a/web/app/components/header/app-nav/index.tsx +++ b/web/app/components/header/app-nav/index.tsx @@ -9,16 +9,18 @@ import { flatten } from 'es-toolkit/compat' import { produce } from 'immer' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import CreateAppTemplateDialog from '@/app/components/app/create-app-dialog' -import CreateAppModal from '@/app/components/app/create-app-modal' -import CreateFromDSLModal from '@/app/components/app/create-from-dsl-modal' import { useStore as useAppStore } from '@/app/components/app/store' import { useAppContext } from '@/context/app-context' +import dynamic from '@/next/dynamic' import { useParams } from '@/next/navigation' import { useInfiniteAppList } from '@/service/use-apps' import { AppModeEnum } from '@/types/app' import Nav from '../nav' +const CreateAppTemplateDialog = dynamic(() => import('@/app/components/app/create-app-dialog'), { ssr: false }) +const CreateAppModal = dynamic(() => import('@/app/components/app/create-app-modal'), { ssr: false }) +const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-from-dsl-modal'), { ssr: false }) + const AppNav = () => { const { t } = useTranslation() const { appId } = useParams() diff --git a/web/app/components/lazy-sentry-initializer.tsx b/web/app/components/lazy-sentry-initializer.tsx new file mode 100644 index 0000000000..8c29ca4f9a --- /dev/null +++ b/web/app/components/lazy-sentry-initializer.tsx @@ -0,0 +1,16 @@ +'use client' + +import { IS_DEV } from '@/config' +import { env } from '@/env' +import dynamic from '@/next/dynamic' + +const SentryInitializer = dynamic(() => import('./sentry-initializer'), { ssr: false }) + +const LazySentryInitializer = () => { + if (IS_DEV || !env.NEXT_PUBLIC_SENTRY_DSN) + return null + + return +} + +export default LazySentryInitializer diff --git a/web/app/components/sentry-initializer.tsx b/web/app/components/sentry-initializer.tsx index 8a7286f908..00c3af37c1 100644 --- a/web/app/components/sentry-initializer.tsx +++ b/web/app/components/sentry-initializer.tsx @@ -2,13 +2,10 @@ import * as Sentry from '@sentry/react' import { useEffect } from 'react' - import { IS_DEV } from '@/config' import { env } from '@/env' -const SentryInitializer = ({ - children, -}: { children: React.ReactElement }) => { +const SentryInitializer = () => { useEffect(() => { const SENTRY_DSN = env.NEXT_PUBLIC_SENTRY_DSN if (!IS_DEV && SENTRY_DSN) { @@ -24,7 +21,7 @@ const SentryInitializer = ({ }) } }, []) - return children + return null } export default SentryInitializer diff --git a/web/app/layout.tsx b/web/app/layout.tsx index be51c76f2e..1cf1bb0d94 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -1,9 +1,7 @@ import type { Viewport } from '@/next' -import { Agentation } from 'agentation' import { Provider as JotaiProvider } from 'jotai/react' import { ThemeProvider } from 'next-themes' import { NuqsAdapter } from 'nuqs/adapters/next/app' -import { IS_DEV } from '@/config' import GlobalPublicStoreProvider from '@/context/global-public-context' import { TanstackQueryInitializer } from '@/context/query-client' import { getDatasetMap } from '@/env' @@ -12,9 +10,10 @@ import { ToastProvider } from './components/base/toast' import { ToastHost } from './components/base/ui/toast' import { TooltipProvider } from './components/base/ui/tooltip' import BrowserInitializer from './components/browser-initializer' +import { AgentationLoader } from './components/devtools/agentation-loader' import { ReactScanLoader } from './components/devtools/react-scan/loader' +import LazySentryInitializer from './components/lazy-sentry-initializer' import { I18nServerProvider } from './components/provider/i18n-server' -import SentryInitializer from './components/sentry-initializer' import RoutePrefixHandle from './routePrefixHandle' import './styles/globals.css' import './styles/markdown.scss' @@ -57,6 +56,7 @@ const LocaleLayout = async ({ className="h-full select-auto" {...datasetMap} > +
- - - - - - - - {children} - - - - - - + + + + + + + {children} + + + + + - {IS_DEV && } +
diff --git a/web/config/index.ts b/web/config/index.ts index 3f7d26c623..eed914726c 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -42,6 +42,8 @@ export const AMPLITUDE_API_KEY = getStringConfig( '', ) +export const isAmplitudeEnabled = IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY + export const IS_DEV = process.env.NODE_ENV === 'development' export const IS_PROD = process.env.NODE_ENV === 'production' diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 0f69e6ce33..c02d302f09 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -1501,11 +1501,6 @@ "count": 2 } }, - "app/components/base/amplitude/AmplitudeProvider.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "app/components/base/amplitude/utils.ts": { "ts/no-explicit-any": { "count": 2