mirror of https://github.com/langgenius/dify.git
test: migrate oauth server service tests to testcontainers (#33958)
This commit is contained in:
parent
0492ed7034
commit
f2c71f3668
|
|
@ -0,0 +1,174 @@
|
|||
"""Testcontainers integration tests for OAuthServerService."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerServiceGetProviderApp:
|
||||
"""DB-backed tests for get_oauth_provider_app."""
|
||||
|
||||
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
|
||||
app = OAuthProviderApp(
|
||||
app_icon="icon.png",
|
||||
client_id=client_id,
|
||||
client_secret=str(uuid4()),
|
||||
app_label={"en-US": "Test OAuth App"},
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
scope="read",
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
|
||||
client_id = f"client-{uuid4()}"
|
||||
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
|
||||
|
||||
result = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.client_id == client_id
|
||||
assert result.id == created.id
|
||||
|
||||
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestOAuthServerServiceTokenOperations:
|
||||
"""Redis-backed tests for token sign/validate operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
with patch("services.oauth_server.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
assert code == str(deterministic_uuid)
|
||||
mock_redis.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
|
||||
mock_redis.get.return_value = b"user-1"
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
mock_redis.delete.assert_called_once_with(code_key)
|
||||
mock_redis.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
mock_redis.get.return_value = b"user-1"
|
||||
|
||||
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh == "refresh-1"
|
||||
|
||||
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
|
||||
mock_redis.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
|
||||
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
assert result is expected_user
|
||||
mock_load.assert_called_once_with("user-88")
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.oauth_server.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mocker: MockerFixture) -> MagicMock:
|
||||
"""Mock the OAuth server Session context manager."""
|
||||
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
mocker.patch("services.oauth_server.Session", return_value=session_cm)
|
||||
return session
|
||||
|
||||
|
||||
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_execute_result = MagicMock()
|
||||
expected_app = MagicMock()
|
||||
mock_execute_result.scalar_one_or_none.return_value = expected_app
|
||||
mock_session.execute.return_value = mock_execute_result
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.get_oauth_provider_app("client-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected_app
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_execute_result.scalar_one_or_none.assert_called_once()
|
||||
|
||||
|
||||
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
# Assert
|
||||
expected_code = str(deterministic_uuid)
|
||||
assert code == expected_code
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
|
||||
# Act
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
mock_redis_client.delete.assert_called_once_with(code_key)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
|
||||
# Act
|
||||
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh_token == "refresh-1"
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
|
||||
# Arrange
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=grant_type,
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
# Assert
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_load_user_when_token_exists(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
# Assert
|
||||
assert result is expected_user
|
||||
mock_load_user.assert_called_once_with("user-88")
|
||||
Loading…
Reference in New Issue