test: unit test cases for console.explore and tag module (#32186)

This commit is contained in:
rajatagarwal-oss 2026-03-10 08:55:00 +05:30 committed by GitHub
parent 4f835107b2
commit 01991f3536
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 4268 additions and 0 deletions

View File

@ -0,0 +1,402 @@
from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import InternalServerError
import controllers.console.explore.audio as audio_module
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
)
def unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
@pytest.fixture
def installed_app():
app = MagicMock()
app.app = MagicMock()
return app
@pytest.fixture
def audio_file():
return (BytesIO(b"audio"), "audio.wav")
class TestChatAudioApi:
def setup_method(self):
self.api = audio_module.ChatAudioApi()
self.method = unwrap(self.api.post)
def test_post_success(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
return_value={"text": "ok"},
),
):
resp = self.method(installed_app)
assert resp == {"text": "ok"}
def test_app_unavailable(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(),
),
):
with pytest.raises(AppUnavailableError):
self.method(installed_app)
def test_no_audio_uploaded(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=NoAudioUploadedServiceError(),
),
):
with pytest.raises(NoAudioUploadedError):
self.method(installed_app)
def test_audio_too_large(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=AudioTooLargeServiceError("too big"),
),
):
with pytest.raises(AudioTooLargeError):
self.method(installed_app)
def test_provider_quota_exceeded(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=QuotaExceededError(),
),
):
with pytest.raises(ProviderQuotaExceededError):
self.method(installed_app)
def test_unknown_exception(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
self.method(installed_app)
def test_unsupported_audio_type(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=audio_module.UnsupportedAudioTypeServiceError(),
),
):
with pytest.raises(audio_module.UnsupportedAudioTypeError):
self.method(installed_app)
def test_provider_not_support_speech_to_text(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(),
),
):
with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError):
self.method(installed_app)
def test_provider_not_initialized(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=ProviderTokenNotInitError("not init"),
),
):
with pytest.raises(ProviderNotInitializeError):
self.method(installed_app)
def test_model_currently_not_supported(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
self.method(installed_app)
def test_invoke_error_asr(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=InvokeError("invoke failed"),
),
):
with pytest.raises(CompletionRequestError):
self.method(installed_app)
class TestChatTextApi:
def setup_method(self):
self.api = audio_module.ChatTextApi()
self.method = unwrap(self.api.post)
def test_post_success(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"message_id": "m1", "text": "hello", "voice": "v1"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
return_value={"audio": "ok"},
),
):
resp = self.method(installed_app)
assert resp == {"audio": "ok"}
def test_provider_not_initialized(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=ProviderTokenNotInitError("not init"),
),
):
with pytest.raises(ProviderNotInitializeError):
self.method(installed_app)
def test_model_not_supported(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
self.method(installed_app)
def test_invoke_error(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=InvokeError("invoke failed"),
),
):
with pytest.raises(CompletionRequestError):
self.method(installed_app)
def test_unknown_exception(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
self.method(installed_app)
def test_app_unavailable_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(),
),
):
with pytest.raises(AppUnavailableError):
self.method(installed_app)
def test_no_audio_uploaded_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=NoAudioUploadedServiceError(),
),
):
with pytest.raises(NoAudioUploadedError):
self.method(installed_app)
def test_audio_too_large_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=AudioTooLargeServiceError("too big"),
),
):
with pytest.raises(AudioTooLargeError):
self.method(installed_app)
def test_unsupported_audio_type_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=audio_module.UnsupportedAudioTypeServiceError(),
),
):
with pytest.raises(audio_module.UnsupportedAudioTypeError):
self.method(installed_app)
def test_provider_not_support_speech_to_text_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(),
),
):
with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError):
self.method(installed_app)
def test_quota_exceeded_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=QuotaExceededError(),
),
):
with pytest.raises(ProviderQuotaExceededError):
self.method(installed_app)

View File

@ -0,0 +1,100 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import controllers.console.explore.banner as banner_module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestBannerApi:
def test_get_banners_with_requested_language(self, app):
api = banner_module.BannerApi()
method = unwrap(api.get)
banner = MagicMock()
banner.id = "b1"
banner.content = {"text": "hello"}
banner.link = "https://example.com"
banner.sort = 1
banner.status = "enabled"
banner.created_at = datetime(2024, 1, 1)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = [banner]
session = MagicMock()
session.query.return_value = query
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
result = method(api)
assert result == [
{
"id": "b1",
"content": {"text": "hello"},
"link": "https://example.com",
"sort": 1,
"status": "enabled",
"created_at": "2024-01-01T00:00:00",
}
]
def test_get_banners_fallback_to_en_us(self, app):
api = banner_module.BannerApi()
method = unwrap(api.get)
banner = MagicMock()
banner.id = "b2"
banner.content = {"text": "fallback"}
banner.link = None
banner.sort = 1
banner.status = "enabled"
banner.created_at = None
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.side_effect = [
[],
[banner],
]
session = MagicMock()
session.query.return_value = query
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
result = method(api)
assert result == [
{
"id": "b2",
"content": {"text": "fallback"},
"link": None,
"sort": 1,
"status": "enabled",
"created_at": None,
}
]
def test_get_banners_default_language_en_us(self, app):
api = banner_module.BannerApi()
method = unwrap(api.get)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = []
session = MagicMock()
session.query.return_value = query
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
result = method(api)
assert result == []

View File

@ -0,0 +1,459 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import InternalServerError
import controllers.console.explore.completion as completion_module
from controllers.console.app.error import (
ConversationCompletedError,
)
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from models import Account
from models.model import AppMode
from services.errors.llm import InvokeRateLimitError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def user():
return MagicMock(spec=Account)
@pytest.fixture
def completion_app():
return MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
@pytest.fixture
def chat_app():
return MagicMock(app=MagicMock(mode=AppMode.CHAT))
@pytest.fixture
def payload_data():
return {"inputs": {}, "query": "hi"}
@pytest.fixture
def payload_patch(payload_data):
return patch.object(
type(completion_module.console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload_data,
)
class TestCompletionApi:
def test_post_success(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
return_value={"ok": True},
),
patch.object(
completion_module.helper,
"compact_generate_response",
return_value=("ok", 200),
),
):
result = method(completion_app)
assert result == ("ok", 200)
def test_post_wrong_app_mode(self):
api = completion_module.CompletionApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
with pytest.raises(NotCompletionAppError):
method(installed_app)
def test_conversation_completed(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.conversation.ConversationCompletedError(),
),
):
with pytest.raises(ConversationCompletedError):
method(completion_app)
def test_internal_error(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
method(completion_app)
def test_conversation_not_exists(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(),
),
):
with pytest.raises(completion_module.NotFound):
method(completion_app)
def test_app_unavailable(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(),
),
):
with pytest.raises(completion_module.AppUnavailableError):
method(completion_app)
def test_provider_not_initialized(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.ProviderTokenNotInitError("not init"),
),
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(completion_app)
def test_quota_exceeded(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.QuotaExceededError(),
),
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(completion_app)
def test_model_not_supported(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
method(completion_app)
def test_invoke_error(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.InvokeError("invoke failed"),
),
):
with pytest.raises(completion_module.CompletionRequestError):
method(completion_app)
class TestCompletionStopApi:
def test_stop_success(self, completion_app, user):
api = completion_module.CompletionStopApi()
method = unwrap(api.post)
user.id = "u1"
with (
patch.object(completion_module, "current_user", user),
patch.object(completion_module.AppTaskService, "stop_task"),
):
resp, status = method(completion_app, "task-1")
assert status == 200
assert resp == {"result": "success"}
def test_stop_wrong_app_mode(self):
api = completion_module.CompletionStopApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
with pytest.raises(NotCompletionAppError):
method(installed_app, "task")
class TestChatApi:
def test_post_success(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
return_value={"ok": True},
),
patch.object(
completion_module.helper,
"compact_generate_response",
return_value=("ok", 200),
),
):
result = method(chat_app)
assert result == ("ok", 200)
def test_post_not_chat_app(self):
api = completion_module.ChatApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
with pytest.raises(NotChatAppError):
method(installed_app)
def test_rate_limit_error(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(chat_app)
def test_conversation_completed_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.conversation.ConversationCompletedError(),
),
):
with pytest.raises(ConversationCompletedError):
method(chat_app)
def test_conversation_not_exists_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(),
),
):
with pytest.raises(completion_module.NotFound):
method(chat_app)
def test_app_unavailable_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(),
),
):
with pytest.raises(completion_module.AppUnavailableError):
method(chat_app)
def test_provider_not_initialized_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.ProviderTokenNotInitError("not init"),
),
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(chat_app)
def test_quota_exceeded_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.QuotaExceededError(),
),
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(chat_app)
def test_model_not_supported_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
method(chat_app)
def test_invoke_error_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.InvokeError("invoke failed"),
),
):
with pytest.raises(completion_module.CompletionRequestError):
method(chat_app)
def test_internal_error_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
method(chat_app)
class TestChatStopApi:
def test_stop_success(self, chat_app, user):
api = completion_module.ChatStopApi()
method = unwrap(api.post)
user.id = "u1"
with (
patch.object(completion_module, "current_user", user),
patch.object(completion_module.AppTaskService, "stop_task"),
):
resp, status = method(chat_app, "task-1")
assert status == 200
assert resp == {"result": "success"}
def test_stop_not_chat_app(self):
api = completion_module.ChatStopApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
with pytest.raises(NotChatAppError):
method(installed_app, "task")

View File

@ -0,0 +1,232 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.conversation as conversation_module
from controllers.console.explore.error import NotChatAppError
from models import Account
from models.model import AppMode
from services.errors.conversation import (
ConversationNotExistsError,
LastConversationNotExistsError,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class FakeConversation:
def __init__(self, cid):
self.id = cid
self.name = "test"
self.inputs = {}
self.status = "normal"
self.introduction = ""
@pytest.fixture
def chat_app():
app_model = MagicMock(mode=AppMode.CHAT, id="app-id")
return MagicMock(app=app_model)
@pytest.fixture
def non_chat_app():
app_model = MagicMock(mode=AppMode.COMPLETION)
return MagicMock(app=app_model)
@pytest.fixture
def user():
user = MagicMock(spec=Account)
user.id = "uid"
return user
@pytest.fixture(autouse=True)
def mock_db_and_session():
with (
patch.object(
conversation_module,
"db",
MagicMock(session=MagicMock(), engine=MagicMock()),
),
patch(
"controllers.console.explore.conversation.Session",
MagicMock(),
),
):
yield
class TestConversationListApi:
def test_get_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
pagination = MagicMock(
limit=20,
has_more=False,
data=[FakeConversation("c1"), FakeConversation("c2")],
)
with (
app.test_request_context("/?limit=20"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pagination_by_last_id",
return_value=pagination,
),
):
result = method(chat_app)
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pagination_by_last_id",
side_effect=LastConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(chat_app)
def test_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(NotChatAppError):
method(non_chat_app)
class TestConversationApi:
def test_delete_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"delete",
),
):
result = method(chat_app, "cid")
body, status = result
assert status == 204
assert body["result"] == "success"
def test_delete_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"delete",
side_effect=ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(chat_app, "cid")
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with pytest.raises(NotChatAppError):
method(non_chat_app, "cid")
class TestConversationRenameApi:
def test_rename_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
conversation = FakeConversation("cid")
with (
app.test_request_context("/", json={"name": "new"}),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"rename",
return_value=conversation,
),
):
result = method(chat_app, "cid")
assert result["id"] == "cid"
def test_rename_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "new"}),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"rename",
side_effect=ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(chat_app, "cid")
class TestConversationPinApi:
def test_pin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pin",
),
):
result = method(chat_app, "cid")
assert result == {"result": "success"}
class TestConversationUnPinApi:
def test_unpin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"unpin",
),
):
result = method(chat_app, "cid")
assert result == {"result": "success"}

View File

@ -0,0 +1,363 @@
from datetime import datetime
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import controllers.console.explore.installed_app as module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def tenant_id():
return "t1"
@pytest.fixture
def current_user(tenant_id):
user = MagicMock()
user.id = "u1"
user.current_tenant = MagicMock(id=tenant_id)
return user
@pytest.fixture
def installed_app():
app = MagicMock()
app.id = "ia1"
app.app = MagicMock(id="a1")
app.app_owner_tenant_id = "t2"
app.is_pinned = False
app.last_used_at = datetime(2024, 1, 1)
return app
@pytest.fixture
def payload_patch():
def _patch(payload):
return patch.object(
type(module.console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
)
return _patch
class TestInstalledAppsListApi:
def test_get_installed_apps(self, app, current_user, tenant_id, installed_app):
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
assert "installed_apps" in result
assert result["installed_apps"][0]["editable"] is True
assert result["installed_apps"][0]["uninstallable"] is False
def test_get_installed_apps_with_app_id_filter(self, app, current_user, tenant_id):
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = []
with (
app.test_request_context("/?app_id=a1"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="member"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
assert result == {"installed_apps": []}
def test_get_installed_apps_with_webapp_auth_enabled(self, app, current_user, tenant_id, installed_app):
"""Test filtering when webapp_auth is enabled."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
mock_webapp_setting = MagicMock()
mock_webapp_setting.access_mode = "restricted"
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_get_app_access_mode_by_id",
return_value={"a1": mock_webapp_setting},
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_is_user_allowed_to_access_webapps",
return_value={"a1": True},
),
):
result = method(api)
assert len(result["installed_apps"]) == 1
def test_get_installed_apps_with_webapp_auth_user_denied(self, app, current_user, tenant_id, installed_app):
"""Test filtering when user doesn't have access."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
mock_webapp_setting = MagicMock()
mock_webapp_setting.access_mode = "restricted"
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="member"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_get_app_access_mode_by_id",
return_value={"a1": mock_webapp_setting},
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_is_user_allowed_to_access_webapps",
return_value={"a1": False},
),
):
result = method(api)
assert result["installed_apps"] == []
def test_get_installed_apps_with_sso_verified_access(self, app, current_user, tenant_id, installed_app):
"""Test that sso_verified access mode apps are skipped in filtering."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
mock_webapp_setting = MagicMock()
mock_webapp_setting.access_mode = "sso_verified"
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_get_app_access_mode_by_id",
return_value={"a1": mock_webapp_setting},
),
):
result = method(api)
assert len(result["installed_apps"]) == 0
def test_get_installed_apps_filters_null_apps(self, app, current_user, tenant_id):
"""Test that installed apps with null app are filtered out."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
installed_app_with_null = MagicMock()
installed_app_with_null.app = None
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app_with_null]
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
assert result["installed_apps"] == []
def test_get_installed_apps_current_tenant_none(self, app, tenant_id, installed_app):
"""Test error when current_user.current_tenant is None."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
current_user = MagicMock()
current_user.current_tenant = None
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
):
with pytest.raises(ValueError, match="current_user.current_tenant must not be None"):
method(api)
class TestInstalledAppsCreateApi:
def test_post_success(self, app, tenant_id, payload_patch):
api = module.InstalledAppsListApi()
method = unwrap(api.post)
recommended = MagicMock()
recommended.install_count = 0
app_entity = MagicMock()
app_entity.id = "a1"
app_entity.is_public = True
app_entity.tenant_id = "t2"
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
None,
]
with (
app.test_request_context("/", json={"app_id": "a1"}),
payload_patch({"app_id": "a1"}),
patch.object(module.db, "session", session),
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
):
result = method(api)
assert result == {"message": "App installed successfully"}
assert recommended.install_count == 1
def test_post_recommended_not_found(self, app, payload_patch):
api = module.InstalledAppsListApi()
method = unwrap(api.post)
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
with (
app.test_request_context("/", json={"app_id": "a1"}),
payload_patch({"app_id": "a1"}),
patch.object(module.db, "session", session),
):
with pytest.raises(NotFound):
method(api)
def test_post_app_not_public(self, app, tenant_id, payload_patch):
api = module.InstalledAppsListApi()
method = unwrap(api.post)
recommended = MagicMock()
app_entity = MagicMock(is_public=False)
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
]
with (
app.test_request_context("/", json={"app_id": "a1"}),
payload_patch({"app_id": "a1"}),
patch.object(module.db, "session", session),
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
):
with pytest.raises(Forbidden):
method(api)
class TestInstalledAppApi:
def test_delete_success(self, tenant_id, installed_app):
api = module.InstalledAppApi()
method = unwrap(api.delete)
with (
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
patch.object(module.db, "session"),
):
resp, status = method(installed_app)
assert status == 204
assert resp["result"] == "success"
def test_delete_owned_by_current_tenant(self, tenant_id):
api = module.InstalledAppApi()
method = unwrap(api.delete)
installed_app = MagicMock(app_owner_tenant_id=tenant_id)
with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)):
with pytest.raises(BadRequest):
method(installed_app)
def test_patch_update_pin(self, app, payload_patch, installed_app):
api = module.InstalledAppApi()
method = unwrap(api.patch)
with (
app.test_request_context("/", json={"is_pinned": True}),
payload_patch({"is_pinned": True}),
patch.object(module.db, "session"),
):
result = method(installed_app)
assert installed_app.is_pinned is True
assert result["result"] == "success"
def test_patch_no_change(self, app, payload_patch, installed_app):
api = module.InstalledAppApi()
method = unwrap(api.patch)
with app.test_request_context("/", json={}), payload_patch({}), patch.object(module.db, "session"):
result = method(installed_app)
assert result["result"] == "success"

View File

@ -0,0 +1,552 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import InternalServerError, NotFound
import controllers.console.explore.message as module
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
)
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import (
FirstMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
def unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def make_message():
msg = MagicMock()
msg.id = "m1"
msg.conversation_id = "11111111-1111-1111-1111-111111111111"
msg.parent_message_id = None
msg.inputs = {}
msg.query = "hello"
msg.re_sign_file_url_answer = ""
msg.user_feedback = MagicMock(rating=None)
msg.status = "success"
msg.error = None
return msg
class TestMessageListApi:
def test_get_success(self, app):
api = module.MessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
pagination = MagicMock(
limit=20,
has_more=False,
data=[make_message(), make_message()],
)
with (
app.test_request_context(
"/",
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"pagination_by_first_id",
return_value=pagination,
),
):
result = method(installed_app)
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_get_not_chat_app(self):
api = module.MessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotChatAppError):
method(installed_app)
def test_conversation_not_exists(self, app):
api = module.MessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
app.test_request_context(
"/",
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"pagination_by_first_id",
side_effect=ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app)
def test_first_message_not_exists(self, app):
api = module.MessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
app.test_request_context(
"/",
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"pagination_by_first_id",
side_effect=FirstMessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app)
class TestMessageFeedbackApi:
def test_post_success(self, app):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
installed_app = MagicMock()
installed_app.app = MagicMock()
with (
app.test_request_context("/", json={"rating": "like"}),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"create_feedback",
),
):
result = method(installed_app, "mid")
assert result["result"] == "success"
def test_message_not_exists(self, app):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
installed_app = MagicMock()
installed_app.app = MagicMock()
with (
app.test_request_context("/", json={}),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"create_feedback",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app, "mid")
class TestMessageMoreLikeThisApi:
def test_get_success(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
return_value={"ok": True},
),
patch.object(
module.helper,
"compact_generate_response",
return_value=("ok", 200),
),
):
resp = method(installed_app, "mid")
assert resp == ("ok", 200)
def test_not_completion_app(self):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotCompletionAppError):
method(installed_app, "mid")
def test_more_like_this_disabled(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=module.MoreLikeThisDisabledError(),
),
):
with pytest.raises(AppMoreLikeThisDisabledError):
method(installed_app, "mid")
def test_message_not_exists_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_provider_not_init_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=ProviderTokenNotInitError("test"),
),
):
with pytest.raises(ProviderNotInitializeError):
method(installed_app, "mid")
def test_quota_exceeded_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=QuotaExceededError(),
),
):
with pytest.raises(ProviderQuotaExceededError):
method(installed_app, "mid")
def test_model_not_support_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(installed_app, "mid")
def test_invoke_error_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=InvokeError("test error"),
),
):
with pytest.raises(CompletionRequestError):
method(installed_app, "mid")
def test_unexpected_error_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=Exception("unexpected"),
),
):
with pytest.raises(InternalServerError):
method(installed_app, "mid")
class TestMessageSuggestedQuestionApi:
def test_get_success(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
return_value=["q1", "q2"],
),
):
result = method(installed_app, "mid")
assert result["data"] == ["q1", "q2"]
def test_not_chat_app(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotChatAppError):
method(installed_app, "mid")
def test_disabled(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=SuggestedQuestionsAfterAnswerDisabledError(),
),
):
with pytest.raises(AppSuggestedQuestionsAfterAnswerDisabledError):
method(installed_app, "mid")
def test_message_not_exists_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_conversation_not_exists_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_provider_not_init_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=ProviderTokenNotInitError("test"),
),
):
with pytest.raises(ProviderNotInitializeError):
method(installed_app, "mid")
def test_quota_exceeded_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=QuotaExceededError(),
),
):
with pytest.raises(ProviderQuotaExceededError):
method(installed_app, "mid")
def test_model_not_support_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(installed_app, "mid")
def test_invoke_error_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=InvokeError("test error"),
),
):
with pytest.raises(CompletionRequestError):
method(installed_app, "mid")
def test_unexpected_error_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=Exception("unexpected"),
),
):
with pytest.raises(InternalServerError):
method(installed_app, "mid")

View File

@ -0,0 +1,140 @@
from unittest.mock import MagicMock, patch
import pytest
import controllers.console.explore.parameter as module
from controllers.console.app.error import AppUnavailableError
from models.model import AppMode
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAppParameterApi:
def test_get_app_none(self):
api = module.AppParameterApi()
method = unwrap(api.get)
installed_app = MagicMock(app=None)
with pytest.raises(AppUnavailableError):
method(installed_app)
def test_get_advanced_chat_workflow(self):
api = module.AppParameterApi()
method = unwrap(api.get)
workflow = MagicMock()
workflow.features_dict = {"f": "v"}
workflow.user_input_form.return_value = [{"name": "x"}]
app = MagicMock(
mode=AppMode.ADVANCED_CHAT,
workflow=workflow,
)
installed_app = MagicMock(app=app)
with (
patch.object(
module,
"get_parameters_from_feature_dict",
return_value={"any": "thing"},
),
patch.object(
module.fields.Parameters,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: {"ok": True}),
),
):
result = method(installed_app)
assert result == {"ok": True}
def test_get_advanced_chat_workflow_missing(self):
api = module.AppParameterApi()
method = unwrap(api.get)
app = MagicMock(
mode=AppMode.ADVANCED_CHAT,
workflow=None,
)
installed_app = MagicMock(app=app)
with pytest.raises(AppUnavailableError):
method(installed_app)
def test_get_non_workflow_app(self):
api = module.AppParameterApi()
method = unwrap(api.get)
app_model_config = MagicMock()
app_model_config.to_dict.return_value = {"user_input_form": [{"name": "y"}]}
app = MagicMock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
installed_app = MagicMock(app=app)
with (
patch.object(
module,
"get_parameters_from_feature_dict",
return_value={"whatever": 123},
),
patch.object(
module.fields.Parameters,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: {"ok": True}),
),
):
result = method(installed_app)
assert result == {"ok": True}
def test_get_non_workflow_missing_config(self):
api = module.AppParameterApi()
method = unwrap(api.get)
app = MagicMock(
mode=AppMode.CHAT,
app_model_config=None,
)
installed_app = MagicMock(app=app)
with pytest.raises(AppUnavailableError):
method(installed_app)
class TestExploreAppMetaApi:
def test_get_meta_success(self):
api = module.ExploreAppMetaApi()
method = unwrap(api.get)
app = MagicMock()
installed_app = MagicMock(app=app)
with patch.object(
module.AppService,
"get_app_meta",
return_value={"meta": "ok"},
):
result = method(installed_app)
assert result == {"meta": "ok"}
def test_get_meta_app_missing(self):
api = module.ExploreAppMetaApi()
method = unwrap(api.get)
installed_app = MagicMock(app=None)
with pytest.raises(ValueError):
method(installed_app)

View File

@ -0,0 +1,92 @@
from unittest.mock import MagicMock, patch
import controllers.console.explore.recommended_app as module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestRecommendedAppListApi:
def test_get_with_language_param(self, app):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
result_data = {"recommended_apps": [], "categories": []}
with (
app.test_request_context("/", query_string={"language": "en-US"}),
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
patch.object(
module.RecommendedAppService,
"get_recommended_apps_and_categories",
return_value=result_data,
) as service_mock,
):
result = method(api)
service_mock.assert_called_once_with("en-US")
assert result == result_data
def test_get_fallback_to_user_language(self, app):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
result_data = {"recommended_apps": [], "categories": []}
with (
app.test_request_context("/", query_string={"language": "invalid"}),
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
patch.object(
module.RecommendedAppService,
"get_recommended_apps_and_categories",
return_value=result_data,
) as service_mock,
):
result = method(api)
service_mock.assert_called_once_with("fr-FR")
assert result == result_data
def test_get_fallback_to_default_language(self, app):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
result_data = {"recommended_apps": [], "categories": []}
with (
app.test_request_context("/"),
patch.object(module, "current_user", MagicMock(interface_language=None)),
patch.object(
module.RecommendedAppService,
"get_recommended_apps_and_categories",
return_value=result_data,
) as service_mock,
):
result = method(api)
service_mock.assert_called_once_with(module.languages[0])
assert result == result_data
class TestRecommendedAppApi:
def test_get_success(self, app):
api = module.RecommendedAppApi()
method = unwrap(api.get)
result_data = {"id": "app1"}
with (
app.test_request_context("/"),
patch.object(
module.RecommendedAppService,
"get_recommend_app_detail",
return_value=result_data,
) as service_mock,
):
result = method(api, "11111111-1111-1111-1111-111111111111")
service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111")
assert result == result_data

View File

@ -0,0 +1,154 @@
from unittest.mock import MagicMock, PropertyMock, patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import NotFound
import controllers.console.explore.saved_message as module
from controllers.console.explore.error import NotCompletionAppError
from services.errors.message import MessageNotExistsError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def make_saved_message():
msg = MagicMock()
msg.id = str(uuid4())
msg.message_id = str(uuid4())
msg.app_id = str(uuid4())
msg.inputs = {}
msg.query = "hello"
msg.answer = "world"
msg.user_feedback = MagicMock(rating="like")
msg.created_at = None
return msg
@pytest.fixture
def payload_patch():
def _patch(payload):
return patch.object(
type(module.console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
)
return _patch
class TestSavedMessageListApi:
def test_get_success(self, app):
api = module.SavedMessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
pagination = MagicMock(
limit=20,
has_more=False,
data=[make_saved_message(), make_saved_message()],
)
with (
app.test_request_context("/", query_string={}),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.SavedMessageService,
"pagination_by_last_id",
return_value=pagination,
),
):
result = method(installed_app)
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_get_not_completion_app(self):
api = module.SavedMessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotCompletionAppError):
method(installed_app)
def test_post_success(self, app, payload_patch):
api = module.SavedMessageListApi()
method = unwrap(api.post)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
payload = {"message_id": str(uuid4())}
with (
app.test_request_context("/", json=payload),
payload_patch(payload),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(module.SavedMessageService, "save") as save_mock,
):
result = method(installed_app)
save_mock.assert_called_once()
assert result == {"result": "success"}
def test_post_message_not_exists(self, app, payload_patch):
api = module.SavedMessageListApi()
method = unwrap(api.post)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
payload = {"message_id": str(uuid4())}
with (
app.test_request_context("/", json=payload),
payload_patch(payload),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.SavedMessageService,
"save",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app)
class TestSavedMessageApi:
def test_delete_success(self):
api = module.SavedMessageApi()
method = unwrap(api.delete)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(module.SavedMessageService, "delete") as delete_mock,
):
result, status = method(installed_app, str(uuid4()))
delete_mock.assert_called_once()
assert status == 204
assert result == {"result": "success"}
def test_delete_not_completion_app(self):
api = module.SavedMessageApi()
method = unwrap(api.delete)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotCompletionAppError):
method(installed_app, str(uuid4()))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,151 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import InternalServerError
from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.workflow import (
InstalledAppWorkflowRunApi,
InstalledAppWorkflowTaskStopApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from models.model import AppMode
from services.errors.llm import InvokeRateLimitError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def user():
return MagicMock()
@pytest.fixture
def workflow_app():
app = MagicMock()
app.mode = AppMode.WORKFLOW
return app
@pytest.fixture
def installed_workflow_app(workflow_app):
return MagicMock(app=workflow_app)
@pytest.fixture
def non_workflow_installed_app():
app = MagicMock()
app.mode = AppMode.CHAT
return MagicMock(app=app)
@pytest.fixture
def payload():
return {"inputs": {"a": 1}}
class TestInstalledAppWorkflowRunApi:
def test_not_workflow_app(self, app, non_workflow_installed_app):
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(NotWorkflowAppError):
method(non_workflow_installed_app)
def test_success(self, app, installed_workflow_app, user, payload):
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
return_value=MagicMock(),
) as generate_mock,
):
result = method(installed_workflow_app)
generate_mock.assert_called_once()
assert result is not None
def test_rate_limit_error(self, app, installed_workflow_app, user, payload):
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
side_effect=InvokeRateLimitError("rate limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(installed_workflow_app)
def test_unexpected_exception(self, app, installed_workflow_app, user, payload):
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
method(installed_workflow_app)
class TestInstalledAppWorkflowTaskStopApi:
def test_not_workflow_app(self, non_workflow_installed_app):
api = InstalledAppWorkflowTaskStopApi()
method = unwrap(api.post)
with pytest.raises(NotWorkflowAppError):
method(non_workflow_installed_app, "task-1")
def test_success(self, installed_workflow_app):
api = InstalledAppWorkflowTaskStopApi()
method = unwrap(api.post)
with (
patch("controllers.console.explore.workflow.AppQueueManager.set_stop_flag_no_user_check") as stop_flag,
patch("controllers.console.explore.workflow.GraphEngineManager.send_stop_command") as send_stop,
):
result = method(installed_workflow_app, "task-1")
stop_flag.assert_called_once_with("task-1")
send_stop.assert_called_once_with("task-1")
assert result == {"result": "success"}

View File

@ -0,0 +1,244 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console.explore.error import (
AppAccessDeniedError,
TrialAppLimitExceeded,
TrialAppNotAllowed,
)
from controllers.console.explore.wraps import (
InstalledAppResource,
TrialAppResource,
installed_app_required,
trial_app_required,
trial_feature_enable,
user_allowed_to_access_app,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_installed_app_required_not_found():
@installed_app_required
def view(installed_app):
return "ok"
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = None
with pytest.raises(NotFound):
view("app-id")
def test_installed_app_required_app_deleted():
installed_app = MagicMock(app=None)
@installed_app_required
def view(installed_app):
return "ok"
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.delete"),
patch("controllers.console.explore.wraps.db.session.commit"),
):
q.return_value.where.return_value.first.return_value = installed_app
with pytest.raises(NotFound):
view("app-id")
def test_installed_app_required_success():
installed_app = MagicMock(app=MagicMock())
@installed_app_required
def view(installed_app):
return installed_app
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = installed_app
result = view("app-id")
assert result == installed_app
def test_user_allowed_to_access_app_denied():
installed_app = MagicMock(app_id="app-1")
@user_allowed_to_access_app
def view(installed_app):
return "ok"
feature = MagicMock()
feature.webapp_auth.enabled = True
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch(
"controllers.console.explore.wraps.FeatureService.get_system_features",
return_value=feature,
),
patch(
"controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp",
return_value=False,
),
):
with pytest.raises(AppAccessDeniedError):
view(installed_app)
def test_user_allowed_to_access_app_success():
installed_app = MagicMock(app_id="app-1")
@user_allowed_to_access_app
def view(installed_app):
return "ok"
feature = MagicMock()
feature.webapp_auth.enabled = True
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch(
"controllers.console.explore.wraps.FeatureService.get_system_features",
return_value=feature,
),
patch(
"controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp",
return_value=True,
),
):
assert view(installed_app) == "ok"
def test_trial_app_required_not_allowed():
@trial_app_required
def view(app):
return "ok"
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = None
with pytest.raises(TrialAppNotAllowed):
view("app-id")
def test_trial_app_required_limit_exceeded():
trial_app = MagicMock(trial_limit=1, app=MagicMock())
record = MagicMock(count=1)
@trial_app_required
def view(app):
return "ok"
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.side_effect = [
trial_app,
record,
]
with pytest.raises(TrialAppLimitExceeded):
view("app-id")
def test_trial_app_required_success():
trial_app = MagicMock(trial_limit=2, app=MagicMock())
record = MagicMock(count=1)
@trial_app_required
def view(app):
return app
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.side_effect = [
trial_app,
record,
]
result = view("app-id")
assert result == trial_app.app
def test_trial_feature_enable_disabled():
@trial_feature_enable
def view():
return "ok"
features = MagicMock(enable_trial_app=False)
with patch(
"controllers.console.explore.wraps.FeatureService.get_system_features",
return_value=features,
):
with pytest.raises(Forbidden):
view()
def test_trial_feature_enable_enabled():
@trial_feature_enable
def view():
return "ok"
features = MagicMock(enable_trial_app=True)
with patch(
"controllers.console.explore.wraps.FeatureService.get_system_features",
return_value=features,
):
assert view() == "ok"
def test_installed_app_resource_decorators():
decorators = InstalledAppResource.method_decorators
assert len(decorators) == 4
def test_trial_app_resource_decorators():
decorators = TrialAppResource.method_decorators
assert len(decorators) == 3

View File

@ -0,0 +1,278 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.tag.tags import (
TagBindingCreateApi,
TagBindingDeleteApi,
TagListApi,
TagUpdateDeleteApi,
)
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_tag")
app.config["TESTING"] = True
return app
@pytest.fixture
def admin_user():
return MagicMock(
id="user-1",
has_edit_permission=True,
is_dataset_editor=True,
)
@pytest.fixture
def readonly_user():
return MagicMock(
id="user-2",
has_edit_permission=False,
is_dataset_editor=False,
)
@pytest.fixture
def tag():
tag = MagicMock()
tag.id = "tag-1"
tag.name = "test-tag"
tag.type = "knowledge"
return tag
@pytest.fixture
def payload_patch():
def _patch(payload):
return patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
)
return _patch
class TestTagListApi:
def test_get_success(self, app):
api = TagListApi()
method = unwrap(api.get)
with app.test_request_context("/?type=knowledge"):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.tag.tags.TagService.get_tags",
return_value=[{"id": "1", "name": "tag"}],
),
):
result, status = method(api)
assert status == 200
assert isinstance(result, list)
def test_post_success(self, app, admin_user, tag, payload_patch):
api = TagListApi()
method = unwrap(api.post)
payload = {"name": "test-tag", "type": "knowledge"}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch(
"controllers.console.tag.tags.TagService.save_tags",
return_value=tag,
),
):
result, status = method(api)
assert status == 200
assert result["name"] == "test-tag"
def test_post_forbidden(self, app, readonly_user, payload_patch):
api = TagListApi()
method = unwrap(api.post)
payload = {"name": "x"}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch(payload),
):
with pytest.raises(Forbidden):
method(api)
class TestTagUpdateDeleteApi:
def test_patch_success(self, app, admin_user, tag, payload_patch):
api = TagUpdateDeleteApi()
method = unwrap(api.patch)
payload = {"name": "updated", "type": "knowledge"}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch(
"controllers.console.tag.tags.TagService.update_tags",
return_value=tag,
),
patch(
"controllers.console.tag.tags.TagService.get_tag_binding_count",
return_value=3,
),
):
result, status = method(api, "tag-1")
assert status == 200
assert result["binding_count"] == 3
def test_patch_forbidden(self, app, readonly_user, payload_patch):
api = TagUpdateDeleteApi()
method = unwrap(api.patch)
payload = {"name": "x"}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch(payload),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
def test_delete_success(self, app, admin_user):
api = TagUpdateDeleteApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, "tenant-1"),
),
patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock,
):
result, status = method(api, "tag-1")
delete_mock.assert_called_once_with("tag-1")
assert status == 204
class TestTagBindingCreateApi:
def test_create_success(self, app, admin_user, payload_patch):
api = TagBindingCreateApi()
method = unwrap(api.post)
payload = {
"tag_ids": ["tag-1"],
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
save_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingCreateApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch({}),
):
with pytest.raises(Forbidden):
method(api)
class TestTagBindingDeleteApi:
def test_remove_success(self, app, admin_user, payload_patch):
api = TagBindingDeleteApi()
method = unwrap(api.post)
payload = {
"tag_id": "tag-1",
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api)
delete_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingDeleteApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch({}),
):
with pytest.raises(Forbidden):
method(api)