From 1dfa9f41f304feb3438b3cadcc434302113c5033 Mon Sep 17 00:00:00 2001 From: WH-2099 Date: Tue, 24 Mar 2026 21:06:58 +0800 Subject: [PATCH] fix(api): restore compatibility imports for tests --- api/core/workflow/human_input_compat.py | 6 +--- api/dify_graph/constants.py | 2 ++ api/dify_graph/nodes/human_input/enums.py | 5 ++++ api/libs/login.py | 4 ++- .../test_model_load_balancing_service.py | 7 ----- api/tests/unit_tests/libs/test_login.py | 30 ++++++++++++++----- 6 files changed, 33 insertions(+), 21 deletions(-) create mode 100644 api/dify_graph/constants.py diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py index 41b24635d9..498a537d03 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_compat.py @@ -19,15 +19,11 @@ from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.human_input.enums import DeliveryMethodType from dify_graph.runtime import VariablePool from dify_graph.variables.consts import SELECTORS_LENGTH -class DeliveryMethodType(enum.StrEnum): - WEBAPP = enum.auto() - EMAIL = enum.auto() - - class EmailRecipientType(enum.StrEnum): BOUND = "member" MEMBER = BOUND diff --git a/api/dify_graph/constants.py b/api/dify_graph/constants.py new file mode 100644 index 0000000000..5b2e895161 --- /dev/null +++ b/api/dify_graph/constants.py @@ -0,0 +1,2 @@ +CONVERSATION_VARIABLE_NODE_ID = "conversation" +ENVIRONMENT_VARIABLE_NODE_ID = "env" diff --git a/api/dify_graph/nodes/human_input/enums.py b/api/dify_graph/nodes/human_input/enums.py index 3fb0ab4499..5964d64dd7 100644 --- a/api/dify_graph/nodes/human_input/enums.py +++ b/api/dify_graph/nodes/human_input/enums.py @@ -25,6 +25,11 @@ class HumanInputFormKind(enum.StrEnum): DELIVERY_TEST = enum.auto() # Form created for delivery tests. +class DeliveryMethodType(enum.StrEnum): + WEBAPP = enum.auto() + EMAIL = enum.auto() + + class ButtonStyle(enum.StrEnum): """Button styles for user actions.""" diff --git a/api/libs/login.py b/api/libs/login.py index 10447d9ab7..dce332b01d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS @@ -42,6 +42,8 @@ def current_account_with_tenant(): return user, user.current_tenant_id +from typing import ParamSpec, TypeVar + P = ParamSpec("P") R = TypeVar("R") diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 38a93e5391..ca6e7afeab 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -22,9 +22,6 @@ class TestModelLoadBalancingService: "services.model_load_balancing_service.create_plugin_provider_manager", autospec=True ) as mock_provider_manager, patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, - patch( - "services.model_load_balancing_service.create_plugin_model_provider_factory", autospec=True - ) as mock_model_provider_factory, patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter, ): # Setup default mock returns @@ -48,9 +45,6 @@ class TestModelLoadBalancingService: # Mock LBModelManager mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) - # Mock ModelProviderFactory - mock_model_provider_factory_instance = mock_model_provider_factory.return_value - # Mock credential schemas mock_credential_schema = MagicMock() mock_credential_schema.credential_form_schemas = [] @@ -63,7 +57,6 @@ class TestModelLoadBalancingService: yield { "provider_manager": mock_provider_manager, "lb_model_manager": mock_lb_model_manager, - "model_provider_factory": mock_model_provider_factory, "encrypter": mock_encrypter, "provider_config": mock_provider_config, "provider_model_setting": mock_provider_model_setting, diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index 64c2d9af87..1ec8c82b32 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -52,6 +52,20 @@ def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock: return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync) +def _assert_ensure_sync_called_once_with_view(ensure_sync_spy: MagicMock) -> None: + ensure_sync_spy.assert_called_once() + called_view = ensure_sync_spy.call_args.args[0] + assert callable(called_view) + assert called_view.__name__ == "protected_view" + + +def _patch_current_user(mocker: MockerFixture, resolved_user: MockUser | Account | None) -> MagicMock: + current_user_proxy = MagicMock() + current_user_proxy._get_current_object.return_value = resolved_user + mocker.patch.object(login_module, "current_user", new=current_user_proxy) + return current_user_proxy + + class TestLoginRequired: """Test cases for login_required decorator.""" @@ -65,7 +79,7 @@ class TestLoginRequired: return "Protected content" mock_user = MockUser("test_user", is_authenticated=True) - get_user = mocker.patch.object(login_module, "_get_user", return_value=mock_user) + current_user_proxy = _patch_current_user(mocker, mock_user) with login_app.test_request_context(): result = protected_view() @@ -74,8 +88,8 @@ class TestLoginRequired: assert csrf_check.call_args.args[1] == "test_user" assert result == "Protected content" - get_user.assert_called_once_with() - ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__) + current_user_proxy._get_current_object.assert_called_once_with() + _assert_ensure_sync_called_once_with_view(ensure_sync_spy) login_app.login_manager.unauthorized.assert_not_called() @pytest.mark.parametrize( @@ -100,13 +114,13 @@ class TestLoginRequired: def protected_view(): return "Protected content" - get_user = mocker.patch.object(login_module, "_get_user", return_value=resolved_user) + current_user_proxy = _patch_current_user(mocker, resolved_user) with login_app.test_request_context(): result = protected_view() assert result == "Unauthorized", description - get_user.assert_called_once_with() + current_user_proxy._get_current_object.assert_called_once_with() login_app.login_manager.unauthorized.assert_called_once_with() csrf_check.assert_not_called() ensure_sync_spy.assert_not_called() @@ -134,14 +148,14 @@ class TestLoginRequired: def protected_view(): return "Protected content" - get_user = mocker.patch.object(login_module, "_get_user") + current_user_proxy = _patch_current_user(mocker, MockUser("test_user")) monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled) with login_app.test_request_context(method=method): result = protected_view() assert result == "Protected content" - get_user.assert_not_called() - ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__) + current_user_proxy._get_current_object.assert_not_called() + _assert_ensure_sync_called_once_with_view(ensure_sync_spy) csrf_check.assert_not_called() login_app.login_manager.unauthorized.assert_not_called()