mirror of https://github.com/langgenius/dify.git
test: migrate oauth server service tests to testcontainers (#33958)
This commit is contained in:
parent
0492ed7034
commit
f2c71f3668
|
|
@ -0,0 +1,174 @@
|
||||||
|
"""Testcontainers integration tests for OAuthServerService."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
|
from models.model import OAuthProviderApp
|
||||||
|
from services.oauth_server import (
|
||||||
|
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||||
|
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||||
|
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||||
|
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||||
|
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||||
|
OAuthGrantType,
|
||||||
|
OAuthServerService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthServerServiceGetProviderApp:
|
||||||
|
"""DB-backed tests for get_oauth_provider_app."""
|
||||||
|
|
||||||
|
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
|
||||||
|
app = OAuthProviderApp(
|
||||||
|
app_icon="icon.png",
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=str(uuid4()),
|
||||||
|
app_label={"en-US": "Test OAuth App"},
|
||||||
|
redirect_uris=["https://example.com/callback"],
|
||||||
|
scope="read",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(app)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
return app
|
||||||
|
|
||||||
|
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
|
||||||
|
client_id = f"client-{uuid4()}"
|
||||||
|
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
|
||||||
|
|
||||||
|
result = OAuthServerService.get_oauth_provider_app(client_id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.client_id == client_id
|
||||||
|
assert result.id == created.id
|
||||||
|
|
||||||
|
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
|
||||||
|
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthServerServiceTokenOperations:
|
||||||
|
"""Redis-backed tests for token sign/validate operations."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis(self):
|
||||||
|
with patch("services.oauth_server.redis_client") as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
|
||||||
|
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||||
|
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||||
|
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||||
|
|
||||||
|
assert code == str(deterministic_uuid)
|
||||||
|
mock_redis.set.assert_called_once_with(
|
||||||
|
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
|
||||||
|
"user-1",
|
||||||
|
ex=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(BadRequest, match="invalid code"):
|
||||||
|
OAuthServerService.sign_oauth_access_token(
|
||||||
|
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||||
|
code="bad-code",
|
||||||
|
client_id="client-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
|
||||||
|
token_uuids = [
|
||||||
|
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||||
|
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||||
|
]
|
||||||
|
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
|
||||||
|
mock_redis.get.return_value = b"user-1"
|
||||||
|
|
||||||
|
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||||
|
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||||
|
code="code-1",
|
||||||
|
client_id="client-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert access_token == str(token_uuids[0])
|
||||||
|
assert refresh_token == str(token_uuids[1])
|
||||||
|
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||||
|
mock_redis.delete.assert_called_once_with(code_key)
|
||||||
|
mock_redis.set.assert_any_call(
|
||||||
|
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||||
|
b"user-1",
|
||||||
|
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||||
|
)
|
||||||
|
mock_redis.set.assert_any_call(
|
||||||
|
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||||
|
b"user-1",
|
||||||
|
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||||
|
OAuthServerService.sign_oauth_access_token(
|
||||||
|
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||||
|
refresh_token="stale-token",
|
||||||
|
client_id="client-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
|
||||||
|
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||||
|
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||||
|
mock_redis.get.return_value = b"user-1"
|
||||||
|
|
||||||
|
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
|
||||||
|
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||||
|
refresh_token="refresh-1",
|
||||||
|
client_id="client-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert access_token == str(deterministic_uuid)
|
||||||
|
assert returned_refresh == "refresh-1"
|
||||||
|
|
||||||
|
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
|
||||||
|
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||||
|
|
||||||
|
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
|
||||||
|
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||||
|
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||||
|
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||||
|
|
||||||
|
assert refresh_token == str(deterministic_uuid)
|
||||||
|
mock_redis.set.assert_called_once_with(
|
||||||
|
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||||
|
"user-2",
|
||||||
|
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
|
||||||
|
mock_redis.get.return_value = b"user-88"
|
||||||
|
expected_user = MagicMock()
|
||||||
|
|
||||||
|
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
|
||||||
|
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||||
|
|
||||||
|
assert result is expected_user
|
||||||
|
mock_load.assert_called_once_with("user-88")
|
||||||
|
|
@ -1,224 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import cast
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pytest_mock import MockerFixture
|
|
||||||
from werkzeug.exceptions import BadRequest
|
|
||||||
|
|
||||||
from services.oauth_server import (
|
|
||||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
|
||||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
|
||||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
|
||||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
|
||||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
|
||||||
OAuthGrantType,
|
|
||||||
OAuthServerService,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
|
||||||
return mocker.patch("services.oauth_server.redis_client")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_session(mocker: MockerFixture) -> MagicMock:
|
|
||||||
"""Mock the OAuth server Session context manager."""
|
|
||||||
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
|
|
||||||
session = MagicMock()
|
|
||||||
session_cm = MagicMock()
|
|
||||||
session_cm.__enter__.return_value = session
|
|
||||||
mocker.patch("services.oauth_server.Session", return_value=session_cm)
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
|
|
||||||
# Arrange
|
|
||||||
mock_execute_result = MagicMock()
|
|
||||||
expected_app = MagicMock()
|
|
||||||
mock_execute_result.scalar_one_or_none.return_value = expected_app
|
|
||||||
mock_session.execute.return_value = mock_execute_result
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OAuthServerService.get_oauth_provider_app("client-1")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is expected_app
|
|
||||||
mock_session.execute.assert_called_once()
|
|
||||||
mock_execute_result.scalar_one_or_none.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
|
|
||||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
|
||||||
) -> None:
|
|
||||||
# Arrange
|
|
||||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
|
||||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
expected_code = str(deterministic_uuid)
|
|
||||||
assert code == expected_code
|
|
||||||
mock_redis_client.set.assert_called_once_with(
|
|
||||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
|
|
||||||
"user-1",
|
|
||||||
ex=600,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
|
|
||||||
mock_redis_client: MagicMock,
|
|
||||||
) -> None:
|
|
||||||
# Arrange
|
|
||||||
mock_redis_client.get.return_value = None
|
|
||||||
|
|
||||||
# Act + Assert
|
|
||||||
with pytest.raises(BadRequest, match="invalid code"):
|
|
||||||
OAuthServerService.sign_oauth_access_token(
|
|
||||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
|
||||||
code="bad-code",
|
|
||||||
client_id="client-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
|
|
||||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
|
||||||
) -> None:
|
|
||||||
# Arrange
|
|
||||||
token_uuids = [
|
|
||||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
|
||||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
|
||||||
]
|
|
||||||
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
|
|
||||||
mock_redis_client.get.return_value = b"user-1"
|
|
||||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
|
||||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
|
||||||
code="code-1",
|
|
||||||
client_id="client-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert access_token == str(token_uuids[0])
|
|
||||||
assert refresh_token == str(token_uuids[1])
|
|
||||||
mock_redis_client.delete.assert_called_once_with(code_key)
|
|
||||||
mock_redis_client.set.assert_any_call(
|
|
||||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
|
||||||
b"user-1",
|
|
||||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
|
||||||
)
|
|
||||||
mock_redis_client.set.assert_any_call(
|
|
||||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
|
||||||
b"user-1",
|
|
||||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
|
|
||||||
mock_redis_client: MagicMock,
|
|
||||||
) -> None:
|
|
||||||
# Arrange
|
|
||||||
mock_redis_client.get.return_value = None
|
|
||||||
|
|
||||||
# Act + Assert
|
|
||||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
|
||||||
OAuthServerService.sign_oauth_access_token(
|
|
||||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
|
||||||
refresh_token="stale-token",
|
|
||||||
client_id="client-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
|
|
||||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
|
||||||
) -> None:
|
|
||||||
# Arrange
|
|
||||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
|
||||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
|
||||||
mock_redis_client.get.return_value = b"user-1"
|
|
||||||
|
|
||||||
# Act
|
|
||||||
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
|
|
||||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
|
||||||
refresh_token="refresh-1",
|
|
||||||
client_id="client-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert access_token == str(deterministic_uuid)
|
|
||||||
assert returned_refresh_token == "refresh-1"
|
|
||||||
mock_redis_client.set.assert_called_once_with(
|
|
||||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
|
||||||
b"user-1",
|
|
||||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
|
|
||||||
# Arrange
|
|
||||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OAuthServerService.sign_oauth_access_token(
|
|
||||||
grant_type=grant_type,
|
|
||||||
client_id="client-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
|
|
||||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
|
||||||
) -> None:
|
|
||||||
# Arrange
|
|
||||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
|
||||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert refresh_token == str(deterministic_uuid)
|
|
||||||
mock_redis_client.set.assert_called_once_with(
|
|
||||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
|
||||||
"user-2",
|
|
||||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
|
|
||||||
mock_redis_client: MagicMock,
|
|
||||||
) -> None:
|
|
||||||
# Arrange
|
|
||||||
mock_redis_client.get.return_value = None
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_oauth_access_token_should_load_user_when_token_exists(
|
|
||||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
|
||||||
) -> None:
|
|
||||||
# Arrange
|
|
||||||
mock_redis_client.get.return_value = b"user-88"
|
|
||||||
expected_user = MagicMock()
|
|
||||||
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is expected_user
|
|
||||||
mock_load_user.assert_called_once_with("user-88")
|
|
||||||
Loading…
Reference in New Issue