From dd4f504b39d1afbfeaf51199789da779e006f654 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Mon, 23 Mar 2026 18:53:05 +0100 Subject: [PATCH 1/6] refactor: select in remaining console app controllers (#33969) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/conversation.py | 10 +++------- api/controllers/console/app/generator.py | 2 +- api/controllers/console/app/mcp_server.py | 14 +++++++------- api/controllers/console/app/model_config.py | 4 +--- api/controllers/console/app/site.py | 5 +++-- api/controllers/console/app/wraps.py | 10 +++++----- .../controllers/console/app/test_app_apis.py | 8 ++------ .../console/app/test_conversation_api.py | 12 ++---------- .../app/test_conversation_read_timestamp.py | 2 +- .../controllers/console/app/test_generator_api.py | 12 ++++-------- .../console/app/test_model_config_api.py | 5 +---- .../controllers/console/app/test_wraps.py | 8 ++------ 12 files changed, 32 insertions(+), 60 deletions(-) diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 74750981dd..d329d22309 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -458,9 +458,7 @@ class ChatConversationApi(Resource): args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore subquery = ( - db.session.query( - Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") - ) + sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() ) @@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource): def _get_conversation(app_model, conversation_id): current_user, _ = current_account_with_tenant() - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() + conversation = db.session.scalar( + sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1) ) if not conversation: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index af4ac450bb..442d0d2324 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource): try: # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": - app = db.session.query(App).where(App.id == args.flow_id).first() + app = db.session.get(App, args.flow_id) if not app: return {"error": f"app {args.flow_id} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 4b20418b53..412fc8795a 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -47,7 +48,7 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_model) def get(self, app_model): - server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() + server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) return server @console_ns.doc("create_app_mcp_server") @@ -98,7 +99,7 @@ class AppMCPServerController(Resource): @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) - server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() + server = db.session.get(AppMCPServer, payload.id) if not server: raise NotFound() @@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource): @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() - server = ( - db.session.query(AppMCPServer) - .where(AppMCPServer.id == server_id) - .where(AppMCPServer.tenant_id == current_tenant_id) - .first() + server = db.session.scalar( + select(AppMCPServer) + .where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id) + .limit(1) ) if not server: raise NotFound() diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a85e54fb51..e9bd30ba7e 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -69,9 +69,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config - original_app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() - ) + original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id) if original_app_model_config is None: raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index db218d8b81..7f44a99ff1 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -2,6 +2,7 @@ from typing import Literal from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -75,7 +76,7 @@ class AppSite(Resource): def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound @@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index e687d980fa..493022ffea 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar, Union +from sqlalchemy import select + from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -15,16 +17,14 @@ R1 = TypeVar("R1") def _load_app_model(app_id: str) -> App | None: _, current_tenant_id = current_account_with_tenant() - app_model = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app_model = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) return app_model def _load_app_model_with_trial(app_id: str) -> App | None: - app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1)) return app_model diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index 60b8ee96fe..beb8ff55a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -281,12 +281,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr( site_module, @@ -305,12 +303,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 5db8e5c332..11b3b3470d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: conversation = SimpleNamespace(id="c1", app_id="app-1") - query = MagicMock() - query.where.return_value = query - query.first.return_value = conversation - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = conversation monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) @@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py index 460da06ecc..f588ab261d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(): ), patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, ): - mock_session.query.return_value.where.return_value.first.return_value = conversation + mock_session.scalar.return_value = conversation _get_conversation(app_model, "conversation-id") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index f83bc18da3..e64c508b82 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) with app.test_request_context( "/console/api/instruction-generate", @@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) _install_workflow_service(monkeypatch, workflow=None) with app.test_request_context( @@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace(graph_dict={"nodes": []}) _install_workflow_service(monkeypatch, workflow=workflow) @@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace( graph_dict={ diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index 61d92bb5c7..a0e2edb8cf 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc ) session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.first.return_value = original_config - session.query.return_value = query + session.get.return_value = original_config monkeypatch.setattr(model_config_module.db, "session", session) monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py index 7664e492da..b5f751f5a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -11,10 +11,8 @@ from models.model import AppMode def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model def handler(app_model): @@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) def handler(app_model): From 0492ed703457f186b5b7d29d4d8e813d088539c4 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 12:54:33 -0500 Subject: [PATCH 2/6] test: migrate api tools manage service tests to testcontainers (#33956) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../tools/test_api_tools_manage_service.py | 148 ++++ .../tools/test_api_tools_manage_service.py | 643 ------------------ 2 files changed, 148 insertions(+), 643 deletions(-) delete mode 100644 api/tests/unit_tests/services/tools/test_api_tools_manage_service.py diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index bffdca623a..d3e765055a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -536,3 +536,151 @@ class TestApiToolManageService: # Verify mock interactions mock_external_service_dependencies["encrypter"].assert_called_once() mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + + def test_delete_api_tool_provider_success( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of an API tool provider.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + provider = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert provider is not None + + result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert deleted is None + + def test_delete_api_tool_provider_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test deletion raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when original provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name="new-name", + original_provider="nonexistent", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_update_api_tool_provider_missing_auth_type( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when auth_type is missing from credentials.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + original_provider=provider_name, + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_list_api_tool_provider_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing tools raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent") + + def test_test_api_tool_preview_invalid_schema_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test preview raises ValueError for invalid schema type.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id=tenant.id, + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type="bad-schema-type", + schema="schema", + ) diff --git a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py deleted file mode 100644 index ce44818886..0000000000 --- a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py +++ /dev/null @@ -1,643 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.tools.entities.tool_entities import ApiProviderSchemaType -from services.tools.api_tools_manage_service import ApiToolManageService - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - # Arrange - mocked_db = mocker.patch("services.tools.api_tools_manage_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace: - return SimpleNamespace(operation_id=operation_id) - - -def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value), - ) - - # Act - result = ApiToolManageService.parser_api_schema("valid-schema") - - # Assert - assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value - assert len(result["credentials_schema"]) == 3 - assert "warning" in result - - -def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=RuntimeError("bad schema"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"): - ApiToolManageService.parser_api_schema("invalid") - - -def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None: - # Arrange - expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER) - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=expected, - ) - extra_info: dict[str, str] = {} - - # Act - result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info) - - # Assert - assert result == expected - - -def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=ValueError("parse failed"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema: parse failed"): - ApiToolManageService.convert_schema_to_tool_bundles("schema") - - -def test_create_api_tool_provider_should_raise_error_when_provider_already_exists( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="provider provider-a already exists"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name=" provider-a ", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - many_tools = [_tool_bundle(str(i)) for i in range(101)] - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=(many_tools, ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="the number of apis should be less than 100"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=[], - ) - - -def test_create_api_tool_provider_should_create_provider_when_input_is_valid( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - mock_controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=mock_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.encrypt.return_value = {"auth_type": "none"} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") - - # Act - result = ApiToolManageService.create_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - icon={"emoji": "X"}, - credentials={"auth_type": "none"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=["news"], - ) - - # Assert - assert result == {"result": "success"} - mock_controller.load_bundled_tools.assert_called_once() - mock_db.session.add.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid( - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.get", - return_value=SimpleNamespace(status_code=200, text="schema-content"), - ) - mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True}) - - # Act - result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") - - # Assert - assert result == {"schema": "schema-content"} - - -@pytest.mark.parametrize("status_code", [400, 404, 500]) -def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid( - status_code: int, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.get", - return_value=SimpleNamespace(status_code=status_code, text="schema-content"), - ) - mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger") - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema, please check the url you provided"): - ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") - mock_logger.exception.assert_called_once() - - -def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found( - mock_db: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="you have not added provider provider-a"): - ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") - - -def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")]) - mock_db.session.query.return_value.where.return_value.first.return_value = provider - controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", - return_value=controller, - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"]) - mock_convert = mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", - side_effect=[{"name": "tool-a"}, {"name": "tool-b"}], - ) - - # Act - result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") - - # Assert - assert result == [{"name": "tool-a"}, {"name": "tool-b"}] - assert mock_convert.call_count == 2 - - -def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found( - mock_db: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="api provider provider-a does not exists"): - ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - original_provider="provider-a", - icon={}, - credentials={"auth_type": "none"}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy=None, - custom_disclaimer="custom", - labels=[], - ) - - -def test_update_api_tool_provider_should_raise_error_when_auth_type_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace(credentials={}, name="old") - mock_db.session.query.return_value.where.return_value.first.return_value = provider - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-a", - original_provider="provider-a", - icon={}, - credentials={}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy=None, - custom_disclaimer="custom", - labels=[], - ) - - -def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = SimpleNamespace( - credentials={"auth_type": "none", "api_key_value": "encrypted-old"}, - name="old", - icon="", - schema="", - description="", - schema_type_str="", - tools_str="", - privacy_policy="", - custom_disclaimer="", - credentials_str="", - ) - mock_db.session.query.return_value.where.return_value.first.return_value = provider - mocker.patch.object( - ApiToolManageService, - "convert_schema_to_tool_bundles", - return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), - ) - controller = MagicMock() - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=controller, - ) - cache = MagicMock() - encrypter = MagicMock() - encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"} - encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} - encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(encrypter, cache), - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") - - # Act - result = ApiToolManageService.update_api_tool_provider( - user_id="user-1", - tenant_id="tenant-1", - provider_name="provider-new", - original_provider="provider-old", - icon={"emoji": "E"}, - credentials={"auth_type": "none", "api_key_value": "***"}, - _schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - privacy_policy="privacy", - custom_disclaimer="custom", - labels=["news"], - ) - - # Assert - assert result == {"result": "success"} - assert provider.name == "provider-new" - assert provider.privacy_policy == "privacy" - assert provider.credentials_str != "" - cache.delete.assert_called_once() - mock_db.session.commit.assert_called_once() - - -def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None: - # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="you have not added provider provider-a"): - ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") - - -def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None: - # Arrange - provider = object() - mock_db.session.query.return_value.where.return_value.first.return_value = provider - - # Act - result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") - - # Assert - assert result == {"result": "success"} - mock_db.session.delete.assert_called_once_with(provider) - mock_db.session.commit.assert_called_once() - - -def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None: - # Arrange - expected = {"provider": "value"} - mock_get = mocker.patch( - "services.tools.api_tools_manage_service.ToolManager.user_get_api_provider", - return_value=expected, - ) - - # Act - result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a") - - # Assert - assert result == expected - mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1") - - -def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None: - # Arrange - schema_type = "bad-schema-type" - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema type"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=schema_type, # type: ignore[arg-type] - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - side_effect=RuntimeError("invalid"), - ) - - # Act + Assert - with pytest.raises(ValueError, match="invalid schema"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") - - # Act + Assert - with pytest.raises(ValueError, match="invalid tool name tool-b"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-b", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_raise_error_when_auth_type_missing( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") - - # Act + Assert - with pytest.raises(ValueError, match="auth_type is required"): - ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - -def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) - mock_db.session.query.return_value.where.return_value.first.return_value = db_provider - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - provider_controller = MagicMock() - tool_obj = MagicMock() - tool_obj.fork_tool_runtime.return_value = tool_obj - tool_obj.validate_credentials.side_effect = ValueError("validation failed") - provider_controller.get_tool.return_value = tool_obj - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=provider_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.decrypt.return_value = {"auth_type": "none"} - mock_encrypter.mask_plugin_credentials.return_value = {} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - - # Act - result = ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - # Assert - assert result == {"error": "validation failed"} - - -def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) - mock_db.session.query.return_value.where.return_value.first.return_value = db_provider - mocker.patch( - "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", - return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), - ) - provider_controller = MagicMock() - tool_obj = MagicMock() - tool_obj.fork_tool_runtime.return_value = tool_obj - tool_obj.validate_credentials.return_value = {"ok": True} - provider_controller.get_tool.return_value = tool_obj - mocker.patch( - "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", - return_value=provider_controller, - ) - mock_encrypter = MagicMock() - mock_encrypter.decrypt.return_value = {"auth_type": "none"} - mock_encrypter.mask_plugin_credentials.return_value = {} - mocker.patch( - "services.tools.api_tools_manage_service.create_tool_provider_encrypter", - return_value=(mock_encrypter, MagicMock()), - ) - - # Act - result = ApiToolManageService.test_api_tool_preview( - tenant_id="tenant-1", - provider_name="provider-a", - tool_name="tool-a", - credentials={"auth_type": "none"}, - parameters={"x": "1"}, - schema_type=ApiProviderSchemaType.OPENAPI, - schema="schema", - ) - - # Assert - assert result == {"result": {"ok": True}} - - -def test_list_api_tools_should_return_all_user_providers_with_converted_tools( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider_one = SimpleNamespace(name="p1") - provider_two = SimpleNamespace(name="p2") - mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two] - - controller_one = MagicMock() - controller_one.get_tools.return_value = ["tool-a"] - controller_two = MagicMock() - controller_two.get_tools.return_value = ["tool-b", "tool-c"] - - user_provider_one = SimpleNamespace(labels=[], tools=[]) - user_provider_two = SimpleNamespace(labels=[], tools=[]) - - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", - side_effect=[controller_one, controller_two], - ) - mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"]) - mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider", - side_effect=[user_provider_one, user_provider_two], - ) - mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider") - mock_convert = mocker.patch( - "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", - side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}], - ) - - # Act - result = ApiToolManageService.list_api_tools("tenant-1") - - # Assert - assert len(result) == 2 - assert user_provider_one.tools == [{"name": "tool-a"}] - assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}] - assert mock_convert.call_count == 3 From f2c71f3668227097f8d3770d1c28020bf17d51a5 Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 13:15:22 -0500 Subject: [PATCH 3/6] test: migrate oauth server service tests to testcontainers (#33958) --- .../services/test_oauth_server_service.py | 174 ++++++++++++++ .../services/test_oauth_server_service.py | 224 ------------------ 2 files changed, 174 insertions(+), 224 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_oauth_server_service.py delete mode 100644 api/tests/unit_tests/services/test_oauth_server_service.py diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py new file mode 100644 index 0000000000..c146a5924b --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -0,0 +1,174 @@ +"""Testcontainers integration tests for OAuthServerService.""" + +from __future__ import annotations + +import uuid +from typing import cast +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import BadRequest + +from models.model import OAuthProviderApp +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +class TestOAuthServerServiceGetProviderApp: + """DB-backed tests for get_oauth_provider_app.""" + + def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + app = OAuthProviderApp( + app_icon="icon.png", + client_id=client_id, + client_secret=str(uuid4()), + app_label={"en-US": "Test OAuth App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + client_id = f"client-{uuid4()}" + created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) + + result = OAuthServerService.get_oauth_provider_app(client_id) + + assert result is not None + assert result.client_id == client_id + assert result.id == created.id + + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") + + assert result is None + + +class TestOAuthServerServiceTokenOperations: + """Redis-backed tests for token sign/validate operations.""" + + @pytest.fixture + def mock_redis(self): + with patch("services.oauth_server.redis_client") as mock: + yield mock + + def test_sign_authorization_code_stores_and_returns_code(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + assert code == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code), + "user-1", + ex=600, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis): + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids): + mock_redis.get.return_value = b"user-1" + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + mock_redis.delete.assert_called_once_with(code_key) + mock_redis.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + mock_redis.get.return_value = b"user-1" + + access_token, returned_refresh = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + assert access_token == str(deterministic_uuid) + assert returned_refresh == "refresh-1" + + def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis): + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1") + + assert result is None + + def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + assert refresh_token == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_validate_access_token_returns_none_when_not_found(self, mock_redis): + mock_redis.get.return_value = None + + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + assert result is None + + def test_validate_access_token_loads_user_when_exists(self, mock_redis): + mock_redis.get.return_value = b"user-88" + expected_user = MagicMock() + + with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load: + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + assert result is expected_user + mock_load.assert_called_once_with("user-88") diff --git a/api/tests/unit_tests/services/test_oauth_server_service.py b/api/tests/unit_tests/services/test_oauth_server_service.py deleted file mode 100644 index 231ceb74dc..0000000000 --- a/api/tests/unit_tests/services/test_oauth_server_service.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -import uuid -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from werkzeug.exceptions import BadRequest - -from services.oauth_server import ( - OAUTH_ACCESS_TOKEN_EXPIRES_IN, - OAUTH_ACCESS_TOKEN_REDIS_KEY, - OAUTH_AUTHORIZATION_CODE_REDIS_KEY, - OAUTH_REFRESH_TOKEN_EXPIRES_IN, - OAUTH_REFRESH_TOKEN_REDIS_KEY, - OAuthGrantType, - OAuthServerService, -) - - -@pytest.fixture -def mock_redis_client(mocker: MockerFixture) -> MagicMock: - return mocker.patch("services.oauth_server.redis_client") - - -@pytest.fixture -def mock_session(mocker: MockerFixture) -> MagicMock: - """Mock the OAuth server Session context manager.""" - mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object())) - session = MagicMock() - session_cm = MagicMock() - session_cm.__enter__.return_value = session - mocker.patch("services.oauth_server.Session", return_value=session_cm) - return session - - -def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None: - # Arrange - mock_execute_result = MagicMock() - expected_app = MagicMock() - mock_execute_result.scalar_one_or_none.return_value = expected_app - mock_session.execute.return_value = mock_execute_result - - # Act - result = OAuthServerService.get_oauth_provider_app("client-1") - - # Assert - assert result is expected_app - mock_session.execute.assert_called_once() - mock_execute_result.scalar_one_or_none.assert_called_once() - - -def test_sign_oauth_authorization_code_should_store_code_and_return_value( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - - # Act - code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") - - # Assert - expected_code = str(deterministic_uuid) - assert code == expected_code - mock_redis_client.set.assert_called_once_with( - OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code), - "user-1", - ex=600, - ) - - -def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act + Assert - with pytest.raises(BadRequest, match="invalid code"): - OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - code="bad-code", - client_id="client-1", - ) - - -def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - token_uuids = [ - uuid.UUID("00000000-0000-0000-0000-000000000201"), - uuid.UUID("00000000-0000-0000-0000-000000000202"), - ] - mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids) - mock_redis_client.get.return_value = b"user-1" - code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") - - # Act - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.AUTHORIZATION_CODE, - code="code-1", - client_id="client-1", - ) - - # Assert - assert access_token == str(token_uuids[0]) - assert refresh_token == str(token_uuids[1]) - mock_redis_client.delete.assert_called_once_with(code_key) - mock_redis_client.set.assert_any_call( - OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), - b"user-1", - ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, - ) - mock_redis_client.set.assert_any_call( - OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), - b"user-1", - ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, - ) - - -def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act + Assert - with pytest.raises(BadRequest, match="invalid refresh token"): - OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.REFRESH_TOKEN, - refresh_token="stale-token", - client_id="client-1", - ) - - -def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - mock_redis_client.get.return_value = b"user-1" - - # Act - access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type=OAuthGrantType.REFRESH_TOKEN, - refresh_token="refresh-1", - client_id="client-1", - ) - - # Assert - assert access_token == str(deterministic_uuid) - assert returned_refresh_token == "refresh-1" - mock_redis_client.set.assert_called_once_with( - OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), - b"user-1", - ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, - ) - - -def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None: - # Arrange - grant_type = cast(OAuthGrantType, "invalid-grant-type") - - # Act - result = OAuthServerService.sign_oauth_access_token( - grant_type=grant_type, - client_id="client-1", - ) - - # Assert - assert result is None - - -def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") - mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) - - # Act - refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") - - # Assert - assert refresh_token == str(deterministic_uuid) - mock_redis_client.set.assert_called_once_with( - OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), - "user-2", - ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, - ) - - -def test_validate_oauth_access_token_should_return_none_when_token_not_found( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act - result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") - - # Assert - assert result is None - - -def test_validate_oauth_access_token_should_load_user_when_token_exists( - mocker: MockerFixture, mock_redis_client: MagicMock -) -> None: - # Arrange - mock_redis_client.get.return_value = b"user-88" - expected_user = MagicMock() - mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user) - - # Act - result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") - - # Assert - assert result is expected_user - mock_load_user.assert_called_once_with("user-88") From 5d2cb3cd803c0ef152c4c90a1e2752297538f175 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:37:51 +0100 Subject: [PATCH 4/6] refactor: use EnumText for DocumentSegment.type (#33979) --- api/models/dataset.py | 5 ++++- api/models/enums.py | 7 +++++++ api/services/dataset_service.py | 5 +++-- api/tests/unit_tests/services/segment_service.py | 3 ++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/api/models/dataset.py b/api/models/dataset.py index d0163e6984..e3cbbf9cb9 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -43,6 +43,7 @@ from .enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, SummaryStatus, ) from .model import App, Tag, TagBinding, UploadFile @@ -998,7 +999,9 @@ class ChildChunk(Base): # indexing fields index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) - type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) diff --git a/api/models/enums.py b/api/models/enums.py index 8aca1df2b4..cdec7b2f12 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum): TIME = "time" +class SegmentType(StrEnum): + """Document segment type""" + + AUTOMATIC = "automatic" + CUSTOMIZED = "customized" + + class SegmentStatus(StrEnum): """Document segment status""" diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dc..ba4ab6757f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -58,6 +58,7 @@ from models.enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -3786,7 +3787,7 @@ class SegmentService: child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED update_child_chunks.append(child_chunk) else: new_child_chunks_args.append(child_chunk_update_args) @@ -3845,7 +3846,7 @@ class SegmentService: child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) db.session.commit() diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index affbc8d0b5..cc2c0a8032 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -4,6 +4,7 @@ import pytest from models.account import Account from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import SegmentType from services.dataset_service import SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError @@ -77,7 +78,7 @@ class SegmentTestDataFactory: chunk.word_count = word_count chunk.index_node_id = f"node-{chunk_id}" chunk.index_node_hash = "hash-123" - chunk.type = "automatic" + chunk.type = SegmentType.AUTOMATIC chunk.created_by = "user-123" chunk.updated_by = None chunk.updated_at = None From cc17c8e883accea6a07a8be059b61da0fa4e2455 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:38:29 +0100 Subject: [PATCH 5/6] refactor: use EnumText for TidbAuthBinding.status and MessageFile.type (#33975) --- .../vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py | 3 ++- .../rag/datasource/vdb/tidb_on_qdrant/tidb_service.py | 3 ++- api/models/dataset.py | 5 ++++- api/models/model.py | 4 ++-- api/schedule/create_tidb_serverless_task.py | 3 ++- api/schedule/update_tidb_serverless_status_task.py | 6 +++++- .../services/test_messages_clean_service.py | 3 ++- .../app/task_pipeline/test_easy_ui_message_end_files.py | 8 ++++---- 8 files changed, 23 insertions(+), 12 deletions(-) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 71b6fa0a9b..3c1d5e015f 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -33,6 +33,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding +from models.enums import TidbAuthBindingStatus if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): password=new_cluster["password"], tenant_id=dataset.tenant_id, active=True, - status="ACTIVE", + status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 754c149241..06b17b9e62 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -9,6 +9,7 @@ from configs import dify_config from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus class TidbService: @@ -170,7 +171,7 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = "ACTIVE" + cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" db.session.add(cluster_info) db.session.commit() diff --git a/api/models/dataset.py b/api/models/dataset.py index e3cbbf9cb9..4c6152ed3f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -45,6 +45,7 @@ from .enums import ( SegmentStatus, SegmentType, SummaryStatus, + TidbAuthBindingStatus, ) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -1242,7 +1243,9 @@ class TidbAuthBinding(TypeBase): cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) + status: Mapped[TidbAuthBindingStatus] = mapped_column( + EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'") + ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/model.py b/api/models/model.py index 4541a3b23a..05233f8711 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -21,7 +21,7 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] @@ -1785,7 +1785,7 @@ class MessageFile(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False) transfer_method: Mapped[FileTransferMethod] = mapped_column( EnumText(FileTransferMethod, length=255), nullable=False ) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8b9d973d6d..6ceb3ef856 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -8,6 +8,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -57,7 +58,7 @@ def create_clusters(batch_size): account=new_cluster["account"], password=new_cluster["password"], active=False, - status="CREATING", + status=TidbAuthBindingStatus.CREATING, ) db.session.add(tidb_auth_binding) db.session.commit() diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1befa0e8b5..10003b1b97 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -9,6 +9,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -18,7 +19,10 @@ def update_tidb_serverless_status_task(): try: # check the number of idle tidb serverless tidb_serverless_list = db.session.scalars( - select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + select(TidbAuthBinding).where( + TidbAuthBinding.active == False, + TidbAuthBinding.status == TidbAuthBindingStatus.CREATING, + ) ).all() if len(tidb_serverless_list) == 0: return diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 8707f2e827..57bbc73b50 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from dify_graph.file.enums import FileType from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration: # MessageFile file = MessageFile( message_id=message.id, - type="image", + type=FileType.IMAGE, transfer_method="local_file", url="http://example.com/test.jpg", belongs_to=MessageFileBelongsTo.USER, diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 582990c88a..37dd116470 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -21,7 +21,7 @@ from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from dify_graph.file.enums import FileTransferMethod +from dify_graph.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile @@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.LOCAL_FILE message_file.upload_file_id = str(uuid.uuid4()) message_file.url = None - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.REMOTE_URL message_file.upload_file_id = None message_file.url = "https://example.com/image.jpg" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.TOOL_FILE message_file.upload_file_id = None message_file.url = "tool_file_123.png" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture From 49a1fae55561d8924df19f7085bef4a5a944264d Mon Sep 17 00:00:00 2001 From: Desel72 Date: Mon, 23 Mar 2026 16:04:34 -0500 Subject: [PATCH 6/6] test: migrate password reset tests to testcontainers (#33974) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/auth/test_password_reset.py | 109 +++--------------- 1 file changed, 17 insertions(+), 92 deletions(-) rename api/tests/{unit_tests => test_containers_integration_tests}/controllers/console/auth/test_password_reset.py (81%) diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py similarity index 81% rename from api/tests/unit_tests/controllers/console/auth/test_password_reset.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 9488cf528e..8f9db287e3 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -1,17 +1,10 @@ -""" -Test suite for password reset authentication flows. +"""Testcontainers integration tests for password reset authentication flows.""" -This module tests the password reset mechanism including: -- Password reset email sending -- Verification code validation -- Password reset with token -- Rate limiting and security checks -""" +from __future__ import annotations from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError -@pytest.fixture(autouse=True) -def _mock_forgot_password_session(): - with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - mock_session_cls.return_value.__exit__.return_value = None - yield mock_session - - -@pytest.fixture(autouse=True) -def _mock_forgot_password_db(): - with patch("controllers.console.auth.forgot_password.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, ): - """ - Test successful password reset email sending. - - Verifies that: - - Email is sent to valid account - - Reset token is generated and returned - - IP rate limiting is checked - """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" @@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi: assert response["data"] == "reset_token_123" mock_send_email.assert_called_once() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): """ Test password reset email blocked by IP rate limit. @@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi: - No email is sent when rate limited """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi: (None, "en-US"), # Defaults to en-US when not provided ], ) - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, language_input, @@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi: - Unsupported languages default to en-US """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "token" @@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): """ @@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi: - Rate limit is reset on success """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_generate_token.return_value = (None, "new_token") @@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi: ) mock_reset_rate_limit.assert_called_once_with("test@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} mock_generate_token.return_value = (None, "fresh-token") @@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token.assert_called_once_with("upper_token") mock_reset_rate_limit.assert_called_once_with("user@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app): """ Test code verification blocked by rate limit. @@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi: - Prevents brute force attacks on verification codes """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True # Act & Assert @@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(EmailPasswordResetLimitError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with invalid token. @@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = None @@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with mismatched email. @@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi: - Prevents token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} @@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidEmailError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): """ Test code verification with incorrect code. @@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi: - Rate limit counter is incremented """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} @@ -380,11 +321,8 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -394,7 +332,6 @@ class TestForgotPasswordResetApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -405,7 +342,6 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - mock_wraps_db, app, mock_account, ): @@ -418,7 +354,6 @@ class TestForgotPasswordResetApi: - Success response is returned """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -436,9 +371,8 @@ class TestForgotPasswordResetApi: assert response["result"] == "success" mock_revoke_token.assert_called_once_with("valid_token") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, mock_db, app): + def test_reset_password_mismatch(self, mock_get_data, app): """ Test password reset with mismatched passwords. @@ -447,7 +381,6 @@ class TestForgotPasswordResetApi: - No password update occurs """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} # Act & Assert @@ -460,9 +393,8 @@ class TestForgotPasswordResetApi: with pytest.raises(PasswordMismatchError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, mock_db, app): + def test_reset_password_invalid_token(self, mock_get_data, app): """ Test password reset with invalid token. @@ -470,7 +402,6 @@ class TestForgotPasswordResetApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -483,9 +414,8 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app): + def test_reset_password_wrong_phase(self, mock_get_data, app): """ Test password reset with token not in reset phase. @@ -494,7 +424,6 @@ class TestForgotPasswordResetApi: - Prevents use of verification-phase tokens for reset """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"} # Act & Assert @@ -507,13 +436,10 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found( - self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app - ): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): """ Test password reset for non-existent account. @@ -521,7 +447,6 @@ class TestForgotPasswordResetApi: - AccountNotFound is raised when account doesn't exist """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} mock_get_account.return_value = None