From 6c19e759691b5b94029975853badd22545c9204c Mon Sep 17 00:00:00 2001 From: Dev Sharma <50591491+cryptus-neoxys@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:28:34 +0530 Subject: [PATCH] test: improve unit tests for controllers.web (#32150) Co-authored-by: Rajat Agarwal --- api/controllers/web/message.py | 2 +- .../unit_tests/controllers/web/__init__.py | 0 .../unit_tests/controllers/web/conftest.py | 85 ++++ .../unit_tests/controllers/web/test_app.py | 165 +++++++ .../unit_tests/controllers/web/test_audio.py | 135 ++++++ .../controllers/web/test_completion.py | 161 +++++++ .../controllers/web/test_conversation.py | 183 ++++++++ .../unit_tests/controllers/web/test_error.py | 75 ++++ .../controllers/web/test_feature.py | 38 ++ .../unit_tests/controllers/web/test_files.py | 89 ++++ .../controllers/web/test_message_endpoints.py | 156 +++++++ .../controllers/web/test_passport.py | 103 +++++ .../controllers/web/test_pydantic_models.py | 423 ++++++++++++++++++ .../controllers/web/test_remote_files.py | 147 ++++++ .../controllers/web/test_saved_message.py | 97 ++++ .../unit_tests/controllers/web/test_site.py | 126 ++++++ .../controllers/web/test_web_login.py | 114 ++++- .../controllers/web/test_web_passport.py | 192 ++++++++ .../controllers/web/test_workflow.py | 95 ++++ .../controllers/web/test_workflow_events.py | 127 ++++++ .../unit_tests/controllers/web/test_wraps.py | 393 ++++++++++++++++ 21 files changed, 2904 insertions(+), 2 deletions(-) create mode 100644 api/tests/unit_tests/controllers/web/__init__.py create mode 100644 api/tests/unit_tests/controllers/web/conftest.py create mode 100644 api/tests/unit_tests/controllers/web/test_app.py create mode 100644 api/tests/unit_tests/controllers/web/test_audio.py create mode 100644 api/tests/unit_tests/controllers/web/test_completion.py create mode 100644 api/tests/unit_tests/controllers/web/test_conversation.py create mode 100644 api/tests/unit_tests/controllers/web/test_error.py create mode 100644 api/tests/unit_tests/controllers/web/test_feature.py create mode 100644 api/tests/unit_tests/controllers/web/test_files.py create mode 100644 api/tests/unit_tests/controllers/web/test_message_endpoints.py create mode 100644 api/tests/unit_tests/controllers/web/test_passport.py create mode 100644 api/tests/unit_tests/controllers/web/test_pydantic_models.py create mode 100644 api/tests/unit_tests/controllers/web/test_remote_files.py create mode 100644 api/tests/unit_tests/controllers/web/test_saved_message.py create mode 100644 api/tests/unit_tests/controllers/web/test_site.py create mode 100644 api/tests/unit_tests/controllers/web/test_web_passport.py create mode 100644 api/tests/unit_tests/controllers/web/test_workflow.py create mode 100644 api/tests/unit_tests/controllers/web/test_workflow_events.py create mode 100644 api/tests/unit_tests/controllers/web/test_wraps.py diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index bbae1ce266..2b60691949 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: - raise NotCompletionAppError() + raise NotChatAppError() message_id = str(message_id) diff --git a/api/tests/unit_tests/controllers/web/__init__.py b/api/tests/unit_tests/controllers/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/web/conftest.py b/api/tests/unit_tests/controllers/web/conftest.py new file mode 100644 index 0000000000..274d78c9cf --- /dev/null +++ b/api/tests/unit_tests/controllers/web/conftest.py @@ -0,0 +1,85 @@ +"""Shared fixtures for controllers.web unit tests.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from flask import Flask + + +@pytest.fixture +def app() -> Flask: + """Minimal Flask app for request contexts.""" + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class FakeSession: + """Stand-in for db.session that returns pre-seeded objects by model class name.""" + + def __init__(self, mapping: dict[str, Any] | None = None): + self._mapping: dict[str, Any] = mapping or {} + self._model_name: str | None = None + + def query(self, model: type) -> FakeSession: + self._model_name = model.__name__ + return self + + def where(self, *_args: object, **_kwargs: object) -> FakeSession: + return self + + def first(self) -> Any: + assert self._model_name is not None + return self._mapping.get(self._model_name) + + +class FakeDB: + """Minimal db stub exposing engine and session.""" + + def __init__(self, session: FakeSession | None = None): + self.session = session or FakeSession() + self.engine = object() + + +def make_app_model( + *, + app_id: str = "app-1", + tenant_id: str = "tenant-1", + mode: str = "chat", + enable_site: bool = True, + status: str = "normal", +) -> SimpleNamespace: + """Build a fake App model with common defaults.""" + tenant = SimpleNamespace( + id=tenant_id, + status="normal", + plan="basic", + custom_config_dict={}, + ) + return SimpleNamespace( + id=app_id, + tenant_id=tenant_id, + tenant=tenant, + mode=mode, + enable_site=enable_site, + status=status, + workflow=None, + app_model_config=None, + ) + + +def make_end_user( + *, + user_id: str = "end-user-1", + session_id: str = "session-1", + external_user_id: str = "ext-user-1", +) -> SimpleNamespace: + """Build a fake EndUser model with common defaults.""" + return SimpleNamespace( + id=user_id, + session_id=session_id, + external_user_id=external_user_id, + ) diff --git a/api/tests/unit_tests/controllers/web/test_app.py b/api/tests/unit_tests/controllers/web/test_app.py new file mode 100644 index 0000000000..ce7ae27188 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_app.py @@ -0,0 +1,165 @@ +"""Unit tests for controllers.web.app endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission +from controllers.web.error import AppUnavailableError + + +# --------------------------------------------------------------------------- +# AppParameterApi +# --------------------------------------------------------------------------- +class TestAppParameterApi: + def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None: + features_dict = {"opening_statement": "Hello"} + workflow = SimpleNamespace( + features_dict=features_dict, + user_input_form=lambda to_old_structure=False: [], + ) + app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"} + result = AppParameterApi().get(app_model, SimpleNamespace()) + + mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[]) + assert result == {"result": "ok"} + + def test_workflow_mode_uses_workflow(self, app: Flask) -> None: + features_dict = {} + workflow = SimpleNamespace( + features_dict=features_dict, + user_input_form=lambda to_old_structure=False: [{"var": "x"}], + ) + app_model = SimpleNamespace(mode="workflow", workflow=workflow) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {} + AppParameterApi().get(app_model, SimpleNamespace()) + + mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}]) + + def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None: + app_model = SimpleNamespace(mode="advanced-chat", workflow=None) + with app.test_request_context("/parameters"): + with pytest.raises(AppUnavailableError): + AppParameterApi().get(app_model, SimpleNamespace()) + + def test_standard_mode_uses_app_model_config(self, app: Flask) -> None: + config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"}) + app_model = SimpleNamespace(mode="chat", app_model_config=config) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {} + AppParameterApi().get(app_model, SimpleNamespace()) + + call_kwargs = mock_params.call_args + assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}] + + def test_standard_mode_no_config_raises(self, app: Flask) -> None: + app_model = SimpleNamespace(mode="chat", app_model_config=None) + with app.test_request_context("/parameters"): + with pytest.raises(AppUnavailableError): + AppParameterApi().get(app_model, SimpleNamespace()) + + +# --------------------------------------------------------------------------- +# AppMeta +# --------------------------------------------------------------------------- +class TestAppMeta: + @patch("controllers.web.app.AppService") + def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None: + mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}} + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context("/meta"): + result = AppMeta().get(app_model, SimpleNamespace()) + + assert result == {"tool_icons": {}} + + +# --------------------------------------------------------------------------- +# AppAccessMode +# --------------------------------------------------------------------------- +class TestAppAccessMode: + @patch("controllers.web.app.FeatureService.get_system_features") + def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + with app.test_request_context("/webapp/access-mode?appId=app-1"): + result = AppAccessMode().get() + + assert result == {"accessMode": "public"} + + @patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.app.FeatureService.get_system_features") + def test_returns_access_mode_with_app_id( + self, mock_features: MagicMock, mock_access: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_access.return_value = SimpleNamespace(access_mode="internal") + + with app.test_request_context("/webapp/access-mode?appId=app-1"): + result = AppAccessMode().get() + + assert result == {"accessMode": "internal"} + mock_access.assert_called_once_with("app-1") + + @patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id") + @patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.app.FeatureService.get_system_features") + def test_resolves_app_code_to_id( + self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_access.return_value = SimpleNamespace(access_mode="external") + + with app.test_request_context("/webapp/access-mode?appCode=code1"): + result = AppAccessMode().get() + + mock_resolve.assert_called_once_with("code1") + mock_access.assert_called_once_with("resolved-id") + assert result == {"accessMode": "external"} + + @patch("controllers.web.app.FeatureService.get_system_features") + def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + + with app.test_request_context("/webapp/access-mode"): + with pytest.raises(ValueError, match="appId or appCode"): + AppAccessMode().get() + + +# --------------------------------------------------------------------------- +# AppWebAuthPermission +# --------------------------------------------------------------------------- +class TestAppWebAuthPermission: + @patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None: + with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}): + result = AppWebAuthPermission().get() + + assert result == {"result": True} + + def test_raises_when_missing_app_id(self, app: Flask) -> None: + with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}): + with pytest.raises(ValueError, match="appId"): + AppWebAuthPermission().get() diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py new file mode 100644 index 0000000000..01f34345aa --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -0,0 +1,135 @@ +"""Unit tests for controllers.web.audio endpoints.""" + +from __future__ import annotations + +from io import BytesIO +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.audio import AudioApi, TextApi +from controllers.web.error import ( + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1", external_user_id="ext-1") + + +# --------------------------------------------------------------------------- +# AudioApi (audio-to-text) +# --------------------------------------------------------------------------- +class TestAudioApi: + @patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"}) + def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + data = {"file": (BytesIO(b"fake-audio"), "test.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + result = AudioApi().post(_app_model(), _end_user()) + + assert result == {"text": "hello"} + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError()) + def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b""), "empty.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(NoAudioUploadedError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big")) + def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"big"), "big.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(AudioTooLargeError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError()) + def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"bad"), "bad.xyz")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(UnsupportedAudioTypeError): + AudioApi().post(_app_model(), _end_user()) + + @patch( + "controllers.web.audio.AudioService.transcript_asr", + side_effect=ProviderNotSupportSpeechToTextServiceError(), + ) + def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderNotSupportSpeechToTextError): + AudioApi().post(_app_model(), _end_user()) + + @patch( + "controllers.web.audio.AudioService.transcript_asr", + side_effect=ProviderTokenNotInitError(description="no token"), + ) + def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderNotInitializeError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError()) + def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderQuotaExceededError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError()) + def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + AudioApi().post(_app_model(), _end_user()) + + +# --------------------------------------------------------------------------- +# TextApi (text-to-audio) +# --------------------------------------------------------------------------- +class TestTextApi: + @patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes") + @patch("controllers.web.audio.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None: + mock_ns.payload = {"text": "hello", "voice": "alloy"} + + with app.test_request_context("/text-to-audio", method="POST"): + result = TextApi().post(_app_model(), _end_user()) + + assert result == "audio-bytes" + mock_tts.assert_called_once() + + @patch( + "controllers.web.audio.AudioService.transcript_tts", + side_effect=InvokeError(description="invoke failed"), + ) + @patch("controllers.web.audio.web_ns") + def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None: + mock_ns.payload = {"text": "hello"} + + with app.test_request_context("/text-to-audio", method="POST"): + with pytest.raises(CompletionRequestError): + TextApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py new file mode 100644 index 0000000000..e88bcf2ae6 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -0,0 +1,161 @@ +"""Unit tests for controllers.web.completion endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi +from controllers.web.error import ( + CompletionRequestError, + NotChatAppError, + NotCompletionAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeError + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# CompletionApi +# --------------------------------------------------------------------------- +class TestCompletionApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(NotCompletionAppError): + CompletionApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"}) + @patch("controllers.web.completion.AppGenerateService.generate") + @patch("controllers.web.completion.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "test"} + mock_gen.return_value = "response-obj" + + with app.test_request_context("/completion-messages", method="POST"): + result = CompletionApi().post(_completion_app(), _end_user()) + + assert result == {"answer": "hi"} + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=ProviderTokenNotInitError(description="not init"), + ) + @patch("controllers.web.completion.web_ns") + def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderNotInitializeError): + CompletionApi().post(_completion_app(), _end_user()) + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=QuotaExceededError(), + ) + @patch("controllers.web.completion.web_ns") + def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderQuotaExceededError): + CompletionApi().post(_completion_app(), _end_user()) + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=ModelCurrentlyNotSupportError(), + ) + @patch("controllers.web.completion.web_ns") + def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + CompletionApi().post(_completion_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# CompletionStopApi +# --------------------------------------------------------------------------- +class TestCompletionStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/completion-messages/task-1/stop", method="POST"): + with pytest.raises(NotCompletionAppError): + CompletionStopApi().post(_chat_app(), _end_user(), "task-1") + + @patch("controllers.web.completion.AppTaskService.stop_task") + def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None: + with app.test_request_context("/completion-messages/task-1/stop", method="POST"): + result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1") + + assert status == 200 + assert result == {"result": "success"} + + +# --------------------------------------------------------------------------- +# ChatApi +# --------------------------------------------------------------------------- +class TestChatApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/chat-messages", method="POST"): + with pytest.raises(NotChatAppError): + ChatApi().post(_completion_app(), _end_user()) + + @patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"}) + @patch("controllers.web.completion.AppGenerateService.generate") + @patch("controllers.web.completion.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "hi"} + mock_gen.return_value = "response" + + with app.test_request_context("/chat-messages", method="POST"): + result = ChatApi().post(_chat_app(), _end_user()) + + assert result == {"answer": "reply"} + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=InvokeError(description="rate limit"), + ) + @patch("controllers.web.completion.web_ns") + def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "x"} + + with app.test_request_context("/chat-messages", method="POST"): + with pytest.raises(CompletionRequestError): + ChatApi().post(_chat_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# ChatStopApi +# --------------------------------------------------------------------------- +class TestChatStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/chat-messages/task-1/stop", method="POST"): + with pytest.raises(NotChatAppError): + ChatStopApi().post(_completion_app(), _end_user(), "task-1") + + @patch("controllers.web.completion.AppTaskService.stop_task") + def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None: + with app.test_request_context("/chat-messages/task-1/stop", method="POST"): + result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1") + + assert status == 200 + assert result == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/web/test_conversation.py b/api/tests/unit_tests/controllers/web/test_conversation.py new file mode 100644 index 0000000000..e5adbbbf66 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_conversation.py @@ -0,0 +1,183 @@ +"""Unit tests for controllers.web.conversation endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.conversation import ( + ConversationApi, + ConversationListApi, + ConversationPinApi, + ConversationRenameApi, + ConversationUnPinApi, +) +from controllers.web.error import NotChatAppError +from services.errors.conversation import ConversationNotExistsError + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# ConversationListApi +# --------------------------------------------------------------------------- +class TestConversationListApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/conversations"): + with pytest.raises(NotChatAppError): + ConversationListApi().get(_completion_app(), _end_user()) + + @patch("controllers.web.conversation.WebConversationService.pagination_by_last_id") + @patch("controllers.web.conversation.db") + def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None: + conv_id = str(uuid4()) + conv = SimpleNamespace( + id=conv_id, + name="Test", + inputs={}, + status="normal", + introduction="", + created_at=1700000000, + updated_at=1700000000, + ) + mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv]) + mock_db.engine = "engine" + + session_mock = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + + with ( + app.test_request_context("/conversations?limit=20"), + patch("controllers.web.conversation.Session", return_value=session_ctx), + ): + result = ConversationListApi().get(_chat_app(), _end_user()) + + assert result["limit"] == 20 + assert result["has_more"] is False + + +# --------------------------------------------------------------------------- +# ConversationApi (delete) +# --------------------------------------------------------------------------- +class TestConversationApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}"): + with pytest.raises(NotChatAppError): + ConversationApi().delete(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.ConversationService.delete") + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}"): + result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id) + + assert status == 204 + assert result["result"] == "success" + + @patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError()) + def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}"): + with pytest.raises(NotFound, match="Conversation Not Exists"): + ConversationApi().delete(_chat_app(), _end_user(), c_id) + + +# --------------------------------------------------------------------------- +# ConversationRenameApi +# --------------------------------------------------------------------------- +class TestConversationRenameApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}): + with pytest.raises(NotChatAppError): + ConversationRenameApi().post(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.ConversationService.rename") + @patch("controllers.web.conversation.web_ns") + def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: + c_id = uuid4() + mock_ns.payload = {"name": "New Name", "auto_generate": False} + conv = SimpleNamespace( + id=str(c_id), + name="New Name", + inputs={}, + status="normal", + introduction="", + created_at=1700000000, + updated_at=1700000000, + ) + mock_rename.return_value = conv + + with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}): + result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id) + + assert result["name"] == "New Name" + + @patch( + "controllers.web.conversation.ConversationService.rename", + side_effect=ConversationNotExistsError(), + ) + @patch("controllers.web.conversation.web_ns") + def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: + c_id = uuid4() + mock_ns.payload = {"name": "X", "auto_generate": False} + + with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}): + with pytest.raises(NotFound, match="Conversation Not Exists"): + ConversationRenameApi().post(_chat_app(), _end_user(), c_id) + + +# --------------------------------------------------------------------------- +# ConversationPinApi / ConversationUnPinApi +# --------------------------------------------------------------------------- +class TestConversationPinApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"): + with pytest.raises(NotChatAppError): + ConversationPinApi().patch(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.WebConversationService.pin") + def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): + result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id) + + assert result["result"] == "success" + + @patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError()) + def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): + with pytest.raises(NotFound): + ConversationPinApi().patch(_chat_app(), _end_user(), c_id) + + +class TestConversationUnPinApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"): + with pytest.raises(NotChatAppError): + ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.WebConversationService.unpin") + def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"): + result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id) + + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/web/test_error.py b/api/tests/unit_tests/controllers/web/test_error.py new file mode 100644 index 0000000000..0387d002ba --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_error.py @@ -0,0 +1,75 @@ +"""Unit tests for controllers.web.error HTTP exception classes.""" + +from __future__ import annotations + +import pytest + +from controllers.web.error import ( + AppMoreLikeThisDisabledError, + AppSuggestedQuestionsAfterAnswerDisabledError, + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + ConversationCompletedError, + InvalidArgumentError, + InvokeRateLimitError, + NoAudioUploadedError, + NotChatAppError, + NotCompletionAppError, + NotFoundError, + NotWorkflowAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, + WebAppAuthAccessDeniedError, + WebAppAuthRequiredError, + WebFormRateLimitExceededError, +) + +_ERROR_SPECS: list[tuple[type, str, int]] = [ + (AppUnavailableError, "app_unavailable", 400), + (NotCompletionAppError, "not_completion_app", 400), + (NotChatAppError, "not_chat_app", 400), + (NotWorkflowAppError, "not_workflow_app", 400), + (ConversationCompletedError, "conversation_completed", 400), + (ProviderNotInitializeError, "provider_not_initialize", 400), + (ProviderQuotaExceededError, "provider_quota_exceeded", 400), + (ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400), + (CompletionRequestError, "completion_request_error", 400), + (AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403), + (AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403), + (NoAudioUploadedError, "no_audio_uploaded", 400), + (AudioTooLargeError, "audio_too_large", 413), + (UnsupportedAudioTypeError, "unsupported_audio_type", 415), + (ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400), + (WebAppAuthRequiredError, "web_sso_auth_required", 401), + (WebAppAuthAccessDeniedError, "web_app_access_denied", 401), + (InvokeRateLimitError, "rate_limit_error", 429), + (WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429), + (NotFoundError, "not_found", 404), + (InvalidArgumentError, "invalid_param", 400), +] + + +@pytest.mark.parametrize( + ("cls", "expected_code", "expected_status"), + _ERROR_SPECS, + ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS], +) +def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None: + """Each error class exposes the correct error_code and HTTP status code.""" + assert cls.error_code == expected_code + assert cls.code == expected_status + + +def test_error_classes_have_description() -> None: + """Every error class has a description (string or None for generic errors).""" + # NotFoundError and InvalidArgumentError use None description by design + _NO_DESCRIPTION = {NotFoundError, InvalidArgumentError} + for cls, _, _ in _ERROR_SPECS: + if cls in _NO_DESCRIPTION: + continue + assert isinstance(cls.description, str), f"{cls.__name__} missing description" + assert len(cls.description) > 0, f"{cls.__name__} has empty description" diff --git a/api/tests/unit_tests/controllers/web/test_feature.py b/api/tests/unit_tests/controllers/web/test_feature.py new file mode 100644 index 0000000000..fe45d5f059 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_feature.py @@ -0,0 +1,38 @@ +"""Unit tests for controllers.web.feature endpoints.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from flask import Flask + +from controllers.web.feature import SystemFeatureApi + + +class TestSystemFeatureApi: + @patch("controllers.web.feature.FeatureService.get_system_features") + def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None: + mock_model = MagicMock() + mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}} + mock_features.return_value = mock_model + + with app.test_request_context("/system-features"): + result = SystemFeatureApi().get() + + assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}} + mock_features.assert_called_once() + + @patch("controllers.web.feature.FeatureService.get_system_features") + def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None: + """SystemFeatureApi is unauthenticated by design — no WebApiResource decorator.""" + mock_model = MagicMock() + mock_model.model_dump.return_value = {} + mock_features.return_value = mock_model + + # Verify it's a bare Resource, not WebApiResource + from flask_restx import Resource + + from controllers.web.wraps import WebApiResource + + assert issubclass(SystemFeatureApi, Resource) + assert not issubclass(SystemFeatureApi, WebApiResource) diff --git a/api/tests/unit_tests/controllers/web/test_files.py b/api/tests/unit_tests/controllers/web/test_files.py new file mode 100644 index 0000000000..a3921b0373 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_files.py @@ -0,0 +1,89 @@ +"""Unit tests for controllers.web.files endpoints.""" + +from __future__ import annotations + +from io import BytesIO +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, +) +from controllers.web.files import FileApi + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +class TestFileApi: + def test_no_file_uploaded(self, app: Flask) -> None: + with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"): + with pytest.raises(NoFileUploadedError): + FileApi().post(_app_model(), _end_user()) + + def test_too_many_files(self, app: Flask) -> None: + data = { + "file": (BytesIO(b"a"), "a.txt"), + "file2": (BytesIO(b"b"), "b.txt"), + } + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + # Now has "file" key but len(request.files) > 1 + with pytest.raises(TooManyFilesError): + FileApi().post(_app_model(), _end_user()) + + def test_filename_missing(self, app: Flask) -> None: + data = {"file": (BytesIO(b"content"), "")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(FilenameNotExistsError): + FileApi().post(_app_model(), _end_user()) + + @patch("controllers.web.files.FileService") + @patch("controllers.web.files.db") + def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + from datetime import datetime + + upload_file = SimpleNamespace( + id="file-1", + name="test.txt", + size=100, + extension="txt", + mime_type="text/plain", + created_by="eu-1", + created_at=datetime(2024, 1, 1), + ) + mock_file_svc_cls.return_value.upload_file.return_value = upload_file + + data = {"file": (BytesIO(b"content"), "test.txt")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + result, status = FileApi().post(_app_model(), _end_user()) + + assert status == 201 + assert result["id"] == "file-1" + assert result["name"] == "test.txt" + + @patch("controllers.web.files.FileService") + @patch("controllers.web.files.db") + def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None: + import services.errors.file + + mock_db.engine = "engine" + mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError( + description="max 10MB" + ) + + data = {"file": (BytesIO(b"big"), "big.txt")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(FileTooLargeError): + FileApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_message_endpoints.py b/api/tests/unit_tests/controllers/web/test_message_endpoints.py new file mode 100644 index 0000000000..89ab93d8d4 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_message_endpoints.py @@ -0,0 +1,156 @@ +"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.error import ( + AppMoreLikeThisDisabledError, + NotChatAppError, + NotCompletionAppError, +) +from controllers.web.message import ( + MessageFeedbackApi, + MessageMoreLikeThisApi, + MessageSuggestedQuestionApi, +) +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# MessageFeedbackApi +# --------------------------------------------------------------------------- +class TestMessageFeedbackApi: + @patch("controllers.web.message.MessageService.create_feedback") + @patch("controllers.web.message.web_ns") + def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": "like", "content": "great"} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + assert result == {"result": "success"} + mock_create.assert_called_once() + + @patch("controllers.web.message.MessageService.create_feedback") + @patch("controllers.web.message.web_ns") + def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": None} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + assert result == {"result": "success"} + + @patch( + "controllers.web.message.MessageService.create_feedback", + side_effect=MessageNotExistsError(), + ) + @patch("controllers.web.message.web_ns") + def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": "dislike"} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + with pytest.raises(NotFound, match="Message Not Exists"): + MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + +# --------------------------------------------------------------------------- +# MessageMoreLikeThisApi +# --------------------------------------------------------------------------- +class TestMessageMoreLikeThisApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(NotCompletionAppError): + MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id) + + @patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"}) + @patch("controllers.web.message.AppGenerateService.generate_more_like_this") + def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + msg_id = uuid4() + mock_gen.return_value = "response" + + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + assert result == {"answer": "similar"} + + @patch( + "controllers.web.message.AppGenerateService.generate_more_like_this", + side_effect=MessageNotExistsError(), + ) + def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(NotFound, match="Message Not Exists"): + MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + @patch( + "controllers.web.message.AppGenerateService.generate_more_like_this", + side_effect=MoreLikeThisDisabledError(), + ) + def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(AppMoreLikeThisDisabledError): + MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + +# --------------------------------------------------------------------------- +# MessageSuggestedQuestionApi +# --------------------------------------------------------------------------- +class TestMessageSuggestedQuestionApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotChatAppError): + MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) + + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotChatAppError): + MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) + + @patch("controllers.web.message.MessageService.get_suggested_questions_after_answer") + def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None: + msg_id = uuid4() + mock_suggest.return_value = ["What about X?", "Tell me more about Y."] + + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id) + + assert result["data"] == ["What about X?", "Tell me more about Y."] + + @patch( + "controllers.web.message.MessageService.get_suggested_questions_after_answer", + side_effect=MessageNotExistsError(), + ) + def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotFound, match="Message not found"): + MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id) diff --git a/api/tests/unit_tests/controllers/web/test_passport.py b/api/tests/unit_tests/controllers/web/test_passport.py new file mode 100644 index 0000000000..58d58626b2 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_passport.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.web.error import WebAppAuthRequiredError +from controllers.web.passport import ( + PassportService, + decode_enterprise_webapp_user_id, + exchange_token_for_existing_web_user, + generate_session_id, +) +from services.webapp_auth_service import WebAppAuthType + + +def test_decode_enterprise_webapp_user_id_none() -> None: + assert decode_enterprise_webapp_user_id(None) is None + + +def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"}) + with pytest.raises(Unauthorized): + decode_enterprise_webapp_user_id("token") + + +def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None: + decoded = {"token_source": "webapp_login_token", "user_id": "u1"} + monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded) + assert decode_enterprise_webapp_user_id("token") == decoded + + +def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True) + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + return site if _scalar_side_effect.calls == 1 else app_model + + db_session = SimpleNamespace(scalar=_scalar_side_effect) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp") + + decoded = {"auth_type": "public"} + result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC) + assert result == "resp" + + +def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True) + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + return site if _scalar_side_effect.calls == 1 else app_model + + db_session = SimpleNamespace(scalar=_scalar_side_effect) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + decoded = {"auth_type": "internal"} + with pytest.raises(WebAppAuthRequiredError): + exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL) + + +def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1") + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + if _scalar_side_effect.calls == 1: + return site + if _scalar_side_effect.calls == 2: + return app_model + return None + + db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + decoded = {"auth_type": "internal"} + with pytest.raises(NotFound): + exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL) + + +def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + counts = [1, 0] + + def _scalar(*_args, **_kwargs): + return counts.pop(0) + + db_session = SimpleNamespace(scalar=_scalar) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + session_id = generate_session_id() + assert session_id diff --git a/api/tests/unit_tests/controllers/web/test_pydantic_models.py b/api/tests/unit_tests/controllers/web/test_pydantic_models.py new file mode 100644 index 0000000000..dcf8133712 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_pydantic_models.py @@ -0,0 +1,423 @@ +"""Unit tests for Pydantic models defined in controllers.web modules. + +Covers validation logic, field defaults, constraints, and custom validators +for all ~15 Pydantic models across the web controller layer. +""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +# --------------------------------------------------------------------------- +# app.py models +# --------------------------------------------------------------------------- +from controllers.web.app import AppAccessModeQuery + + +class TestAppAccessModeQuery: + def test_alias_resolution(self) -> None: + q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"}) + assert q.app_id == "abc" + assert q.app_code == "xyz" + + def test_defaults_to_none(self) -> None: + q = AppAccessModeQuery.model_validate({}) + assert q.app_id is None + assert q.app_code is None + + def test_accepts_snake_case(self) -> None: + q = AppAccessModeQuery(app_id="id1", app_code="code1") + assert q.app_id == "id1" + assert q.app_code == "code1" + + +# --------------------------------------------------------------------------- +# audio.py models +# --------------------------------------------------------------------------- +from controllers.web.audio import TextToAudioPayload + + +class TestTextToAudioPayload: + def test_defaults(self) -> None: + p = TextToAudioPayload.model_validate({}) + assert p.message_id is None + assert p.voice is None + assert p.text is None + assert p.streaming is None + + def test_valid_uuid_message_id(self) -> None: + uid = str(uuid4()) + p = TextToAudioPayload(message_id=uid) + assert p.message_id == uid + + def test_none_message_id_passthrough(self) -> None: + p = TextToAudioPayload(message_id=None) + assert p.message_id is None + + def test_invalid_uuid_message_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + TextToAudioPayload(message_id="not-a-uuid") + + +# --------------------------------------------------------------------------- +# completion.py models +# --------------------------------------------------------------------------- +from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload + + +class TestCompletionMessagePayload: + def test_defaults(self) -> None: + p = CompletionMessagePayload(inputs={}) + assert p.query == "" + assert p.files is None + assert p.response_mode is None + assert p.retriever_from == "web_app" + + def test_accepts_full_payload(self) -> None: + p = CompletionMessagePayload( + inputs={"key": "val"}, + query="test", + files=[{"id": "f1"}], + response_mode="streaming", + ) + assert p.response_mode == "streaming" + assert p.files == [{"id": "f1"}] + + def test_invalid_response_mode(self) -> None: + with pytest.raises(ValidationError): + CompletionMessagePayload(inputs={}, response_mode="invalid") + + +class TestChatMessagePayload: + def test_valid_uuid_fields(self) -> None: + cid = str(uuid4()) + pid = str(uuid4()) + p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid) + assert p.conversation_id == cid + assert p.parent_message_id == pid + + def test_none_uuid_fields(self) -> None: + p = ChatMessagePayload(inputs={}, query="hi") + assert p.conversation_id is None + assert p.parent_message_id is None + + def test_invalid_conversation_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ChatMessagePayload(inputs={}, query="hi", conversation_id="bad") + + def test_invalid_parent_message_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad") + + def test_query_required(self) -> None: + with pytest.raises(ValidationError): + ChatMessagePayload(inputs={}) + + +# --------------------------------------------------------------------------- +# conversation.py models +# --------------------------------------------------------------------------- +from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload + + +class TestConversationListQuery: + def test_defaults(self) -> None: + q = ConversationListQuery() + assert q.last_id is None + assert q.limit == 20 + assert q.pinned is None + assert q.sort_by == "-updated_at" + + def test_limit_lower_bound(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(limit=0) + + def test_limit_upper_bound(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(limit=101) + + def test_limit_boundaries_valid(self) -> None: + assert ConversationListQuery(limit=1).limit == 1 + assert ConversationListQuery(limit=100).limit == 100 + + def test_valid_sort_by_options(self) -> None: + for opt in ("created_at", "-created_at", "updated_at", "-updated_at"): + assert ConversationListQuery(sort_by=opt).sort_by == opt + + def test_invalid_sort_by(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(sort_by="invalid") + + def test_valid_last_id(self) -> None: + uid = str(uuid4()) + assert ConversationListQuery(last_id=uid).last_id == uid + + def test_invalid_last_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ConversationListQuery(last_id="not-uuid") + + +class TestConversationRenamePayload: + def test_auto_generate_true_no_name_required(self) -> None: + p = ConversationRenamePayload(auto_generate=True) + assert p.name is None + + def test_auto_generate_false_requires_name(self) -> None: + with pytest.raises(ValidationError, match="name is required"): + ConversationRenamePayload(auto_generate=False) + + def test_auto_generate_false_blank_name_rejected(self) -> None: + with pytest.raises(ValidationError, match="name is required"): + ConversationRenamePayload(auto_generate=False, name=" ") + + def test_auto_generate_false_with_valid_name(self) -> None: + p = ConversationRenamePayload(auto_generate=False, name="My Chat") + assert p.name == "My Chat" + + def test_defaults(self) -> None: + p = ConversationRenamePayload(name="test") + assert p.auto_generate is False + assert p.name == "test" + + +# --------------------------------------------------------------------------- +# message.py models +# --------------------------------------------------------------------------- +from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery + + +class TestMessageListQuery: + def test_valid_query(self) -> None: + cid = str(uuid4()) + q = MessageListQuery(conversation_id=cid) + assert q.conversation_id == cid + assert q.first_id is None + assert q.limit == 20 + + def test_invalid_conversation_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + MessageListQuery(conversation_id="bad") + + def test_limit_bounds(self) -> None: + cid = str(uuid4()) + with pytest.raises(ValidationError): + MessageListQuery(conversation_id=cid, limit=0) + with pytest.raises(ValidationError): + MessageListQuery(conversation_id=cid, limit=101) + + def test_valid_first_id(self) -> None: + cid = str(uuid4()) + fid = str(uuid4()) + q = MessageListQuery(conversation_id=cid, first_id=fid) + assert q.first_id == fid + + def test_invalid_first_id(self) -> None: + cid = str(uuid4()) + with pytest.raises(ValidationError, match="not a valid uuid"): + MessageListQuery(conversation_id=cid, first_id="invalid") + + +class TestMessageFeedbackPayload: + def test_defaults(self) -> None: + p = MessageFeedbackPayload() + assert p.rating is None + assert p.content is None + + def test_valid_ratings(self) -> None: + assert MessageFeedbackPayload(rating="like").rating == "like" + assert MessageFeedbackPayload(rating="dislike").rating == "dislike" + + def test_invalid_rating(self) -> None: + with pytest.raises(ValidationError): + MessageFeedbackPayload(rating="neutral") + + +class TestMessageMoreLikeThisQuery: + def test_valid_modes(self) -> None: + assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking" + assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming" + + def test_invalid_mode(self) -> None: + with pytest.raises(ValidationError): + MessageMoreLikeThisQuery(response_mode="invalid") + + def test_required(self) -> None: + with pytest.raises(ValidationError): + MessageMoreLikeThisQuery() + + +# --------------------------------------------------------------------------- +# remote_files.py models +# --------------------------------------------------------------------------- +from controllers.web.remote_files import RemoteFileUploadPayload + + +class TestRemoteFileUploadPayload: + def test_valid_url(self) -> None: + p = RemoteFileUploadPayload(url="https://example.com/file.pdf") + assert str(p.url) == "https://example.com/file.pdf" + + def test_invalid_url(self) -> None: + with pytest.raises(ValidationError): + RemoteFileUploadPayload(url="not-a-url") + + def test_url_required(self) -> None: + with pytest.raises(ValidationError): + RemoteFileUploadPayload() + + +# --------------------------------------------------------------------------- +# saved_message.py models +# --------------------------------------------------------------------------- +from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery + + +class TestSavedMessageListQuery: + def test_defaults(self) -> None: + q = SavedMessageListQuery() + assert q.last_id is None + assert q.limit == 20 + + def test_limit_bounds(self) -> None: + with pytest.raises(ValidationError): + SavedMessageListQuery(limit=0) + with pytest.raises(ValidationError): + SavedMessageListQuery(limit=101) + + def test_valid_last_id(self) -> None: + uid = str(uuid4()) + q = SavedMessageListQuery(last_id=uid) + assert q.last_id == uid + + def test_empty_last_id(self) -> None: + q = SavedMessageListQuery(last_id="") + assert q.last_id == "" + + +class TestSavedMessageCreatePayload: + def test_valid_message_id(self) -> None: + uid = str(uuid4()) + p = SavedMessageCreatePayload(message_id=uid) + assert p.message_id == uid + + def test_required(self) -> None: + with pytest.raises(ValidationError): + SavedMessageCreatePayload() + + +# --------------------------------------------------------------------------- +# workflow.py models +# --------------------------------------------------------------------------- +from controllers.web.workflow import WorkflowRunPayload + + +class TestWorkflowRunPayload: + def test_defaults(self) -> None: + p = WorkflowRunPayload(inputs={}) + assert p.inputs == {} + assert p.files is None + + def test_with_files(self) -> None: + p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}]) + assert p.files == [{"id": "f1"}] + + def test_inputs_required(self) -> None: + with pytest.raises(ValidationError): + WorkflowRunPayload() + + +# --------------------------------------------------------------------------- +# forgot_password.py models +# --------------------------------------------------------------------------- +from controllers.web.forgot_password import ( + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordSendPayload, +) + + +class TestForgotPasswordSendPayload: + def test_valid_email(self) -> None: + p = ForgotPasswordSendPayload(email="user@example.com") + assert p.email == "user@example.com" + + def test_invalid_email(self) -> None: + with pytest.raises(ValidationError, match="not a valid email"): + ForgotPasswordSendPayload(email="not-an-email") + + def test_language_optional(self) -> None: + p = ForgotPasswordSendPayload(email="a@b.com") + assert p.language is None + + +class TestForgotPasswordCheckPayload: + def test_valid(self) -> None: + p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok") + assert p.email == "a@b.com" + assert p.code == "1234" + assert p.token == "tok" + + def test_empty_token_rejected(self) -> None: + with pytest.raises(ValidationError): + ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="") + + +class TestForgotPasswordResetPayload: + def test_valid_passwords(self) -> None: + p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234") + assert p.new_password == "Valid1234" + + def test_weak_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short") + + def test_letters_only_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi") + + def test_digits_only_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789") + + +# --------------------------------------------------------------------------- +# login.py models +# --------------------------------------------------------------------------- +from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload + + +class TestLoginPayload: + def test_valid(self) -> None: + p = LoginPayload(email="a@b.com", password="Valid1234") + assert p.email == "a@b.com" + + def test_invalid_email(self) -> None: + with pytest.raises(ValidationError, match="not a valid email"): + LoginPayload(email="bad", password="Valid1234") + + def test_weak_password(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + LoginPayload(email="a@b.com", password="weak") + + +class TestEmailCodeLoginSendPayload: + def test_valid(self) -> None: + p = EmailCodeLoginSendPayload(email="a@b.com") + assert p.language is None + + def test_with_language(self) -> None: + p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans") + assert p.language == "zh-Hans" + + +class TestEmailCodeLoginVerifyPayload: + def test_valid(self) -> None: + p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok") + assert p.code == "1234" + + def test_empty_token_rejected(self) -> None: + with pytest.raises(ValidationError): + EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="") diff --git a/api/tests/unit_tests/controllers/web/test_remote_files.py b/api/tests/unit_tests/controllers/web/test_remote_files.py new file mode 100644 index 0000000000..8554f440b7 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_remote_files.py @@ -0,0 +1,147 @@ +"""Unit tests for controllers.web.remote_files endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.common.errors import FileTooLargeError, RemoteFileUploadError +from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# RemoteFileInfoApi +# --------------------------------------------------------------------------- +class TestRemoteFileInfoApi: + @patch("controllers.web.remote_files.ssrf_proxy") + def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None: + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"} + mock_proxy.head.return_value = mock_resp + + with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"): + result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf") + + assert result["file_type"] == "application/pdf" + assert result["file_length"] == 1024 + + @patch("controllers.web.remote_files.ssrf_proxy") + def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None: + head_resp = MagicMock() + head_resp.status_code = 405 # Method not allowed + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"} + get_resp.raise_for_status = MagicMock() + mock_proxy.head.return_value = head_resp + mock_proxy.get.return_value = get_resp + + with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"): + result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt") + + assert result["file_type"] == "text/plain" + mock_proxy.get.assert_called_once() + + +# --------------------------------------------------------------------------- +# RemoteFileUploadApi +# --------------------------------------------------------------------------- +class TestRemoteFileUploadApi: + @patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url") + @patch("controllers.web.remote_files.FileService") + @patch("controllers.web.remote_files.helpers.guess_file_info_from_response") + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + @patch("controllers.web.remote_files.db") + def test_upload_success( + self, + mock_db: MagicMock, + mock_ns: MagicMock, + mock_proxy: MagicMock, + mock_guess: MagicMock, + mock_file_svc_cls: MagicMock, + mock_signed: MagicMock, + app: Flask, + ) -> None: + mock_db.engine = "engine" + mock_ns.payload = {"url": "https://example.com/file.pdf"} + head_resp = MagicMock() + head_resp.status_code = 200 + head_resp.content = b"pdf-content" + head_resp.request.method = "HEAD" + mock_proxy.head.return_value = head_resp + get_resp = MagicMock() + get_resp.content = b"pdf-content" + mock_proxy.get.return_value = get_resp + + mock_guess.return_value = SimpleNamespace( + filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100 + ) + mock_file_svc_cls.is_file_size_within_limit.return_value = True + + from datetime import datetime + + upload_file = SimpleNamespace( + id="f-1", + name="file.pdf", + size=100, + extension="pdf", + mime_type="application/pdf", + created_by="eu-1", + created_at=datetime(2024, 1, 1), + ) + mock_file_svc_cls.return_value.upload_file.return_value = upload_file + + with app.test_request_context("/remote-files/upload", method="POST"): + result, status = RemoteFileUploadApi().post(_app_model(), _end_user()) + + assert status == 201 + assert result["id"] == "f-1" + + @patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False) + @patch("controllers.web.remote_files.helpers.guess_file_info_from_response") + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + def test_file_too_large( + self, + mock_ns: MagicMock, + mock_proxy: MagicMock, + mock_guess: MagicMock, + mock_size_check: MagicMock, + app: Flask, + ) -> None: + mock_ns.payload = {"url": "https://example.com/big.zip"} + head_resp = MagicMock() + head_resp.status_code = 200 + mock_proxy.head.return_value = head_resp + mock_guess.return_value = SimpleNamespace( + filename="big.zip", extension="zip", mimetype="application/zip", size=999999999 + ) + + with app.test_request_context("/remote-files/upload", method="POST"): + with pytest.raises(FileTooLargeError): + RemoteFileUploadApi().post(_app_model(), _end_user()) + + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None: + import httpx + + mock_ns.payload = {"url": "https://example.com/bad"} + mock_proxy.head.side_effect = httpx.RequestError("connection failed") + + with app.test_request_context("/remote-files/upload", method="POST"): + with pytest.raises(RemoteFileUploadError): + RemoteFileUploadApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_saved_message.py b/api/tests/unit_tests/controllers/web/test_saved_message.py new file mode 100644 index 0000000000..3d55804912 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_saved_message.py @@ -0,0 +1,97 @@ +"""Unit tests for controllers.web.saved_message endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.error import NotCompletionAppError +from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi +from services.errors.message import MessageNotExistsError + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# SavedMessageListApi (GET) +# --------------------------------------------------------------------------- +class TestSavedMessageListApiGet: + def test_non_completion_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/saved-messages"): + with pytest.raises(NotCompletionAppError): + SavedMessageListApi().get(_chat_app(), _end_user()) + + @patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id") + def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None: + mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[]) + + with app.test_request_context("/saved-messages?limit=20"): + result = SavedMessageListApi().get(_completion_app(), _end_user()) + + assert result["limit"] == 20 + assert result["has_more"] is False + + +# --------------------------------------------------------------------------- +# SavedMessageListApi (POST) +# --------------------------------------------------------------------------- +class TestSavedMessageListApiPost: + def test_non_completion_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/saved-messages", method="POST"): + with pytest.raises(NotCompletionAppError): + SavedMessageListApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.saved_message.SavedMessageService.save") + @patch("controllers.web.saved_message.web_ns") + def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None: + msg_id = str(uuid4()) + mock_ns.payload = {"message_id": msg_id} + + with app.test_request_context("/saved-messages", method="POST"): + result = SavedMessageListApi().post(_completion_app(), _end_user()) + + assert result["result"] == "success" + + @patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError()) + @patch("controllers.web.saved_message.web_ns") + def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None: + mock_ns.payload = {"message_id": str(uuid4())} + + with app.test_request_context("/saved-messages", method="POST"): + with pytest.raises(NotFound, match="Message Not Exists"): + SavedMessageListApi().post(_completion_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# SavedMessageApi (DELETE) +# --------------------------------------------------------------------------- +class TestSavedMessageApi: + def test_non_completion_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"): + with pytest.raises(NotCompletionAppError): + SavedMessageApi().delete(_chat_app(), _end_user(), msg_id) + + @patch("controllers.web.saved_message.SavedMessageService.delete") + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"): + result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id) + + assert status == 204 + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/unit_tests/controllers/web/test_site.py new file mode 100644 index 0000000000..557bf93e9e --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_site.py @@ -0,0 +1,126 @@ +"""Unit tests for controllers.web.site endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.web.site import AppSiteApi, AppSiteInfo + + +def _tenant(*, status: str = "normal") -> SimpleNamespace: + return SimpleNamespace( + id="tenant-1", + status=status, + plan="basic", + custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False}, + ) + + +def _site() -> SimpleNamespace: + return SimpleNamespace( + title="Site", + icon_type="emoji", + icon="robot", + icon_background="#fff", + description="desc", + default_language="en", + chat_color_theme="light", + chat_color_theme_inverted=False, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + prompt_public=False, + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + +# --------------------------------------------------------------------------- +# AppSiteApi +# --------------------------------------------------------------------------- +class TestAppSiteApi: + @patch("controllers.web.site.FeatureService.get_features") + @patch("controllers.web.site.db") + def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + mock_features.return_value = SimpleNamespace(can_replace_logo=False) + site_obj = _site() + mock_db.session.query.return_value.where.return_value.first.return_value = site_obj + tenant = _tenant() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + result = AppSiteApi().get(app_model, end_user) + + # marshal_with serializes AppSiteInfo to a dict + assert result["app_id"] == "app-1" + assert result["plan"] == "basic" + assert result["enable_site"] is True + + @patch("controllers.web.site.db") + def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + mock_db.session.query.return_value.where.return_value.first.return_value = None + tenant = _tenant() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + with pytest.raises(Forbidden): + AppSiteApi().get(app_model, end_user) + + @patch("controllers.web.site.db") + def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + from models.account import TenantStatus + + mock_db.session.query.return_value.where.return_value.first.return_value = _site() + tenant = SimpleNamespace( + id="tenant-1", + status=TenantStatus.ARCHIVE, + plan="basic", + custom_config_dict={}, + ) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + with pytest.raises(Forbidden): + AppSiteApi().get(app_model, end_user) + + +# --------------------------------------------------------------------------- +# AppSiteInfo +# --------------------------------------------------------------------------- +class TestAppSiteInfo: + def test_basic_fields(self) -> None: + tenant = _tenant() + site_obj = _site() + info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) + + assert info.app_id == "app-1" + assert info.end_user_id == "eu-1" + assert info.enable_site is True + assert info.plan == "basic" + assert info.can_replace_logo is False + assert info.model_config is None + + @patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com")) + def test_can_replace_logo_sets_custom_config(self) -> None: + tenant = SimpleNamespace( + id="tenant-1", + plan="pro", + custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, + ) + site_obj = _site() + info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) + + assert info.can_replace_logo is True + assert info.custom_config["remove_webapp_brand"] is True + assert "webapp-logo" in info.custom_config["replace_webapp_logo"] diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index e62993e8d5..0661c02578 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -5,7 +5,8 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi +import services.errors.account +from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi def encode_code(code: str) -> str: @@ -89,3 +90,114 @@ class TestEmailCodeLoginApi: mock_revoke_token.assert_called_once_with("token-123") mock_login.assert_called_once() mock_reset_login_rate.assert_called_once_with("user@example.com") + + +class TestLoginApi: + @patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok") + @patch("controllers.web.login.WebAppAuthService.authenticate") + def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None: + mock_auth.return_value = MagicMock() + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + response = LoginApi().post() + + assert response.get_json()["data"]["access_token"] == "access-tok" + mock_auth.assert_called_once() + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountLoginError(), + ) + def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.error import AccountBannedError + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AccountBannedError): + LoginApi().post() + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountPasswordError(), + ) + def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.auth.error import AuthenticationFailedError + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + +class TestLoginStatusApi: + @patch("controllers.web.login.extract_webapp_access_token", return_value=None) + def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None: + with app.test_request_context("/web/login/status"): + result = LoginStatusApi().get() + + assert result["logged_in"] is False + assert result["app_logged_in"] is False + + @patch("controllers.web.login.decode_jwt_token") + @patch("controllers.web.login.PassportService") + @patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False) + @patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1") + @patch("controllers.web.login.extract_webapp_access_token", return_value="tok") + def test_public_app_user_logged_in( + self, + mock_extract: MagicMock, + mock_app_id: MagicMock, + mock_perm: MagicMock, + mock_passport: MagicMock, + mock_decode: MagicMock, + app: Flask, + ) -> None: + mock_decode.return_value = (MagicMock(), MagicMock()) + + with app.test_request_context("/web/login/status?app_code=code1"): + result = LoginStatusApi().get() + + assert result["logged_in"] is True + assert result["app_logged_in"] is True + + @patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad")) + @patch("controllers.web.login.PassportService") + @patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True) + @patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1") + @patch("controllers.web.login.extract_webapp_access_token", return_value="tok") + def test_private_app_passport_fails( + self, + mock_extract: MagicMock, + mock_app_id: MagicMock, + mock_perm: MagicMock, + mock_passport_cls: MagicMock, + mock_decode: MagicMock, + app: Flask, + ) -> None: + mock_passport_cls.return_value.verify.side_effect = Exception("bad") + + with app.test_request_context("/web/login/status?app_code=code1"): + result = LoginStatusApi().get() + + assert result["logged_in"] is False + assert result["app_logged_in"] is False + + +class TestLogoutApi: + @patch("controllers.web.login.clear_webapp_access_token_from_cookie") + def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None: + with app.test_request_context("/web/logout", method="POST"): + response = LogoutApi().post() + + assert response.get_json() == {"result": "success"} + mock_clear.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_web_passport.py b/api/tests/unit_tests/controllers/web/test_web_passport.py new file mode 100644 index 0000000000..19b1d8504a --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_passport.py @@ -0,0 +1,192 @@ +"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.web.error import WebAppAuthRequiredError +from controllers.web.passport import ( + PassportResource, + decode_enterprise_webapp_user_id, + exchange_token_for_existing_web_user, + generate_session_id, +) +from services.webapp_auth_service import WebAppAuthType + + +# --------------------------------------------------------------------------- +# decode_enterprise_webapp_user_id +# --------------------------------------------------------------------------- +class TestDecodeEnterpriseWebappUserId: + def test_none_token_returns_none(self) -> None: + assert decode_enterprise_webapp_user_id(None) is None + + @patch("controllers.web.passport.PassportService") + def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = { + "token_source": "webapp_login_token", + "user_id": "u1", + } + result = decode_enterprise_webapp_user_id("valid-jwt") + assert result["user_id"] == "u1" + + @patch("controllers.web.passport.PassportService") + def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = { + "token_source": "other_source", + } + with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"): + decode_enterprise_webapp_user_id("bad-jwt") + + @patch("controllers.web.passport.PassportService") + def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = {} + with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"): + decode_enterprise_webapp_user_id("no-source-jwt") + + +# --------------------------------------------------------------------------- +# generate_session_id +# --------------------------------------------------------------------------- +class TestGenerateSessionId: + @patch("controllers.web.passport.db") + def test_returns_unique_session_id(self, mock_db: MagicMock) -> None: + mock_db.session.scalar.return_value = 0 + sid = generate_session_id() + assert isinstance(sid, str) + assert len(sid) == 36 # UUID format + + @patch("controllers.web.passport.db") + def test_retries_on_collision(self, mock_db: MagicMock) -> None: + # First call returns count=1 (collision), second returns 0 + mock_db.session.scalar.side_effect = [1, 0] + sid = generate_session_id() + assert isinstance(sid, str) + assert mock_db.session.scalar.call_count == 2 + + +# --------------------------------------------------------------------------- +# exchange_token_for_existing_web_user +# --------------------------------------------------------------------------- +class TestExchangeTokenForExistingWebUser: + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + site = SimpleNamespace(code="code1", app_id="app-1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + + decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external" + with pytest.raises(WebAppAuthRequiredError, match="external"): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL + ) + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + site = SimpleNamespace(code="code1", app_id="app-1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + + decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal" + with pytest.raises(WebAppAuthRequiredError, match="internal"): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL + ) + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + mock_db.session.scalar.return_value = None + decoded = {"user_id": "u1", "auth_type": "external"} + with pytest.raises(NotFound): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL + ) + + +# --------------------------------------------------------------------------- +# PassportResource.get +# --------------------------------------------------------------------------- +class TestPassportResource: + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + with app.test_request_context("/passport"): + with pytest.raises(Unauthorized, match="X-App-Code"): + PassportResource().get() + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.generate_session_id", return_value="new-sess-id") + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_creates_new_end_user_when_no_user_id( + self, + mock_features: MagicMock, + mock_db: MagicMock, + mock_gen_session: MagicMock, + mock_passport_cls: MagicMock, + app: Flask, + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + mock_passport_cls.return_value.issue.return_value = "issued-token" + + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + response = PassportResource().get() + + assert response.get_json()["access_token"] == "issued-token" + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_reuses_existing_end_user_when_user_id_provided( + self, + mock_features: MagicMock, + mock_db: MagicMock, + mock_passport_cls: MagicMock, + app: Flask, + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing") + mock_db.session.scalar.side_effect = [site, app_model, existing_user] + mock_passport_cls.return_value.issue.return_value = "reused-token" + + with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}): + response = PassportResource().get() + + assert response.get_json()["access_token"] == "reused-token" + # Should not create a new end user + mock_db.session.add.assert_not_called() + + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_db.session.scalar.return_value = None + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + PassportResource().get() + + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False) + mock_db.session.scalar.side_effect = [site, disabled_app] + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + PassportResource().get() diff --git a/api/tests/unit_tests/controllers/web/test_workflow.py b/api/tests/unit_tests/controllers/web/test_workflow.py new file mode 100644 index 0000000000..0973340527 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_workflow.py @@ -0,0 +1,95 @@ +"""Unit tests for controllers.web.workflow endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.error import ( + NotWorkflowAppError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi +from core.errors.error import ProviderTokenNotInitError, QuotaExceededError + + +def _workflow_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="workflow") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# WorkflowRunApi +# --------------------------------------------------------------------------- +class TestWorkflowRunApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(NotWorkflowAppError): + WorkflowRunApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"}) + @patch("controllers.web.workflow.AppGenerateService.generate") + @patch("controllers.web.workflow.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {"key": "val"}} + mock_gen.return_value = "response" + + with app.test_request_context("/workflows/run", method="POST"): + result = WorkflowRunApi().post(_workflow_app(), _end_user()) + + assert result == {"result": "ok"} + + @patch( + "controllers.web.workflow.AppGenerateService.generate", + side_effect=ProviderTokenNotInitError(description="not init"), + ) + @patch("controllers.web.workflow.web_ns") + def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(ProviderNotInitializeError): + WorkflowRunApi().post(_workflow_app(), _end_user()) + + @patch( + "controllers.web.workflow.AppGenerateService.generate", + side_effect=QuotaExceededError(), + ) + @patch("controllers.web.workflow.web_ns") + def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(ProviderQuotaExceededError): + WorkflowRunApi().post(_workflow_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# WorkflowTaskStopApi +# --------------------------------------------------------------------------- +class TestWorkflowTaskStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"): + with pytest.raises(NotWorkflowAppError): + WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1") + + @patch("controllers.web.workflow.GraphEngineManager.send_stop_command") + @patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check") + def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None: + with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"): + result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1") + + assert result == {"result": "success"} + mock_legacy.assert_called_once_with("task-1") + mock_graph.assert_called_once_with("task-1") diff --git a/api/tests/unit_tests/controllers/web/test_workflow_events.py b/api/tests/unit_tests/controllers/web/test_workflow_events.py new file mode 100644 index 0000000000..64c09b5e22 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_workflow_events.py @@ -0,0 +1,127 @@ +"""Unit tests for controllers.web.workflow_events endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.error import NotFoundError +from controllers.web.workflow_events import WorkflowEventsApi +from models.enums import CreatorUserRole + + +def _workflow_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# WorkflowEventsApi +# --------------------------------------------------------------------------- +class TestWorkflowEventsApi: + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="other-app", + created_by_role=CreatorUserRole.END_USER, + created_by="eu-1", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_not_created_by_end_user( + self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask + ) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="eu-1", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="other-user", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.WorkflowResponseConverter") + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_finished_run_returns_sse_response( + self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask + ) -> None: + from datetime import datetime + + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="eu-1", + finished_at=datetime(2024, 1, 1), + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + finish_response = MagicMock() + finish_response.model_dump.return_value = {"task_id": "run-1"} + finish_response.event.value = "workflow_finished" + mock_converter.workflow_run_result_to_finish_response.return_value = finish_response + + with app.test_request_context("/workflow/run-1/events"): + response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + assert response.mimetype == "text/event-stream" diff --git a/api/tests/unit_tests/controllers/web/test_wraps.py b/api/tests/unit_tests/controllers/web/test_wraps.py new file mode 100644 index 0000000000..85049ae975 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_wraps.py @@ -0,0 +1,393 @@ +"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import BadRequest, NotFound, Unauthorized + +from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError +from controllers.web.wraps import ( + _validate_user_accessibility, + _validate_webapp_token, + decode_jwt_token, +) + + +# --------------------------------------------------------------------------- +# _validate_webapp_token +# --------------------------------------------------------------------------- +class TestValidateWebappToken: + def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None: + """When both flags are true, a non-webapp source must raise.""" + decoded = {"token_source": "other"} + with pytest.raises(WebAppAuthRequiredError): + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None: + decoded = {"token_source": "webapp"} + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None: + decoded = {} + with pytest.raises(WebAppAuthRequiredError): + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_public_app_rejects_webapp_source(self) -> None: + """When auth is not required, a webapp-sourced token must be rejected.""" + decoded = {"token_source": "webapp"} + with pytest.raises(Unauthorized): + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_public_app_accepts_non_webapp_source(self) -> None: + decoded = {"token_source": "other"} + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_public_app_accepts_no_source(self) -> None: + decoded = {} + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_system_enabled_but_app_public(self) -> None: + """system_webapp_auth_enabled=True but app is public — webapp source rejected.""" + decoded = {"token_source": "webapp"} + with pytest.raises(Unauthorized): + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True) + + +# --------------------------------------------------------------------------- +# _validate_user_accessibility +# --------------------------------------------------------------------------- +class TestValidateUserAccessibility: + def test_skips_when_auth_disabled(self) -> None: + """No checks when system or app auth is disabled.""" + _validate_user_accessibility( + decoded={}, + app_code="code", + app_web_auth_enabled=False, + system_webapp_auth_enabled=False, + webapp_settings=None, + ) + + def test_missing_user_id_raises(self) -> None: + decoded = {} + with pytest.raises(WebAppAuthRequiredError): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=SimpleNamespace(access_mode="internal"), + ) + + def test_missing_webapp_settings_raises(self) -> None: + decoded = {"user_id": "u1"} + with pytest.raises(WebAppAuthRequiredError, match="settings not found"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=None, + ) + + def test_missing_auth_type_raises(self) -> None: + decoded = {"user_id": "u1", "granted_at": 1} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + def test_missing_granted_at_raises(self) -> None: + decoded = {"user_id": "u1", "auth_type": "external"} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_external_auth_type_checks_sso_update_time( + self, mock_perm_check: MagicMock, mock_sso_time: MagicMock + ) -> None: + # granted_at is before SSO update time → denied + mock_sso_time.return_value = datetime.now(UTC) + old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp()) + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_internal_auth_type_checks_workspace_sso_update_time( + self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock + ) -> None: + mock_workspace_sso.return_value = datetime.now(UTC) + old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp()) + decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_external_auth_passes_when_granted_after_sso_update( + self, mock_perm_check: MagicMock, mock_sso_time: MagicMock + ) -> None: + mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2) + recent_granted = int(datetime.now(UTC).timestamp()) + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted} + settings = SimpleNamespace(access_mode="public") + # Should not raise + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False) + @patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True) + def test_permission_check_denies_unauthorized_user( + self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock + ) -> None: + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())} + settings = SimpleNamespace(access_mode="internal") + with pytest.raises(WebAppAuthAccessDeniedError): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + +# --------------------------------------------------------------------------- +# decode_jwt_token +# --------------------------------------------------------------------------- +class TestDecodeJwtToken: + @patch("controllers.web.wraps._validate_user_accessibility") + @patch("controllers.web.wraps._validate_webapp_token") + @patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.wraps.AppService.get_app_id_by_code") + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_happy_path( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + mock_app_id: MagicMock, + mock_access_mode: MagicMock, + mock_validate_token: MagicMock, + mock_validate_user: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + end_user = SimpleNamespace(id="eu-1", session_id="sess-1") + + # Configure session mock to return correct objects via scalar() + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, end_user] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + result_app, result_user = decode_jwt_token() + + assert result_app.id == "app-1" + assert result_user.id == "eu-1" + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.extract_webapp_passport") + def test_missing_token_raises_unauthorized( + self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_extract.return_value = None + + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(Unauthorized): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_missing_app_raises_not_found( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + session_mock = MagicMock() + session_mock.scalar.return_value = None # No app found + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_disabled_site_raises_bad_request( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=False) + + session_mock = MagicMock() + # scalar calls: app_model, site (code found), then end_user + session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(BadRequest, match="Site is disabled"): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_missing_end_user_raises_not_found( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, None] # end_user is None + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_user_id_mismatch_raises_unauthorized( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + end_user = SimpleNamespace(id="eu-1", session_id="sess-1") + + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, end_user] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(Unauthorized, match="expired"): + decode_jwt_token(user_id="different-user")