mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/dev-sh/services-retention-ent-plugin-uts
This commit is contained in:
commit
afbe2b3430
|
|
@ -0,0 +1,320 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.message import (
|
||||
ChatMessageListApi,
|
||||
ChatMessagesQuery,
|
||||
FeedbackExportQuery,
|
||||
MessageAnnotationCountApi,
|
||||
MessageApi,
|
||||
MessageFeedbackApi,
|
||||
MessageFeedbackExportApi,
|
||||
MessageFeedbackPayload,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from models import App, AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
import contextlib
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None
|
||||
):
|
||||
with (
|
||||
patch("extensions.ext_database.db") as mock_db,
|
||||
patch("controllers.console.app.wraps.db", mock_db),
|
||||
patch("controllers.console.wraps.db", mock_db),
|
||||
patch("controllers.console.app.message.db", mock_db),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
# Set up a generic query mock that usually returns mock_app_model when getting app
|
||||
app_query_mock = MagicMock()
|
||||
app_query_mock.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
|
||||
data_query_mock = MagicMock()
|
||||
|
||||
def query_side_effect(*args, **kwargs):
|
||||
if args and hasattr(args[0], "__name__") and args[0].__name__ == "App":
|
||||
return app_query_mock
|
||||
return data_query_mock
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
mock_db.data_query = data_query_mock
|
||||
|
||||
# Let the caller override the stat db query logic
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
|
||||
full_path = f"{route_path}?{query_string}" if qs else route_path
|
||||
|
||||
with (
|
||||
patch("libs.login.current_user", proxy_mock),
|
||||
patch("flask_login.current_user", proxy_mock),
|
||||
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
|
||||
):
|
||||
with test_app.test_request_context(full_path, method=method, json=payload):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
|
||||
if "suggested-questions" in route_path:
|
||||
# simplistic extraction for message_id
|
||||
parts = route_path.split("chat-messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
elif "messages/" in route_path and "chat-messages" not in route_path:
|
||||
parts = route_path.split("messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
|
||||
# Check if it has a dispatch_request or method
|
||||
if hasattr(api_instance, method.lower()):
|
||||
yield api_instance, mock_db, request.view_args
|
||||
|
||||
|
||||
class TestMessageValidators:
|
||||
def test_chat_messages_query_validators(self):
|
||||
# Test empty_to_none
|
||||
assert ChatMessagesQuery.empty_to_none("") is None
|
||||
assert ChatMessagesQuery.empty_to_none("val") == "val"
|
||||
|
||||
# Test validate_uuid
|
||||
assert ChatMessagesQuery.validate_uuid(None) is None
|
||||
assert (
|
||||
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_message_feedback_validators(self):
|
||||
assert (
|
||||
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_feedback_export_validators(self):
|
||||
assert FeedbackExportQuery.parse_bool(None) is None
|
||||
assert FeedbackExportQuery.parse_bool(True) is True
|
||||
assert FeedbackExportQuery.parse_bool("1") is True
|
||||
assert FeedbackExportQuery.parse_bool("0") is False
|
||||
assert FeedbackExportQuery.parse_bool("off") is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackExportQuery.parse_bool("invalid")
|
||||
|
||||
|
||||
class TestMessageEndpoints:
|
||||
def test_chat_message_list_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.get(**v_args)
|
||||
|
||||
def test_chat_message_list_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
# mock returns
|
||||
q_mock = mock_db.data_query
|
||||
q_mock.where.return_value.first.side_effect = [mock_conv]
|
||||
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
|
||||
mock_db.session.scalar.return_value = False
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["limit"] == 1
|
||||
assert resp["has_more"] is False
|
||||
assert len(resp["data"]) == 1
|
||||
|
||||
def test_message_feedback_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageFeedbackApi,
|
||||
"/apps/app_123/feedbacks",
|
||||
"POST",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.post(**v_args)
|
||||
|
||||
def test_message_feedback_success(self, app, mock_account, mock_app_model):
|
||||
payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"}
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.admin_feedback = None
|
||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
||||
|
||||
resp = api.post(**v_args)
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_message_annotation_count(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.count.return_value = 5
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"count": 5}
|
||||
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"]
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"data": ["q1", "q2"]}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected_exc"),
|
||||
[
|
||||
(MessageNotExistsError, NotFound),
|
||||
(ConversationNotExistsError, NotFound),
|
||||
(ProviderTokenNotInitError, ProviderNotInitializeError),
|
||||
(QuotaExceededError, ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError),
|
||||
(SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError),
|
||||
(Exception, InternalServerError),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_errors(
|
||||
self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc()
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
with pytest.raises(expected_exc):
|
||||
api.get(**v_args)
|
||||
|
||||
@patch("services.feedback_service.FeedbackService.export_feedbacks")
|
||||
def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model):
|
||||
mock_export.return_value = {"exported": True}
|
||||
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"exported": True}
|
||||
|
||||
def test_message_api_get_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["id"] == "msg_123"
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.statistic import (
|
||||
AverageResponseTimeStatistic,
|
||||
AverageSessionInteractionStatistic,
|
||||
DailyConversationStatistic,
|
||||
DailyMessageStatistic,
|
||||
DailyTerminalsStatistic,
|
||||
DailyTokenCostStatistic,
|
||||
TokensPerSecondStatistic,
|
||||
UserSatisfactionRateStatistic,
|
||||
)
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None)
|
||||
):
|
||||
with (
|
||||
patch("controllers.console.app.statistic.db") as mock_db_stat,
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch(
|
||||
"controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123")
|
||||
),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute.return_value = mock_rs
|
||||
|
||||
mock_begin = MagicMock()
|
||||
mock_begin.__enter__.return_value = mock_conn
|
||||
mock_db_stat.engine.begin.return_value = mock_begin
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method="GET"):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
api_instance = endpoint_class()
|
||||
response = api_instance.get(app_id="app_123")
|
||||
return response
|
||||
|
||||
|
||||
class TestStatisticEndpoints:
|
||||
def test_daily_message_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["message_count"] == 10
|
||||
|
||||
def test_daily_conversation_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.conversation_count = 5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyConversationStatistic,
|
||||
"/apps/app_123/statistics/daily-conversations",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["conversation_count"] == 5
|
||||
|
||||
def test_daily_terminals_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.terminal_count = 2
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTerminalsStatistic,
|
||||
"/apps/app_123/statistics/daily-end-users",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["terminal_count"] == 2
|
||||
|
||||
def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.token_count = 100
|
||||
mock_row.total_price = Decimal("0.02")
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTokenCostStatistic,
|
||||
"/apps/app_123/statistics/token-costs",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["token_count"] == 100
|
||||
assert response.json["data"][0]["total_price"] == "0.02"
|
||||
|
||||
def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.interactions = Decimal("3.523")
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageSessionInteractionStatistic,
|
||||
"/apps/app_123/statistics/average-session-interactions",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["interactions"] == 3.52
|
||||
|
||||
def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 100
|
||||
mock_row.feedback_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
UserSatisfactionRateStatistic,
|
||||
"/apps/app_123/statistics/user-satisfaction-rate",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["rate"] == 100.0
|
||||
|
||||
def test_average_response_time_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.latency = 1.234
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageResponseTimeStatistic,
|
||||
"/apps/app_123/statistics/average-response-time",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["latency"] == 1234.0
|
||||
|
||||
def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.tokens_per_second = 15.5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
TokensPerSecondStatistic,
|
||||
"/apps/app_123/statistics/tokens-per-second",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["tps"] == 15.5
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model):
|
||||
mock_parse.side_effect = ValueError("Invalid time")
|
||||
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=invalid&end=invalid",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model):
|
||||
import datetime
|
||||
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
mock_parse.return_value = (start, end)
|
||||
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=something&end=something",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_parse.assert_called_once()
|
||||
|
|
@ -0,0 +1,313 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
ConversationVariableCollectionApi,
|
||||
EnvironmentVariableCollectionApi,
|
||||
NodeVariableCollectionApi,
|
||||
SystemVariableCollectionApi,
|
||||
VariableApi,
|
||||
VariableResetApi,
|
||||
WorkflowVariableCollectionApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from models import App, AppMode
|
||||
from models.enums import DraftVariableType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None):
|
||||
with (
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch("controllers.console.app.workflow_draft_variable.db"),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method=method, json=payload):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
# extract node_id or variable_id from path manually since view_args overrides
|
||||
if "nodes/" in route_path:
|
||||
request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0]
|
||||
if "variables/" in route_path:
|
||||
# simplistic extraction
|
||||
parts = route_path.split("variables/")
|
||||
if len(parts) > 1 and parts[1] and parts[1] != "reset":
|
||||
request.view_args["variable_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
# we just call dispatch_request to avoid manual argument passing
|
||||
if hasattr(api_instance, method.lower()):
|
||||
func = getattr(api_instance, method.lower())
|
||||
return func(**request.view_args)
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
@staticmethod
|
||||
def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock:
|
||||
class DummyValueType:
|
||||
def exposed_type(self):
|
||||
return DraftVariableType.NODE
|
||||
|
||||
mock_var = MagicMock()
|
||||
mock_var.app_id = "app_123"
|
||||
mock_var.id = "var_123"
|
||||
mock_var.name = "test_var"
|
||||
mock_var.description = ""
|
||||
mock_var.get_variable_type.return_value = variable_type
|
||||
mock_var.get_selector.return_value = []
|
||||
mock_var.value_type = DummyValueType()
|
||||
mock_var.edited = False
|
||||
mock_var.visible = True
|
||||
mock_var.file_id = None
|
||||
mock_var.variable_file = None
|
||||
mock_var.is_truncated.return_value = False
|
||||
mock_var.get_value.return_value.model_copy.return_value.value = "test_value"
|
||||
return mock_var
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_get_success(
|
||||
self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_wf_srv.return_value.is_workflow_exist.return_value = True
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList(
|
||||
variables=[], total=0
|
||||
)
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables?page=1&limit=20",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": [], "total": 0}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.is_workflow_exist.return_value = False
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables",
|
||||
"DELETE",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/node_123/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/sys/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/node_123/variables",
|
||||
"DELETE",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
|
||||
)
|
||||
assert resp["id"] == "var_123"
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = None
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
VariableApi,
|
||||
"/apps/app_123/workflows/draft/variables/var_123",
|
||||
"PATCH",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"name": "new_name"},
|
||||
)
|
||||
assert resp["id"] == "var_123"
|
||||
mock_draft_srv.return_value.update_variable.assert_called_once()
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
mock_draft_srv.return_value.delete_variable.assert_called_once()
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
mock_draft_srv.return_value.reset_variable.return_value = None # means no content
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
VariableResetApi,
|
||||
"/apps/app_123/workflows/draft/variables/var_123/reset",
|
||||
"PUT",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
ConversationVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/conversation-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
SystemVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/system-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf = MagicMock()
|
||||
mock_wf.environment_variables = []
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
EnvironmentVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/environment-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
ApiKeyAuthDataSource,
|
||||
ApiKeyAuthDataSourceBinding,
|
||||
ApiKeyAuthDataSourceBindingDelete,
|
||||
)
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_binding = MagicMock()
|
||||
mock_binding.id = "bind_123"
|
||||
mock_binding.category = "api_key"
|
||||
mock_binding.provider = "custom_provider"
|
||||
mock_binding.disabled = False
|
||||
mock_binding.created_at.timestamp.return_value = 1620000000
|
||||
mock_binding.updated_at.timestamp.return_value = 1620000001
|
||||
|
||||
mock_get_list.return_value = [mock_binding]
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 1
|
||||
assert response["sources"][0]["provider"] == "custom_provider"
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_get_list.return_value = None
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 0
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_validate.assert_called_once()
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_create.side_effect = ValueError("Invalid structure")
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBindingDelete:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth")
|
||||
def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBindingDelete()
|
||||
response = api_instance.delete("binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 204
|
||||
mock_delete.assert_called_once_with("tenant_123", "binding_123")
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.auth.data_source_oauth import (
|
||||
OAuthDataSource,
|
||||
OAuthDataSourceBinding,
|
||||
OAuthDataSourceCallback,
|
||||
OAuthDataSourceSync,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.current_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
|
||||
def test_get_oauth_url_successful(
|
||||
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
|
||||
):
|
||||
mock_oauth_provider = MagicMock()
|
||||
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
|
||||
mock_get_providers.return_value = {"notion": mock_oauth_provider}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
mock_libs_user.return_value = mock_account
|
||||
mock_flask_user.return_value = mock_account
|
||||
|
||||
# also patch current_account_with_tenant
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["data"] == "http://oauth.provider/auth"
|
||||
assert response[1] == 200
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("unknown_provider")
|
||||
|
||||
assert response[0]["error"] == "Invalid provider"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceCallback:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_successful(self, mock_get_providers, app):
|
||||
provider_mock = MagicMock()
|
||||
mock_get_providers.return_value = {"notion": provider_mock}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "code=mock_code" in response.location
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_missing_code(self, mock_get_providers, app):
|
||||
provider_mock = MagicMock()
|
||||
mock_get_providers.return_value = {"notion": provider_mock}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "error=Access denied" in response.location
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_invalid_provider(self, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("invalid")
|
||||
|
||||
assert response[0]["error"] == "Invalid provider"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_get_binding_successful(self, mock_get_providers, app):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_access_token.return_value = None
|
||||
mock_get_providers.return_value = {"notion": mock_provider}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"):
|
||||
api_instance = OAuthDataSourceBinding()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_provider.get_access_token.assert_called_once_with("auth_code_123")
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_get_binding_missing_code(self, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"):
|
||||
api_instance = OAuthDataSourceBinding()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["error"] == "Invalid code"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceSync:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.sync_data_source.return_value = None
|
||||
mock_get_providers.return_value = {"notion": mock_provider}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSourceSync()
|
||||
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
|
||||
response = api_instance.get("notion", "binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_provider.sync_data_source.assert_called_once_with("binding_123")
|
||||
|
|
@ -0,0 +1,417 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.auth.oauth_server import (
|
||||
OAuthServerAppApi,
|
||||
OAuthServerUserAccountApi,
|
||||
OAuthServerUserAuthorizeApi,
|
||||
OAuthServerUserTokenApi,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerAppApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
oauth_app.app_icon = "icon_url"
|
||||
oauth_app.app_label = "Test App"
|
||||
oauth_app.scope = "read,write"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["app_icon"] == "icon_url"
|
||||
assert response["app_label"] == "Test App"
|
||||
assert response["scope"] == "read,write"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_client_id(self, mock_get_app, mock_db, app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(NotFound, match="client_id is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAuthorizeApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
oauth_app = MagicMock()
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
def test_successful_authorize(
|
||||
self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.id = "user_123"
|
||||
from models.account import AccountStatus
|
||||
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
|
||||
mock_current.return_value = (mock_account, MagicMock())
|
||||
mock_wrap_current.return_value = (mock_account, MagicMock())
|
||||
|
||||
mock_sign.return_value = "auth_code_123"
|
||||
|
||||
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
|
||||
with patch("libs.login.current_user", mock_account):
|
||||
api_instance = OAuthServerUserAuthorizeApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["code"] == "auth_code_123"
|
||||
mock_sign.assert_called_once_with("test_client_id", "user_123")
|
||||
|
||||
|
||||
class TestOAuthServerUserTokenApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.client_secret = "test_secret"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("access_123", "refresh_123")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "access_123"
|
||||
assert response["refresh_token"] == "refresh_123"
|
||||
assert response["token_type"] == "Bearer"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="code is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "invalid_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="client_secret is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://invalid/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("new_access", "new_refresh")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "new_access"
|
||||
assert response["refresh_token"] == "new_refresh"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="refresh_token is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "invalid_grant",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="invalid grant_type"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAccountApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.name = "Test User"
|
||||
mock_account.email = "test@example.com"
|
||||
mock_account.avatar = "avatar_url"
|
||||
mock_account.interface_language = "en-US"
|
||||
mock_account.timezone = "UTC"
|
||||
mock_validate.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer valid_access_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["name"] == "Test User"
|
||||
assert response["email"] == "test@example.com"
|
||||
assert response["avatar"] == "avatar_url"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Authorization header is required"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "InvalidFormat"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Basic something"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "token_type is invalid"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer "},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_validate.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "access_token or client_id is invalid"
|
||||
Loading…
Reference in New Issue