mirror of https://github.com/langgenius/dify.git
test(api): harden login decorator tests
This commit is contained in:
parent
5f15d46ded
commit
63da05c35a
|
|
@ -39,38 +39,21 @@ def login_app() -> Flask:
|
|||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_login_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def csrf_check(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch.object(login_module, "check_csrf_token")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock:
|
||||
def _ensure_sync(func):
|
||||
return lambda *args, **kwargs: func(*args, **kwargs)
|
||||
|
||||
return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync)
|
||||
|
||||
|
||||
def _assert_ensure_sync_called_once_with_view(ensure_sync_spy: MagicMock) -> None:
|
||||
ensure_sync_spy.assert_called_once()
|
||||
called_view = ensure_sync_spy.call_args.args[0]
|
||||
assert callable(called_view)
|
||||
assert called_view.__name__ == "protected_view"
|
||||
|
||||
|
||||
def _patch_current_user(mocker: MockerFixture, resolved_user: MockUser | Account | None) -> MagicMock:
|
||||
current_user_proxy = MagicMock()
|
||||
current_user_proxy._get_current_object.return_value = resolved_user
|
||||
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
|
||||
return current_user_proxy
|
||||
|
||||
|
||||
class TestLoginRequired:
|
||||
"""Test cases for login_required decorator."""
|
||||
|
||||
def test_authenticated_user_can_access_protected_view(
|
||||
self, login_app: Flask, csrf_check: MagicMock, ensure_sync_spy: MagicMock, mocker: MockerFixture
|
||||
self, login_app: Flask, csrf_check: MagicMock, mocker: MockerFixture
|
||||
):
|
||||
"""Test that authenticated users can access protected views."""
|
||||
|
||||
|
|
@ -79,7 +62,7 @@ class TestLoginRequired:
|
|||
return "Protected content"
|
||||
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
current_user_proxy = _patch_current_user(mocker, mock_user)
|
||||
resolve_user = mocker.patch.object(login_module, "_resolve_current_user", return_value=mock_user)
|
||||
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
|
|
@ -88,8 +71,7 @@ class TestLoginRequired:
|
|||
assert csrf_check.call_args.args[1] == "test_user"
|
||||
|
||||
assert result == "Protected content"
|
||||
current_user_proxy._get_current_object.assert_called_once_with()
|
||||
_assert_ensure_sync_called_once_with_view(ensure_sync_spy)
|
||||
resolve_user.assert_called_once_with()
|
||||
login_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -103,7 +85,6 @@ class TestLoginRequired:
|
|||
self,
|
||||
login_app: Flask,
|
||||
csrf_check: MagicMock,
|
||||
ensure_sync_spy: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
resolved_user: MockUser | None,
|
||||
description: str,
|
||||
|
|
@ -114,16 +95,15 @@ class TestLoginRequired:
|
|||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
current_user_proxy = _patch_current_user(mocker, resolved_user)
|
||||
resolve_user = mocker.patch.object(login_module, "_resolve_current_user", return_value=resolved_user)
|
||||
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
|
||||
assert result == "Unauthorized", description
|
||||
current_user_proxy._get_current_object.assert_called_once_with()
|
||||
resolve_user.assert_called_once_with()
|
||||
login_app.login_manager.unauthorized.assert_called_once_with()
|
||||
csrf_check.assert_not_called()
|
||||
ensure_sync_spy.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "login_disabled"),
|
||||
|
|
@ -136,7 +116,6 @@ class TestLoginRequired:
|
|||
self,
|
||||
login_app: Flask,
|
||||
csrf_check: MagicMock,
|
||||
ensure_sync_spy: MagicMock,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocker: MockerFixture,
|
||||
method: str,
|
||||
|
|
@ -148,14 +127,13 @@ class TestLoginRequired:
|
|||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
current_user_proxy = _patch_current_user(mocker, MockUser("test_user"))
|
||||
resolve_user = mocker.patch.object(login_module, "_resolve_current_user")
|
||||
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled)
|
||||
|
||||
with login_app.test_request_context(method=method):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
current_user_proxy._get_current_object.assert_not_called()
|
||||
_assert_ensure_sync_called_once_with_view(ensure_sync_spy)
|
||||
resolve_user.assert_not_called()
|
||||
csrf_check.assert_not_called()
|
||||
login_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue