test(api): harden login decorator tests

This commit is contained in:
WH-2099 2026-03-24 22:30:46 +08:00
parent 5f15d46ded
commit 63da05c35a
No known key found for this signature in database
1 changed files with 12 additions and 34 deletions

View File

@ -39,38 +39,21 @@ def login_app() -> Flask:
return app
@pytest.fixture(autouse=True)
def reset_login_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", False)
@pytest.fixture
def csrf_check(mocker: MockerFixture) -> MagicMock:
return mocker.patch.object(login_module, "check_csrf_token")
@pytest.fixture
def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock:
def _ensure_sync(func):
return lambda *args, **kwargs: func(*args, **kwargs)
return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync)
def _assert_ensure_sync_called_once_with_view(ensure_sync_spy: MagicMock) -> None:
ensure_sync_spy.assert_called_once()
called_view = ensure_sync_spy.call_args.args[0]
assert callable(called_view)
assert called_view.__name__ == "protected_view"
def _patch_current_user(mocker: MockerFixture, resolved_user: MockUser | Account | None) -> MagicMock:
current_user_proxy = MagicMock()
current_user_proxy._get_current_object.return_value = resolved_user
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
return current_user_proxy
class TestLoginRequired:
"""Test cases for login_required decorator."""
def test_authenticated_user_can_access_protected_view(
self, login_app: Flask, csrf_check: MagicMock, ensure_sync_spy: MagicMock, mocker: MockerFixture
self, login_app: Flask, csrf_check: MagicMock, mocker: MockerFixture
):
"""Test that authenticated users can access protected views."""
@ -79,7 +62,7 @@ class TestLoginRequired:
return "Protected content"
mock_user = MockUser("test_user", is_authenticated=True)
current_user_proxy = _patch_current_user(mocker, mock_user)
resolve_user = mocker.patch.object(login_module, "_resolve_current_user", return_value=mock_user)
with login_app.test_request_context():
result = protected_view()
@ -88,8 +71,7 @@ class TestLoginRequired:
assert csrf_check.call_args.args[1] == "test_user"
assert result == "Protected content"
current_user_proxy._get_current_object.assert_called_once_with()
_assert_ensure_sync_called_once_with_view(ensure_sync_spy)
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_not_called()
@pytest.mark.parametrize(
@ -103,7 +85,6 @@ class TestLoginRequired:
self,
login_app: Flask,
csrf_check: MagicMock,
ensure_sync_spy: MagicMock,
mocker: MockerFixture,
resolved_user: MockUser | None,
description: str,
@ -114,16 +95,15 @@ class TestLoginRequired:
def protected_view():
return "Protected content"
current_user_proxy = _patch_current_user(mocker, resolved_user)
resolve_user = mocker.patch.object(login_module, "_resolve_current_user", return_value=resolved_user)
with login_app.test_request_context():
result = protected_view()
assert result == "Unauthorized", description
current_user_proxy._get_current_object.assert_called_once_with()
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
csrf_check.assert_not_called()
ensure_sync_spy.assert_not_called()
@pytest.mark.parametrize(
("method", "login_disabled"),
@ -136,7 +116,6 @@ class TestLoginRequired:
self,
login_app: Flask,
csrf_check: MagicMock,
ensure_sync_spy: MagicMock,
monkeypatch: pytest.MonkeyPatch,
mocker: MockerFixture,
method: str,
@ -148,14 +127,13 @@ class TestLoginRequired:
def protected_view():
return "Protected content"
current_user_proxy = _patch_current_user(mocker, MockUser("test_user"))
resolve_user = mocker.patch.object(login_module, "_resolve_current_user")
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled)
with login_app.test_request_context(method=method):
result = protected_view()
assert result == "Protected content"
current_user_proxy._get_current_object.assert_not_called()
_assert_ensure_sync_called_once_with_view(ensure_sync_spy)
resolve_user.assert_not_called()
csrf_check.assert_not_called()
login_app.login_manager.unauthorized.assert_not_called()