mirror of https://github.com/langgenius/dify.git
233 lines
8.1 KiB
Python
233 lines
8.1 KiB
Python
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from flask import Flask, g
|
|
from flask_login import LoginManager, UserMixin
|
|
from pytest_mock import MockerFixture
|
|
|
|
import libs.login as login_module
|
|
from libs.login import current_user
|
|
from models.account import Account
|
|
|
|
|
|
class MockUser(UserMixin):
|
|
"""Mock user class for testing."""
|
|
|
|
def __init__(self, id: str, is_authenticated: bool = True):
|
|
self.id = id
|
|
self._is_authenticated = is_authenticated
|
|
|
|
@property
|
|
def is_authenticated(self) -> bool:
|
|
return self._is_authenticated
|
|
|
|
|
|
@pytest.fixture
|
|
def login_app() -> Flask:
|
|
app = Flask(__name__)
|
|
app.config["TESTING"] = True
|
|
|
|
login_manager = LoginManager()
|
|
login_manager.init_app(app)
|
|
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
|
|
|
|
@login_manager.user_loader
|
|
def load_user(_user_id: str):
|
|
return None
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def csrf_check(mocker: MockerFixture) -> MagicMock:
|
|
return mocker.patch.object(login_module, "check_csrf_token")
|
|
|
|
|
|
@pytest.fixture
|
|
def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock:
|
|
def _ensure_sync(func):
|
|
return lambda *args, **kwargs: func(*args, **kwargs)
|
|
|
|
return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync)
|
|
|
|
|
|
class TestLoginRequired:
|
|
"""Test cases for login_required decorator."""
|
|
|
|
def test_authenticated_user_can_access_protected_view(
|
|
self, login_app: Flask, csrf_check: MagicMock, ensure_sync_spy: MagicMock, mocker: MockerFixture
|
|
):
|
|
"""Test that authenticated users can access protected views."""
|
|
|
|
@login_module.login_required
|
|
def protected_view():
|
|
return "Protected content"
|
|
|
|
mock_user = MockUser("test_user", is_authenticated=True)
|
|
get_user = mocker.patch.object(login_module, "_get_user", return_value=mock_user)
|
|
|
|
with login_app.test_request_context():
|
|
result = protected_view()
|
|
csrf_check.assert_called_once()
|
|
assert csrf_check.call_args.args[0].method == "GET"
|
|
assert csrf_check.call_args.args[1] == "test_user"
|
|
|
|
assert result == "Protected content"
|
|
get_user.assert_called_once_with()
|
|
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
|
|
login_app.login_manager.unauthorized.assert_not_called()
|
|
|
|
@pytest.mark.parametrize(
|
|
("resolved_user", "description"),
|
|
[
|
|
pytest.param(None, "missing user", id="missing-user"),
|
|
pytest.param(MockUser("test_user", is_authenticated=False), "unauthenticated user", id="unauthenticated"),
|
|
],
|
|
)
|
|
def test_unauthorized_access_returns_login_manager_response(
|
|
self,
|
|
login_app: Flask,
|
|
csrf_check: MagicMock,
|
|
ensure_sync_spy: MagicMock,
|
|
mocker: MockerFixture,
|
|
resolved_user: MockUser | None,
|
|
description: str,
|
|
):
|
|
"""Test that missing or unauthenticated users are redirected."""
|
|
|
|
@login_module.login_required
|
|
def protected_view():
|
|
return "Protected content"
|
|
|
|
get_user = mocker.patch.object(login_module, "_get_user", return_value=resolved_user)
|
|
|
|
with login_app.test_request_context():
|
|
result = protected_view()
|
|
|
|
assert result == "Unauthorized", description
|
|
get_user.assert_called_once_with()
|
|
login_app.login_manager.unauthorized.assert_called_once_with()
|
|
csrf_check.assert_not_called()
|
|
ensure_sync_spy.assert_not_called()
|
|
|
|
@pytest.mark.parametrize(
|
|
("method", "login_disabled"),
|
|
[
|
|
pytest.param("OPTIONS", False, id="options"),
|
|
pytest.param("GET", True, id="login-disabled"),
|
|
],
|
|
)
|
|
def test_bypass_paths_skip_authentication_and_csrf(
|
|
self,
|
|
login_app: Flask,
|
|
csrf_check: MagicMock,
|
|
ensure_sync_spy: MagicMock,
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
mocker: MockerFixture,
|
|
method: str,
|
|
login_disabled: bool,
|
|
):
|
|
"""Test that bypass conditions skip auth lookup, CSRF, and unauthorized handling."""
|
|
|
|
@login_module.login_required
|
|
def protected_view():
|
|
return "Protected content"
|
|
|
|
get_user = mocker.patch.object(login_module, "_get_user")
|
|
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled)
|
|
|
|
with login_app.test_request_context(method=method):
|
|
result = protected_view()
|
|
|
|
assert result == "Protected content"
|
|
get_user.assert_not_called()
|
|
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
|
|
csrf_check.assert_not_called()
|
|
login_app.login_manager.unauthorized.assert_not_called()
|
|
|
|
|
|
class TestGetUser:
|
|
"""Test cases for _get_user function."""
|
|
|
|
def test_get_user_returns_user_from_g(self, login_app: Flask):
|
|
"""Test that _get_user returns user from g._login_user."""
|
|
mock_user = MockUser("test_user")
|
|
|
|
with login_app.test_request_context():
|
|
g._login_user = mock_user
|
|
user = login_module._get_user()
|
|
assert user == mock_user
|
|
assert user.id == "test_user"
|
|
|
|
def test_get_user_loads_user_if_not_in_g(self, login_app: Flask):
|
|
"""Test that _get_user loads user if not already in g."""
|
|
mock_user = MockUser("test_user")
|
|
|
|
def _load_user() -> None:
|
|
g._login_user = mock_user
|
|
|
|
login_app.login_manager._load_user = MagicMock(side_effect=_load_user)
|
|
|
|
with login_app.test_request_context():
|
|
user = login_module._get_user()
|
|
|
|
assert user == mock_user
|
|
login_app.login_manager._load_user.assert_called_once_with()
|
|
|
|
def test_get_user_returns_none_without_request_context(self):
|
|
"""Test that _get_user returns None outside request context."""
|
|
user = login_module._get_user()
|
|
assert user is None
|
|
|
|
|
|
class TestCurrentUser:
|
|
"""Test cases for current_user proxy."""
|
|
|
|
def test_current_user_proxy_returns_authenticated_user(self, login_app: Flask, mocker: MockerFixture):
|
|
"""Test that current_user proxy returns authenticated user."""
|
|
mock_user = MockUser("test_user", is_authenticated=True)
|
|
mocker.patch.object(login_module, "_get_user", return_value=mock_user)
|
|
|
|
with login_app.test_request_context():
|
|
assert current_user.id == "test_user"
|
|
assert current_user.is_authenticated is True
|
|
|
|
def test_current_user_proxy_raises_attribute_error_when_no_user(self, login_app: Flask, mocker: MockerFixture):
|
|
"""Test that current_user proxy handles None user."""
|
|
mocker.patch.object(login_module, "_get_user", return_value=None)
|
|
|
|
with login_app.test_request_context():
|
|
with pytest.raises(AttributeError):
|
|
_ = current_user.id
|
|
|
|
|
|
class TestCurrentAccountWithTenant:
|
|
"""Test cases for current_account_with_tenant helper."""
|
|
|
|
def test_returns_account_and_tenant_id(self, mocker: MockerFixture):
|
|
account = Account(name="Test User", email="test@example.com")
|
|
account._current_tenant = SimpleNamespace(id="tenant-123")
|
|
current_user_proxy = MagicMock()
|
|
current_user_proxy._get_current_object.return_value = account
|
|
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
|
|
|
|
user, tenant_id = login_module.current_account_with_tenant()
|
|
|
|
assert user is account
|
|
assert tenant_id == "tenant-123"
|
|
current_user_proxy._get_current_object.assert_called_once_with()
|
|
|
|
def test_raises_when_current_user_is_not_account(self, mocker: MockerFixture):
|
|
mocker.patch.object(login_module, "current_user", new=MockUser("test_user"))
|
|
|
|
with pytest.raises(ValueError, match="current_user must be an Account instance"):
|
|
login_module.current_account_with_tenant()
|
|
|
|
def test_raises_when_account_has_no_tenant(self, mocker: MockerFixture):
|
|
account = Account(name="Test User", email="test@example.com")
|
|
mocker.patch.object(login_module, "current_user", new=account)
|
|
|
|
with pytest.raises(AssertionError, match="tenant information should be loaded"):
|
|
login_module.current_account_with_tenant()
|