mirror of https://github.com/langgenius/dify.git
fix more
This commit is contained in:
parent
1bb160b567
commit
ff15a0426e
|
|
@ -0,0 +1,334 @@
|
|||
"""Controller integration tests for console statistic routes."""
|
||||
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageFeedback
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
create_console_app,
|
||||
)
|
||||
|
||||
|
||||
def _create_conversation(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
mode: AppMode,
|
||||
created_at_offset_days: int = 0,
|
||||
) -> Conversation:
|
||||
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=mode,
|
||||
name="Stats Conversation",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
db_session.commit()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_message(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
*,
|
||||
from_account_id: str | None,
|
||||
from_end_user_id: str | None = None,
|
||||
message_tokens: int = 1,
|
||||
answer_tokens: int = 1,
|
||||
total_price: Decimal = Decimal("0.01"),
|
||||
provider_response_latency: float = 1.0,
|
||||
created_at_offset_days: int = 0,
|
||||
) -> Message:
|
||||
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
|
||||
message = Message(
|
||||
app_id=app_id,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
conversation_id=conversation_id,
|
||||
inputs={},
|
||||
query="Hello",
|
||||
message={"type": "text", "content": "Hello"},
|
||||
message_tokens=message_tokens,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
answer="Hi there",
|
||||
answer_tokens=answer_tokens,
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
parent_message_id=None,
|
||||
provider_response_latency=provider_response_latency,
|
||||
total_price=total_price,
|
||||
currency="USD",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=from_end_user_id,
|
||||
from_account_id=from_account_id,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
app_mode=AppMode.CHAT,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
return message
|
||||
|
||||
|
||||
def _create_like_feedback(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
account_id: str,
|
||||
) -> None:
|
||||
db_session.add(
|
||||
MessageFeedback(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def test_daily_message_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["message_count"] == 1
|
||||
|
||||
|
||||
def test_daily_conversation_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-conversations",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["conversation_count"] == 1
|
||||
|
||||
|
||||
def test_daily_terminals_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=None,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-end-users",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["terminal_count"] == 1
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
message_tokens=40,
|
||||
answer_tokens=60,
|
||||
total_price=Decimal("0.02"),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/token-costs",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["data"][0]["token_count"] == 100
|
||||
assert payload["data"][0]["total_price"] == "0.02"
|
||||
|
||||
|
||||
def test_average_session_interaction_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/average-session-interactions",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["interactions"] == 2.0
|
||||
|
||||
|
||||
def test_user_satisfaction_rate_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
for _ in range(9):
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["rate"] == 100.0
|
||||
|
||||
|
||||
def test_average_response_time_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
provider_response_latency=1.234,
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/average-response-time",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["latency"] == 1234.0
|
||||
|
||||
|
||||
def test_tokens_per_second_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
answer_tokens=31,
|
||||
provider_response_latency=2.0,
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/tokens-per-second",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["tps"] == 15.5
|
||||
|
||||
|
||||
def test_invalid_time_range(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "Invalid time"
|
||||
|
||||
|
||||
def test_time_range_params_passed(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
import datetime
|
||||
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse:
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_parse.assert_called_once_with("something", "something", "UTC")
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
"""Controller integration tests for API key data source auth routes."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def test_get_api_key_auth_data_source(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant.id,
|
||||
category="api_key",
|
||||
provider="custom_provider",
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/api-key-auth/data-source",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert len(payload["sources"]) == 1
|
||||
assert payload["sources"][0]["provider"] == "custom_provider"
|
||||
|
||||
|
||||
def test_get_api_key_auth_data_source_empty(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/api-key-auth/data-source",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"sources": []}
|
||||
|
||||
|
||||
def test_create_binding_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
|
||||
|
||||
def test_create_binding_failure(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth",
|
||||
side_effect=ValueError("Invalid structure"),
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "auth_failed"
|
||||
assert payload["message"] == "Invalid structure"
|
||||
|
||||
|
||||
def test_delete_binding_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant.id,
|
||||
category="api_key",
|
||||
provider="custom_provider",
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/api-key-auth/data-source/{binding.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert db_session_with_containers.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id)
|
||||
) is None
|
||||
|
|
@ -0,0 +1,328 @@
|
|||
"""Controller integration tests for console OAuth server routes."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
ensure_dify_setup,
|
||||
)
|
||||
|
||||
|
||||
def _create_oauth_provider_app(db_session: Session) -> OAuthProviderApp:
|
||||
oauth_provider_app = OAuthProviderApp(
|
||||
app_icon="icon_url",
|
||||
client_id="test_client_id",
|
||||
client_secret="test_secret",
|
||||
app_label={"en-US": "Test App"},
|
||||
redirect_uris=["http://localhost/callback"],
|
||||
scope="read,write",
|
||||
)
|
||||
db_session.add(oauth_provider_app)
|
||||
db_session.commit()
|
||||
return oauth_provider_app
|
||||
|
||||
|
||||
def test_oauth_provider_successful_post(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["app_icon"] == "icon_url"
|
||||
assert payload["app_label"] == {"en-US": "Test App"}
|
||||
assert payload["scope"] == "read,write"
|
||||
|
||||
|
||||
def test_oauth_provider_invalid_redirect_uri(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["message"] == "redirect_uri is invalid"
|
||||
|
||||
|
||||
def test_oauth_provider_invalid_client_id(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["message"] == "client_id is invalid"
|
||||
|
||||
|
||||
def test_oauth_authorize_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
|
||||
return_value="auth_code_123",
|
||||
) as mock_sign:
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/authorize",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"code": "auth_code_123"}
|
||||
mock_sign.assert_called_once_with("test_client_id", account.id)
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
|
||||
return_value=("access_123", "refresh_123"),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"access_token": "access_123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "refresh_123",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_missing_code(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "code is required"
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_invalid_secret(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "invalid_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "client_secret is invalid"
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_invalid_redirect_uri(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://invalid/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "redirect_uri is invalid"
|
||||
|
||||
|
||||
def test_oauth_token_refresh_token_grant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
|
||||
return_value=("new_access", "new_refresh"),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"access_token": "new_access",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "new_refresh",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_token_refresh_token_grant_missing_token(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "refresh_token is required"
|
||||
|
||||
|
||||
def test_oauth_token_invalid_grant_type(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "invalid_grant"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "invalid grant_type"
|
||||
|
||||
|
||||
def test_oauth_account_successful_retrieval(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
account.avatar = "avatar_url"
|
||||
account.timezone = "UTC"
|
||||
account.created_at = datetime.utcnow()
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token",
|
||||
return_value=account,
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer valid_access_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"name": "Test User",
|
||||
"email": "test@example.com",
|
||||
"avatar": "avatar_url",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_account_missing_authorization_header(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.get_json() == {"error": "Authorization header is required"}
|
||||
|
||||
|
||||
def test_oauth_account_invalid_authorization_header_format(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
_create_oauth_provider_app(db_session_with_containers)
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "InvalidFormat"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.get_json() == {"error": "Invalid Authorization header format"}
|
||||
|
|
@ -1,275 +0,0 @@
|
|||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.statistic import (
|
||||
AverageResponseTimeStatistic,
|
||||
AverageSessionInteractionStatistic,
|
||||
DailyConversationStatistic,
|
||||
DailyMessageStatistic,
|
||||
DailyTerminalsStatistic,
|
||||
DailyTokenCostStatistic,
|
||||
TokensPerSecondStatistic,
|
||||
UserSatisfactionRateStatistic,
|
||||
)
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None)
|
||||
):
|
||||
with (
|
||||
patch("controllers.console.app.statistic.db") as mock_db_stat,
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch(
|
||||
"controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123")
|
||||
),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute.return_value = mock_rs
|
||||
|
||||
mock_begin = MagicMock()
|
||||
mock_begin.__enter__.return_value = mock_conn
|
||||
mock_db_stat.engine.begin.return_value = mock_begin
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method="GET"):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
api_instance = endpoint_class()
|
||||
response = api_instance.get(app_id="app_123")
|
||||
return response
|
||||
|
||||
|
||||
class TestStatisticEndpoints:
|
||||
def test_daily_message_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["message_count"] == 10
|
||||
|
||||
def test_daily_conversation_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.conversation_count = 5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyConversationStatistic,
|
||||
"/apps/app_123/statistics/daily-conversations",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["conversation_count"] == 5
|
||||
|
||||
def test_daily_terminals_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.terminal_count = 2
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTerminalsStatistic,
|
||||
"/apps/app_123/statistics/daily-end-users",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["terminal_count"] == 2
|
||||
|
||||
def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.token_count = 100
|
||||
mock_row.total_price = Decimal("0.02")
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTokenCostStatistic,
|
||||
"/apps/app_123/statistics/token-costs",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["token_count"] == 100
|
||||
assert response.json["data"][0]["total_price"] == "0.02"
|
||||
|
||||
def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.interactions = Decimal("3.523")
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageSessionInteractionStatistic,
|
||||
"/apps/app_123/statistics/average-session-interactions",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["interactions"] == 3.52
|
||||
|
||||
def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 100
|
||||
mock_row.feedback_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
UserSatisfactionRateStatistic,
|
||||
"/apps/app_123/statistics/user-satisfaction-rate",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["rate"] == 100.0
|
||||
|
||||
def test_average_response_time_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.latency = 1.234
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageResponseTimeStatistic,
|
||||
"/apps/app_123/statistics/average-response-time",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["latency"] == 1234.0
|
||||
|
||||
def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.tokens_per_second = 15.5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
TokensPerSecondStatistic,
|
||||
"/apps/app_123/statistics/tokens-per-second",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["tps"] == 15.5
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model):
|
||||
mock_parse.side_effect = ValueError("Invalid time")
|
||||
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=invalid&end=invalid",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model):
|
||||
import datetime
|
||||
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
mock_parse.return_value = (start, end)
|
||||
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=something&end=something",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_parse.assert_called_once()
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
ApiKeyAuthDataSource,
|
||||
ApiKeyAuthDataSourceBinding,
|
||||
ApiKeyAuthDataSourceBindingDelete,
|
||||
)
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_binding = MagicMock()
|
||||
mock_binding.id = "bind_123"
|
||||
mock_binding.category = "api_key"
|
||||
mock_binding.provider = "custom_provider"
|
||||
mock_binding.disabled = False
|
||||
mock_binding.created_at.timestamp.return_value = 1620000000
|
||||
mock_binding.updated_at.timestamp.return_value = 1620000001
|
||||
|
||||
mock_get_list.return_value = [mock_binding]
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 1
|
||||
assert response["sources"][0]["provider"] == "custom_provider"
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_get_list.return_value = None
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 0
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_validate.assert_called_once()
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_create.side_effect = ValueError("Invalid structure")
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBindingDelete:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth")
|
||||
def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBindingDelete()
|
||||
response = api_instance.delete("binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 204
|
||||
mock_delete.assert_called_once_with("tenant_123", "binding_123")
|
||||
|
|
@ -1,417 +0,0 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.auth.oauth_server import (
|
||||
OAuthServerAppApi,
|
||||
OAuthServerUserAccountApi,
|
||||
OAuthServerUserAuthorizeApi,
|
||||
OAuthServerUserTokenApi,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerAppApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
oauth_app.app_icon = "icon_url"
|
||||
oauth_app.app_label = "Test App"
|
||||
oauth_app.scope = "read,write"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["app_icon"] == "icon_url"
|
||||
assert response["app_label"] == "Test App"
|
||||
assert response["scope"] == "read,write"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_client_id(self, mock_get_app, mock_db, app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(NotFound, match="client_id is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAuthorizeApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
oauth_app = MagicMock()
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
def test_successful_authorize(
|
||||
self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.id = "user_123"
|
||||
from models.account import AccountStatus
|
||||
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
|
||||
mock_current.return_value = (mock_account, MagicMock())
|
||||
mock_wrap_current.return_value = (mock_account, MagicMock())
|
||||
|
||||
mock_sign.return_value = "auth_code_123"
|
||||
|
||||
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
|
||||
with patch("libs.login.current_user", mock_account):
|
||||
api_instance = OAuthServerUserAuthorizeApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["code"] == "auth_code_123"
|
||||
mock_sign.assert_called_once_with("test_client_id", "user_123")
|
||||
|
||||
|
||||
class TestOAuthServerUserTokenApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.client_secret = "test_secret"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("access_123", "refresh_123")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "access_123"
|
||||
assert response["refresh_token"] == "refresh_123"
|
||||
assert response["token_type"] == "Bearer"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="code is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "invalid_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="client_secret is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://invalid/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("new_access", "new_refresh")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "new_access"
|
||||
assert response["refresh_token"] == "new_refresh"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="refresh_token is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "invalid_grant",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="invalid grant_type"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAccountApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.name = "Test User"
|
||||
mock_account.email = "test@example.com"
|
||||
mock_account.avatar = "avatar_url"
|
||||
mock_account.interface_language = "en-US"
|
||||
mock_account.timezone = "UTC"
|
||||
mock_validate.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer valid_access_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["name"] == "Test User"
|
||||
assert response["email"] == "test@example.com"
|
||||
assert response["avatar"] == "avatar_url"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Authorization header is required"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "InvalidFormat"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Basic something"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "token_type is invalid"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer "},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_validate.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "access_token or client_id is invalid"
|
||||
Loading…
Reference in New Issue