fix(api): restore compatibility imports for tests

This commit is contained in:
WH-2099 2026-03-24 21:06:58 +08:00
parent 78da2d3131
commit 1dfa9f41f3
No known key found for this signature in database
6 changed files with 33 additions and 21 deletions

View File

@ -19,15 +19,11 @@ from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.human_input.enums import DeliveryMethodType
from dify_graph.runtime import VariablePool
from dify_graph.variables.consts import SELECTORS_LENGTH
class DeliveryMethodType(enum.StrEnum):
WEBAPP = enum.auto()
EMAIL = enum.auto()
class EmailRecipientType(enum.StrEnum):
BOUND = "member"
MEMBER = BOUND

View File

@ -0,0 +1,2 @@
CONVERSATION_VARIABLE_NODE_ID = "conversation"
ENVIRONMENT_VARIABLE_NODE_ID = "env"

View File

@ -25,6 +25,11 @@ class HumanInputFormKind(enum.StrEnum):
DELIVERY_TEST = enum.auto() # Form created for delivery tests.
class DeliveryMethodType(enum.StrEnum):
WEBAPP = enum.auto()
EMAIL = enum.auto()
class ButtonStyle(enum.StrEnum):
"""Button styles for user actions."""

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from typing import TYPE_CHECKING, Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
@ -42,6 +42,8 @@ def current_account_with_tenant():
return user, user.current_tenant_id
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")

View File

@ -22,9 +22,6 @@ class TestModelLoadBalancingService:
"services.model_load_balancing_service.create_plugin_provider_manager", autospec=True
) as mock_provider_manager,
patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager,
patch(
"services.model_load_balancing_service.create_plugin_model_provider_factory", autospec=True
) as mock_model_provider_factory,
patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter,
):
# Setup default mock returns
@ -48,9 +45,6 @@ class TestModelLoadBalancingService:
# Mock LBModelManager
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
# Mock ModelProviderFactory
mock_model_provider_factory_instance = mock_model_provider_factory.return_value
# Mock credential schemas
mock_credential_schema = MagicMock()
mock_credential_schema.credential_form_schemas = []
@ -63,7 +57,6 @@ class TestModelLoadBalancingService:
yield {
"provider_manager": mock_provider_manager,
"lb_model_manager": mock_lb_model_manager,
"model_provider_factory": mock_model_provider_factory,
"encrypter": mock_encrypter,
"provider_config": mock_provider_config,
"provider_model_setting": mock_provider_model_setting,

View File

@ -52,6 +52,20 @@ def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock:
return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync)
def _assert_ensure_sync_called_once_with_view(ensure_sync_spy: MagicMock) -> None:
ensure_sync_spy.assert_called_once()
called_view = ensure_sync_spy.call_args.args[0]
assert callable(called_view)
assert called_view.__name__ == "protected_view"
def _patch_current_user(mocker: MockerFixture, resolved_user: MockUser | Account | None) -> MagicMock:
current_user_proxy = MagicMock()
current_user_proxy._get_current_object.return_value = resolved_user
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
return current_user_proxy
class TestLoginRequired:
"""Test cases for login_required decorator."""
@ -65,7 +79,7 @@ class TestLoginRequired:
return "Protected content"
mock_user = MockUser("test_user", is_authenticated=True)
get_user = mocker.patch.object(login_module, "_get_user", return_value=mock_user)
current_user_proxy = _patch_current_user(mocker, mock_user)
with login_app.test_request_context():
result = protected_view()
@ -74,8 +88,8 @@ class TestLoginRequired:
assert csrf_check.call_args.args[1] == "test_user"
assert result == "Protected content"
get_user.assert_called_once_with()
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
current_user_proxy._get_current_object.assert_called_once_with()
_assert_ensure_sync_called_once_with_view(ensure_sync_spy)
login_app.login_manager.unauthorized.assert_not_called()
@pytest.mark.parametrize(
@ -100,13 +114,13 @@ class TestLoginRequired:
def protected_view():
return "Protected content"
get_user = mocker.patch.object(login_module, "_get_user", return_value=resolved_user)
current_user_proxy = _patch_current_user(mocker, resolved_user)
with login_app.test_request_context():
result = protected_view()
assert result == "Unauthorized", description
get_user.assert_called_once_with()
current_user_proxy._get_current_object.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
csrf_check.assert_not_called()
ensure_sync_spy.assert_not_called()
@ -134,14 +148,14 @@ class TestLoginRequired:
def protected_view():
return "Protected content"
get_user = mocker.patch.object(login_module, "_get_user")
current_user_proxy = _patch_current_user(mocker, MockUser("test_user"))
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled)
with login_app.test_request_context(method=method):
result = protected_view()
assert result == "Protected content"
get_user.assert_not_called()
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
current_user_proxy._get_current_object.assert_not_called()
_assert_ensure_sync_called_once_with_view(ensure_sync_spy)
csrf_check.assert_not_called()
login_app.login_manager.unauthorized.assert_not_called()