diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index 1ec8c82b32..4d71be35e6 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -39,38 +39,21 @@ def login_app() -> Flask: return app +@pytest.fixture(autouse=True) +def reset_login_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", False) + + @pytest.fixture def csrf_check(mocker: MockerFixture) -> MagicMock: return mocker.patch.object(login_module, "check_csrf_token") -@pytest.fixture -def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock: - def _ensure_sync(func): - return lambda *args, **kwargs: func(*args, **kwargs) - - return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync) - - -def _assert_ensure_sync_called_once_with_view(ensure_sync_spy: MagicMock) -> None: - ensure_sync_spy.assert_called_once() - called_view = ensure_sync_spy.call_args.args[0] - assert callable(called_view) - assert called_view.__name__ == "protected_view" - - -def _patch_current_user(mocker: MockerFixture, resolved_user: MockUser | Account | None) -> MagicMock: - current_user_proxy = MagicMock() - current_user_proxy._get_current_object.return_value = resolved_user - mocker.patch.object(login_module, "current_user", new=current_user_proxy) - return current_user_proxy - - class TestLoginRequired: """Test cases for login_required decorator.""" def test_authenticated_user_can_access_protected_view( - self, login_app: Flask, csrf_check: MagicMock, ensure_sync_spy: MagicMock, mocker: MockerFixture + self, login_app: Flask, csrf_check: MagicMock, mocker: MockerFixture ): """Test that authenticated users can access protected views.""" @@ -79,7 +62,7 @@ class TestLoginRequired: return "Protected content" mock_user = MockUser("test_user", is_authenticated=True) - current_user_proxy = _patch_current_user(mocker, mock_user) + resolve_user = mocker.patch.object(login_module, "_resolve_current_user", return_value=mock_user) with login_app.test_request_context(): result = protected_view() @@ -88,8 +71,7 @@ class TestLoginRequired: assert csrf_check.call_args.args[1] == "test_user" assert result == "Protected content" - current_user_proxy._get_current_object.assert_called_once_with() - _assert_ensure_sync_called_once_with_view(ensure_sync_spy) + resolve_user.assert_called_once_with() login_app.login_manager.unauthorized.assert_not_called() @pytest.mark.parametrize( @@ -103,7 +85,6 @@ class TestLoginRequired: self, login_app: Flask, csrf_check: MagicMock, - ensure_sync_spy: MagicMock, mocker: MockerFixture, resolved_user: MockUser | None, description: str, @@ -114,16 +95,15 @@ class TestLoginRequired: def protected_view(): return "Protected content" - current_user_proxy = _patch_current_user(mocker, resolved_user) + resolve_user = mocker.patch.object(login_module, "_resolve_current_user", return_value=resolved_user) with login_app.test_request_context(): result = protected_view() assert result == "Unauthorized", description - current_user_proxy._get_current_object.assert_called_once_with() + resolve_user.assert_called_once_with() login_app.login_manager.unauthorized.assert_called_once_with() csrf_check.assert_not_called() - ensure_sync_spy.assert_not_called() @pytest.mark.parametrize( ("method", "login_disabled"), @@ -136,7 +116,6 @@ class TestLoginRequired: self, login_app: Flask, csrf_check: MagicMock, - ensure_sync_spy: MagicMock, monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, method: str, @@ -148,14 +127,13 @@ class TestLoginRequired: def protected_view(): return "Protected content" - current_user_proxy = _patch_current_user(mocker, MockUser("test_user")) + resolve_user = mocker.patch.object(login_module, "_resolve_current_user") monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled) with login_app.test_request_context(method=method): result = protected_view() assert result == "Protected content" - current_user_proxy._get_current_object.assert_not_called() - _assert_ensure_sync_called_once_with_view(ensure_sync_spy) + resolve_user.assert_not_called() csrf_check.assert_not_called() login_app.login_manager.unauthorized.assert_not_called()