diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py new file mode 100644 index 0000000000..9329b83d95 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py @@ -0,0 +1,334 @@ +"""Controller integration tests for console statistic routes.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageFeedback +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation( + db_session: Session, + app_id: str, + account_id: str, + *, + mode: AppMode, + created_at_offset_days: int = 0, +) -> Conversation: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Stats Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + *, + from_account_id: str | None, + from_end_user_id: str | None = None, + message_tokens: int = 1, + answer_tokens: int = 1, + total_price: Decimal = Decimal("0.01"), + provider_response_latency: float = 1.0, + created_at_offset_days: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=message_tokens, + message_unit_price=Decimal("0.001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=answer_tokens, + answer_unit_price=Decimal("0.001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=provider_response_latency, + total_price=total_price, + currency="USD", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=from_end_user_id, + from_account_id=from_account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +def _create_like_feedback( + db_session: Session, + app_id: str, + conversation_id: str, + message_id: str, + account_id: str, +) -> None: + db_session.add( + MessageFeedback( + app_id=app_id, + conversation_id=conversation_id, + message_id=message_id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, + from_account_id=account_id, + ) + ) + db_session.commit() + + +def test_daily_message_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["message_count"] == 1 + + +def test_daily_conversation_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-conversations", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["conversation_count"] == 1 + + +def test_daily_terminals_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=None, + from_end_user_id=str(uuid4()), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-end-users", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["terminal_count"] == 1 + + +def test_daily_token_cost_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + message_tokens=40, + answer_tokens=60, + total_price=Decimal("0.02"), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/token-costs", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["data"][0]["token_count"] == 100 + assert payload["data"][0]["total_price"] == "0.02" + + +def test_average_session_interaction_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-session-interactions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["interactions"] == 2.0 + + +def test_user_satisfaction_rate_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + for _ in range(9): + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["rate"] == 100.0 + + +def test_average_response_time_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + provider_response_latency=1.234, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-response-time", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["latency"] == 1234.0 + + +def test_tokens_per_second_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + answer_tokens=31, + provider_response_latency=2.0, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/tokens-per-second", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["tps"] == 15.5 + + +def test_invalid_time_range( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "Invalid time" + + +def test_time_range_params_passed( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + import datetime + + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + start = datetime.datetime.now() + end = datetime.datetime.now() + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse: + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + mock_parse.assert_called_once_with("something", "something", "UTC") diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..bce39bfde6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,128 @@ +"""Controller integration tests for API key data source auth routes.""" + +import json +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.source import DataSourceApiKeyAuthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_api_key_auth_data_source( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert len(payload["sources"]) == 1 + assert payload["sources"][0]["provider"] == "custom_provider" + + +def test_get_api_key_auth_data_source_empty( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"sources": []} + + +def test_create_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + +def test_create_binding_failure( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth", + side_effect=ValueError("Invalid structure"), + ), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 500 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "auth_failed" + assert payload["message"] == "Invalid structure" + + +def test_delete_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.delete( + f"/console/api/api-key-auth/data-source/{binding.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert db_session_with_containers.scalar( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id) + ) is None diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..0b800d4031 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,328 @@ +"""Controller integration tests for console OAuth server routes.""" + +from datetime import datetime +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models import Account +from models.model import OAuthProviderApp +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + ensure_dify_setup, +) + + +def _create_oauth_provider_app(db_session: Session) -> OAuthProviderApp: + oauth_provider_app = OAuthProviderApp( + app_icon="icon_url", + client_id="test_client_id", + client_secret="test_secret", + app_label={"en-US": "Test App"}, + redirect_uris=["http://localhost/callback"], + scope="read,write", + ) + db_session.add(oauth_provider_app) + db_session.commit() + return oauth_provider_app + + +def test_oauth_provider_successful_post( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["app_icon"] == "icon_url" + assert payload["app_label"] == {"en-US": "Test App"} + assert payload["scope"] == "read,write" + + +def test_oauth_provider_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert payload["message"] == "redirect_uri is invalid" + + +def test_oauth_provider_invalid_client_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["message"] == "client_id is invalid" + + +def test_oauth_authorize_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + _create_oauth_provider_app(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="auth_code_123", + ) as mock_sign: + response = test_client_with_containers.post( + "/console/api/oauth/provider/authorize", + json={"client_id": "test_client_id"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"code": "auth_code_123"} + mock_sign.assert_called_once_with("test_client_id", account.id) + + +def test_oauth_token_authorization_code_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("access_123", "refresh_123"), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "access_123", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "refresh_123", + } + + +def test_oauth_token_authorization_code_grant_missing_code( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "code is required" + + +def test_oauth_token_authorization_code_grant_invalid_secret( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "invalid_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "client_secret is invalid" + + +def test_oauth_token_authorization_code_grant_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://invalid/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "redirect_uri is invalid" + + +def test_oauth_token_refresh_token_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("new_access", "new_refresh"), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "new_access", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "new_refresh", + } + + +def test_oauth_token_refresh_token_grant_missing_token( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "refresh_token is required" + + +def test_oauth_token_invalid_grant_type( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "invalid_grant"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "invalid grant_type" + + +def test_oauth_account_successful_retrieval( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + account = Account( + email="test@example.com", + name="Test User", + interface_language="en-US", + status="active", + ) + account.avatar = "avatar_url" + account.timezone = "UTC" + account.created_at = datetime.utcnow() + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token", + return_value=account, + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer valid_access_token"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "name": "Test User", + "email": "test@example.com", + "avatar": "avatar_url", + "interface_language": "en-US", + "timezone": "UTC", + } + + +def test_oauth_account_missing_authorization_header( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Authorization header is required"} + + +def test_oauth_account_invalid_authorization_header_format( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + _create_oauth_provider_app(db_session_with_containers) + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "InvalidFormat"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Invalid Authorization header format"} diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic.py b/api/tests/unit_tests/controllers/console/app/test_statistic.py deleted file mode 100644 index beba23385d..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_statistic.py +++ /dev/null @@ -1,275 +0,0 @@ -from decimal import Decimal -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.local import LocalProxy - -from controllers.console.app.statistic import ( - AverageResponseTimeStatistic, - AverageSessionInteractionStatistic, - DailyConversationStatistic, - DailyMessageStatistic, - DailyTerminalsStatistic, - DailyTokenCostStatistic, - TokensPerSecondStatistic, - UserSatisfactionRateStatistic, -) -from models import App, AppMode - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.CHAT - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -def setup_test_context( - test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None) -): - with ( - patch("controllers.console.app.statistic.db") as mock_db_stat, - patch("controllers.console.app.wraps.db") as mock_db_wraps, - patch("controllers.console.wraps.db", mock_db_wraps), - patch( - "controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123") - ), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - mock_conn = MagicMock() - mock_conn.execute.return_value = mock_rs - - mock_begin = MagicMock() - mock_begin.__enter__.return_value = mock_conn - mock_db_stat.engine.begin.return_value = mock_begin - - mock_query = MagicMock() - mock_query.filter.return_value.first.return_value = mock_app_model - mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model - mock_query.where.return_value.first.return_value = mock_app_model - mock_query.where.return_value.where.return_value.first.return_value = mock_app_model - mock_db_wraps.session.query.return_value = mock_query - - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - with test_app.test_request_context(route_path, method="GET"): - request.view_args = {"app_id": "app_123"} - api_instance = endpoint_class() - response = api_instance.get(app_id="app_123") - return response - - -class TestStatisticEndpoints: - def test_daily_message_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.message_count = 10 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["message_count"] == 10 - - def test_daily_conversation_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.conversation_count = 5 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyConversationStatistic, - "/apps/app_123/statistics/daily-conversations", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["conversation_count"] == 5 - - def test_daily_terminals_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.terminal_count = 2 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyTerminalsStatistic, - "/apps/app_123/statistics/daily-end-users", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["terminal_count"] == 2 - - def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.token_count = 100 - mock_row.total_price = Decimal("0.02") - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyTokenCostStatistic, - "/apps/app_123/statistics/token-costs", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["token_count"] == 100 - assert response.json["data"][0]["total_price"] == "0.02" - - def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.interactions = Decimal("3.523") - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - AverageSessionInteractionStatistic, - "/apps/app_123/statistics/average-session-interactions", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["interactions"] == 3.52 - - def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.message_count = 100 - mock_row.feedback_count = 10 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - UserSatisfactionRateStatistic, - "/apps/app_123/statistics/user-satisfaction-rate", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["rate"] == 100.0 - - def test_average_response_time_statistic(self, app, mock_account, mock_app_model): - mock_app_model.mode = AppMode.COMPLETION - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.latency = 1.234 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - AverageResponseTimeStatistic, - "/apps/app_123/statistics/average-response-time", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["latency"] == 1234.0 - - def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.tokens_per_second = 15.5 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - TokensPerSecondStatistic, - "/apps/app_123/statistics/tokens-per-second", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["tps"] == 15.5 - - @patch("controllers.console.app.statistic.parse_time_range") - def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model): - mock_parse.side_effect = ValueError("Invalid time") - - from werkzeug.exceptions import BadRequest - - with pytest.raises(BadRequest): - setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=invalid&end=invalid", - mock_account, - mock_app_model, - [], - ) - - @patch("controllers.console.app.statistic.parse_time_range") - def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model): - import datetime - - start = datetime.datetime.now() - end = datetime.datetime.now() - mock_parse.return_value = (start, end) - - response = setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=something&end=something", - mock_account, - mock_app_model, - [], - ) - assert response.status_code == 200 - mock_parse.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py deleted file mode 100644 index bc4c7e0993..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ /dev/null @@ -1,209 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask - -from controllers.console.auth.data_source_bearer_auth import ( - ApiKeyAuthDataSource, - ApiKeyAuthDataSourceBinding, - ApiKeyAuthDataSourceBindingDelete, -) -from controllers.console.auth.error import ApiKeyAuthFailedError - - -class TestApiKeyAuthDataSource: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") - def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_binding = MagicMock() - mock_binding.id = "bind_123" - mock_binding.category = "api_key" - mock_binding.provider = "custom_provider" - mock_binding.disabled = False - mock_binding.created_at.timestamp.return_value = 1620000000 - mock_binding.updated_at.timestamp.return_value = 1620000001 - - mock_get_list.return_value = [mock_binding] - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSource() - response = api_instance.get() - - assert "sources" in response - assert len(response["sources"]) == 1 - assert response["sources"][0]["provider"] == "custom_provider" - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") - def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_get_list.return_value = None - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSource() - response = api_instance.get() - - assert "sources" in response - assert len(response["sources"]) == 0 - - -class TestApiKeyAuthDataSourceBinding: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") - def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context( - "/console/api/api-key-auth/data-source/binding", - method="POST", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, - ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBinding() - response = api_instance.post() - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_validate.assert_called_once() - mock_create.assert_called_once() - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") - def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_create.side_effect = ValueError("Invalid structure") - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context( - "/console/api/api-key-auth/data-source/binding", - method="POST", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, - ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBinding() - with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"): - api_instance.post() - - -class TestApiKeyAuthDataSourceBindingDelete: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth") - def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBindingDelete() - response = api_instance.delete("binding_123") - - assert response[0]["result"] == "success" - assert response[1] == 204 - mock_delete.assert_called_once_with("tenant_123", "binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py deleted file mode 100644 index fc5663e72d..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py +++ /dev/null @@ -1,417 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from werkzeug.exceptions import BadRequest, NotFound - -from controllers.console.auth.oauth_server import ( - OAuthServerAppApi, - OAuthServerUserAccountApi, - OAuthServerUserAuthorizeApi, - OAuthServerUserTokenApi, -) - - -class TestOAuthServerAppApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - oauth_app.redirect_uris = ["http://localhost/callback"] - oauth_app.app_icon = "icon_url" - oauth_app.app_label = "Test App" - oauth_app.scope = "read,write" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, - ): - api_instance = OAuthServerAppApi() - response = api_instance.post() - - assert response["app_icon"] == "icon_url" - assert response["app_label"] == "Test App" - assert response["scope"] == "read,write" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, - ): - api_instance = OAuthServerAppApi() - with pytest.raises(BadRequest, match="redirect_uri is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_client_id(self, mock_get_app, mock_db, app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = None - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, - ): - api_instance = OAuthServerAppApi() - with pytest.raises(NotFound, match="client_id is invalid"): - api_instance.post() - - -class TestOAuthServerUserAuthorizeApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - oauth_app = MagicMock() - oauth_app.client_id = "test_client_id" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.current_account_with_tenant") - @patch("controllers.console.wraps.current_account_with_tenant") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code") - @patch("libs.login.check_csrf_token") - def test_successful_authorize( - self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app - ): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - mock_account = MagicMock() - mock_account.id = "user_123" - from models.account import AccountStatus - - mock_account.status = AccountStatus.ACTIVE - - mock_current.return_value = (mock_account, MagicMock()) - mock_wrap_current.return_value = (mock_account, MagicMock()) - - mock_sign.return_value = "auth_code_123" - - with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}): - with patch("libs.login.current_user", mock_account): - api_instance = OAuthServerUserAuthorizeApi() - response = api_instance.post() - - assert response["code"] == "auth_code_123" - mock_sign.assert_called_once_with("test_client_id", "user_123") - - -class TestOAuthServerUserTokenApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - oauth_app.client_secret = "test_secret" - oauth_app.redirect_uris = ["http://localhost/callback"] - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") - def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_sign.return_value = ("access_123", "refresh_123") - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "test_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - response = api_instance.post() - - assert response["access_token"] == "access_123" - assert response["refresh_token"] == "refresh_123" - assert response["token_type"] == "Bearer" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "client_secret": "test_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="code is required"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "invalid_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="client_secret is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "test_secret", - "redirect_uri": "http://invalid/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="redirect_uri is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") - def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_sign.return_value = ("new_access", "new_refresh") - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, - ): - api_instance = OAuthServerUserTokenApi() - response = api_instance.post() - - assert response["access_token"] == "new_access" - assert response["refresh_token"] == "new_refresh" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "refresh_token", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="refresh_token is required"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "invalid_grant", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="invalid grant_type"): - api_instance.post() - - -class TestOAuthServerUserAccountApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") - def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - mock_account = MagicMock() - mock_account.name = "Test User" - mock_account.email = "test@example.com" - mock_account.avatar = "avatar_url" - mock_account.interface_language = "en-US" - mock_account.timezone = "UTC" - mock_validate.return_value = mock_account - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer valid_access_token"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response["name"] == "Test User" - assert response["email"] == "test@example.com" - assert response["avatar"] == "avatar_url" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Authorization header is required" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "InvalidFormat"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Invalid Authorization header format" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Basic something"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "token_type is invalid" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer "}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Invalid Authorization header format" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") - def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_validate.return_value = None - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer invalid_token"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "access_token or client_id is invalid"