mirror of https://github.com/langgenius/dify.git
test: unit test cases for console.explore and tag module (#32186)
This commit is contained in:
parent
4f835107b2
commit
01991f3536
|
|
@ -0,0 +1,402 @@
|
|||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import controllers.console.explore.audio as audio_module
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installed_app():
|
||||
app = MagicMock()
|
||||
app.app = MagicMock()
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_file():
|
||||
return (BytesIO(b"audio"), "audio.wav")
|
||||
|
||||
|
||||
class TestChatAudioApi:
|
||||
def setup_method(self):
|
||||
self.api = audio_module.ChatAudioApi()
|
||||
self.method = unwrap(self.api.post)
|
||||
|
||||
def test_post_success(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
return_value={"text": "ok"},
|
||||
),
|
||||
):
|
||||
resp = self.method(installed_app)
|
||||
|
||||
assert resp == {"text": "ok"}
|
||||
|
||||
def test_app_unavailable(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_no_audio_uploaded(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=NoAudioUploadedServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NoAudioUploadedError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_audio_too_large(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=AudioTooLargeServiceError("too big"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_provider_quota_exceeded(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_unknown_exception(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_unsupported_audio_type(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=audio_module.UnsupportedAudioTypeServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(audio_module.UnsupportedAudioTypeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_provider_not_support_speech_to_text(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_provider_not_initialized(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=ProviderTokenNotInitError("not init"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_model_currently_not_supported(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_invoke_error_asr(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=InvokeError("invoke failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
self.method(installed_app)
|
||||
|
||||
|
||||
class TestChatTextApi:
|
||||
def setup_method(self):
|
||||
self.api = audio_module.ChatTextApi()
|
||||
self.method = unwrap(self.api.post)
|
||||
|
||||
def test_post_success(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"message_id": "m1", "text": "hello", "voice": "v1"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
return_value={"audio": "ok"},
|
||||
),
|
||||
):
|
||||
resp = self.method(installed_app)
|
||||
|
||||
assert resp == {"audio": "ok"}
|
||||
|
||||
def test_provider_not_initialized(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=ProviderTokenNotInitError("not init"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_model_not_supported(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_invoke_error(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=InvokeError("invoke failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_unknown_exception(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_app_unavailable_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_no_audio_uploaded_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=NoAudioUploadedServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NoAudioUploadedError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_audio_too_large_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=AudioTooLargeServiceError("too big"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_unsupported_audio_type_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=audio_module.UnsupportedAudioTypeServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(audio_module.UnsupportedAudioTypeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_provider_not_support_speech_to_text_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_quota_exceeded_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
self.method(installed_app)
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.explore.banner as banner_module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestBannerApi:
|
||||
def test_get_banners_with_requested_language(self, app):
|
||||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
banner = MagicMock()
|
||||
banner.id = "b1"
|
||||
banner.content = {"text": "hello"}
|
||||
banner.link = "https://example.com"
|
||||
banner.sort = 1
|
||||
banner.status = "enabled"
|
||||
banner.created_at = datetime(2024, 1, 1)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = [banner]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"id": "b1",
|
||||
"content": {"text": "hello"},
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"status": "enabled",
|
||||
"created_at": "2024-01-01T00:00:00",
|
||||
}
|
||||
]
|
||||
|
||||
def test_get_banners_fallback_to_en_us(self, app):
|
||||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
banner = MagicMock()
|
||||
banner.id = "b2"
|
||||
banner.content = {"text": "fallback"}
|
||||
banner.link = None
|
||||
banner.sort = 1
|
||||
banner.status = "enabled"
|
||||
banner.created_at = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.side_effect = [
|
||||
[],
|
||||
[banner],
|
||||
]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"id": "b2",
|
||||
"content": {"text": "fallback"},
|
||||
"link": None,
|
||||
"sort": 1,
|
||||
"status": "enabled",
|
||||
"created_at": None,
|
||||
}
|
||||
]
|
||||
|
||||
def test_get_banners_default_language_en_us(self, app):
|
||||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = []
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
||||
assert result == []
|
||||
|
|
@ -0,0 +1,459 @@
|
|||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import controllers.console.explore.completion as completion_module
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
return MagicMock(spec=Account)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completion_app():
|
||||
return MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_app():
|
||||
return MagicMock(app=MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_data():
|
||||
return {"inputs": {}, "query": "hi"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch(payload_data):
|
||||
return patch.object(
|
||||
type(completion_module.console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload_data,
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionApi:
|
||||
def test_post_success(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
patch.object(
|
||||
completion_module.helper,
|
||||
"compact_generate_response",
|
||||
return_value=("ok", 200),
|
||||
),
|
||||
):
|
||||
result = method(completion_app)
|
||||
|
||||
assert result == ("ok", 200)
|
||||
|
||||
def test_post_wrong_app_mode(self):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app)
|
||||
|
||||
def test_conversation_completed(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.conversation.ConversationCompletedError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ConversationCompletedError):
|
||||
method(completion_app)
|
||||
|
||||
def test_internal_error(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(completion_app)
|
||||
|
||||
def test_conversation_not_exists(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.NotFound):
|
||||
method(completion_app)
|
||||
|
||||
def test_app_unavailable(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.AppUnavailableError):
|
||||
method(completion_app)
|
||||
|
||||
def test_provider_not_initialized(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.ProviderTokenNotInitError("not init"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(completion_app)
|
||||
|
||||
def test_quota_exceeded(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(completion_app)
|
||||
|
||||
def test_model_not_supported(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
|
||||
method(completion_app)
|
||||
|
||||
def test_invoke_error(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.InvokeError("invoke failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.CompletionRequestError):
|
||||
method(completion_app)
|
||||
|
||||
|
||||
class TestCompletionStopApi:
|
||||
def test_stop_success(self, completion_app, user):
|
||||
api = completion_module.CompletionStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user.id = "u1"
|
||||
|
||||
with (
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(completion_module.AppTaskService, "stop_task"),
|
||||
):
|
||||
resp, status = method(completion_app, "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_stop_wrong_app_mode(self):
|
||||
api = completion_module.CompletionStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app, "task")
|
||||
|
||||
|
||||
class TestChatApi:
|
||||
def test_post_success(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
patch.object(
|
||||
completion_module.helper,
|
||||
"compact_generate_response",
|
||||
return_value=("ok", 200),
|
||||
),
|
||||
):
|
||||
result = method(chat_app)
|
||||
|
||||
assert result == ("ok", 200)
|
||||
|
||||
def test_post_not_chat_app(self):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
|
||||
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app)
|
||||
|
||||
def test_rate_limit_error(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(chat_app)
|
||||
|
||||
def test_conversation_completed_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.conversation.ConversationCompletedError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ConversationCompletedError):
|
||||
method(chat_app)
|
||||
|
||||
def test_conversation_not_exists_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.NotFound):
|
||||
method(chat_app)
|
||||
|
||||
def test_app_unavailable_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.AppUnavailableError):
|
||||
method(chat_app)
|
||||
|
||||
def test_provider_not_initialized_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.ProviderTokenNotInitError("not init"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(chat_app)
|
||||
|
||||
def test_quota_exceeded_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(chat_app)
|
||||
|
||||
def test_model_not_supported_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
|
||||
method(chat_app)
|
||||
|
||||
def test_invoke_error_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.InvokeError("invoke failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.CompletionRequestError):
|
||||
method(chat_app)
|
||||
|
||||
def test_internal_error_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(chat_app)
|
||||
|
||||
|
||||
class TestChatStopApi:
|
||||
def test_stop_success(self, chat_app, user):
|
||||
api = completion_module.ChatStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user.id = "u1"
|
||||
|
||||
with (
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(completion_module.AppTaskService, "stop_task"),
|
||||
):
|
||||
resp, status = method(chat_app, "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_stop_not_chat_app(self):
|
||||
api = completion_module.ChatStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
|
||||
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app, "task")
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.conversation as conversation_module
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
LastConversationNotExistsError,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class FakeConversation:
|
||||
def __init__(self, cid):
|
||||
self.id = cid
|
||||
self.name = "test"
|
||||
self.inputs = {}
|
||||
self.status = "normal"
|
||||
self.introduction = ""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_app():
|
||||
app_model = MagicMock(mode=AppMode.CHAT, id="app-id")
|
||||
return MagicMock(app=app_model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_chat_app():
|
||||
app_model = MagicMock(mode=AppMode.COMPLETION)
|
||||
return MagicMock(app=app_model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "uid"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db_and_session():
|
||||
with (
|
||||
patch.object(
|
||||
conversation_module,
|
||||
"db",
|
||||
MagicMock(session=MagicMock(), engine=MagicMock()),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.conversation.Session",
|
||||
MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestConversationListApi:
|
||||
def test_get_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pagination = MagicMock(
|
||||
limit=20,
|
||||
has_more=False,
|
||||
data=[FakeConversation("c1"), FakeConversation("c2")],
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?limit=20"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pagination_by_last_id",
|
||||
return_value=pagination,
|
||||
),
|
||||
):
|
||||
result = method(chat_app)
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pagination_by_last_id",
|
||||
side_effect=LastConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app)
|
||||
|
||||
def test_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(non_chat_app)
|
||||
|
||||
|
||||
class TestConversationApi:
|
||||
def test_delete_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
|
||||
body, status = result
|
||||
assert status == 204
|
||||
assert body["result"] == "success"
|
||||
|
||||
def test_delete_not_found(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
|
||||
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(non_chat_app, "cid")
|
||||
|
||||
|
||||
class TestConversationRenameApi:
|
||||
def test_rename_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
conversation = FakeConversation("cid")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "new"}),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"rename",
|
||||
return_value=conversation,
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
|
||||
assert result["id"] == "cid"
|
||||
|
||||
def test_rename_not_found(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "new"}),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"rename",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
|
||||
|
||||
class TestConversationPinApi:
|
||||
def test_pin_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pin",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_unpin_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationUnPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"unpin",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
|
@ -0,0 +1,363 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
import controllers.console.explore.installed_app as module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
return "t1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user(tenant_id):
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.current_tenant = MagicMock(id=tenant_id)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installed_app():
|
||||
app = MagicMock()
|
||||
app.id = "ia1"
|
||||
app.app = MagicMock(id="a1")
|
||||
app.app_owner_tenant_id = "t2"
|
||||
app.is_pinned = False
|
||||
app.last_used_at = datetime(2024, 1, 1)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch():
|
||||
def _patch(payload):
|
||||
return patch.object(
|
||||
type(module.console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestInstalledAppsListApi:
|
||||
def test_get_installed_apps(self, app, current_user, tenant_id, installed_app):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "installed_apps" in result
|
||||
assert result["installed_apps"][0]["editable"] is True
|
||||
assert result["installed_apps"][0]["uninstallable"] is False
|
||||
|
||||
def test_get_installed_apps_with_app_id_filter(self, app, current_user, tenant_id):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
|
||||
with (
|
||||
app.test_request_context("/?app_id=a1"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="member"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == {"installed_apps": []}
|
||||
|
||||
def test_get_installed_apps_with_webapp_auth_enabled(self, app, current_user, tenant_id, installed_app):
|
||||
"""Test filtering when webapp_auth is enabled."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
mock_webapp_setting = MagicMock()
|
||||
mock_webapp_setting.access_mode = "restricted"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_get_app_access_mode_by_id",
|
||||
return_value={"a1": mock_webapp_setting},
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_is_user_allowed_to_access_webapps",
|
||||
return_value={"a1": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert len(result["installed_apps"]) == 1
|
||||
|
||||
def test_get_installed_apps_with_webapp_auth_user_denied(self, app, current_user, tenant_id, installed_app):
|
||||
"""Test filtering when user doesn't have access."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
mock_webapp_setting = MagicMock()
|
||||
mock_webapp_setting.access_mode = "restricted"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="member"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_get_app_access_mode_by_id",
|
||||
return_value={"a1": mock_webapp_setting},
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_is_user_allowed_to_access_webapps",
|
||||
return_value={"a1": False},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["installed_apps"] == []
|
||||
|
||||
def test_get_installed_apps_with_sso_verified_access(self, app, current_user, tenant_id, installed_app):
|
||||
"""Test that sso_verified access mode apps are skipped in filtering."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
mock_webapp_setting = MagicMock()
|
||||
mock_webapp_setting.access_mode = "sso_verified"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_get_app_access_mode_by_id",
|
||||
return_value={"a1": mock_webapp_setting},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert len(result["installed_apps"]) == 0
|
||||
|
||||
def test_get_installed_apps_filters_null_apps(self, app, current_user, tenant_id):
|
||||
"""Test that installed apps with null app are filtered out."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app_with_null = MagicMock()
|
||||
installed_app_with_null.app = None
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app_with_null]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["installed_apps"] == []
|
||||
|
||||
def test_get_installed_apps_current_tenant_none(self, app, tenant_id, installed_app):
|
||||
"""Test error when current_user.current_tenant is None."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
current_user = MagicMock()
|
||||
current_user.current_tenant = None
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
):
|
||||
with pytest.raises(ValueError, match="current_user.current_tenant must not be None"):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestInstalledAppsCreateApi:
|
||||
def test_post_success(self, app, tenant_id, payload_patch):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
recommended = MagicMock()
|
||||
recommended.install_count = 0
|
||||
|
||||
app_entity = MagicMock()
|
||||
app_entity.id = "a1"
|
||||
app_entity.is_public = True
|
||||
app_entity.tenant_id = "t2"
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
None,
|
||||
]
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
payload_patch({"app_id": "a1"}),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == {"message": "App installed successfully"}
|
||||
assert recommended.install_count == 1
|
||||
|
||||
def test_post_recommended_not_found(self, app, payload_patch):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
payload_patch({"app_id": "a1"}),
|
||||
patch.object(module.db, "session", session),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
|
||||
def test_post_app_not_public(self, app, tenant_id, payload_patch):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
recommended = MagicMock()
|
||||
app_entity = MagicMock(is_public=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
]
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
payload_patch({"app_id": "a1"}),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestInstalledAppApi:
|
||||
def test_delete_success(self, tenant_id, installed_app):
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
patch.object(module.db, "session"),
|
||||
):
|
||||
resp, status = method(installed_app)
|
||||
|
||||
assert status == 204
|
||||
assert resp["result"] == "success"
|
||||
|
||||
def test_delete_owned_by_current_tenant(self, tenant_id):
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
installed_app = MagicMock(app_owner_tenant_id=tenant_id)
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)):
|
||||
with pytest.raises(BadRequest):
|
||||
method(installed_app)
|
||||
|
||||
def test_patch_update_pin(self, app, payload_patch, installed_app):
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"is_pinned": True}),
|
||||
payload_patch({"is_pinned": True}),
|
||||
patch.object(module.db, "session"),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert installed_app.is_pinned is True
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_patch_no_change(self, app, payload_patch, installed_app):
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with app.test_request_context("/", json={}), payload_patch({}), patch.object(module.db, "session"):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
|
@ -0,0 +1,552 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import controllers.console.explore.message as module
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def make_message():
|
||||
msg = MagicMock()
|
||||
msg.id = "m1"
|
||||
msg.conversation_id = "11111111-1111-1111-1111-111111111111"
|
||||
msg.parent_message_id = None
|
||||
msg.inputs = {}
|
||||
msg.query = "hello"
|
||||
msg.re_sign_file_url_answer = ""
|
||||
msg.user_feedback = MagicMock(rating=None)
|
||||
msg.status = "success"
|
||||
msg.error = None
|
||||
return msg
|
||||
|
||||
|
||||
class TestMessageListApi:
|
||||
def test_get_success(self, app):
|
||||
api = module.MessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
pagination = MagicMock(
|
||||
limit=20,
|
||||
has_more=False,
|
||||
data=[make_message(), make_message()],
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"pagination_by_first_id",
|
||||
return_value=pagination,
|
||||
),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_get_not_chat_app(self):
|
||||
api = module.MessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app)
|
||||
|
||||
def test_conversation_not_exists(self, app):
|
||||
api = module.MessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"pagination_by_first_id",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app)
|
||||
|
||||
def test_first_message_not_exists(self, app):
|
||||
api = module.MessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"pagination_by_first_id",
|
||||
side_effect=FirstMessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app)
|
||||
|
||||
|
||||
class TestMessageFeedbackApi:
|
||||
def test_post_success(self, app):
|
||||
api = module.MessageFeedbackApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"rating": "like"}),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"create_feedback",
|
||||
),
|
||||
):
|
||||
result = method(installed_app, "mid")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_message_not_exists(self, app):
|
||||
api = module.MessageFeedbackApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"create_feedback",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app, "mid")
|
||||
|
||||
|
||||
class TestMessageMoreLikeThisApi:
|
||||
def test_get_success(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
patch.object(
|
||||
module.helper,
|
||||
"compact_generate_response",
|
||||
return_value=("ok", 200),
|
||||
),
|
||||
):
|
||||
resp = method(installed_app, "mid")
|
||||
|
||||
assert resp == ("ok", 200)
|
||||
|
||||
def test_not_completion_app(self):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_more_like_this_disabled(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=module.MoreLikeThisDisabledError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppMoreLikeThisDisabledError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_message_not_exists_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_provider_not_init_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=ProviderTokenNotInitError("test"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_quota_exceeded_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_model_not_support_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_invoke_error_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=InvokeError("test error"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_unexpected_error_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=Exception("unexpected"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
|
||||
class TestMessageSuggestedQuestionApi:
|
||||
def test_get_success(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
return_value=["q1", "q2"],
|
||||
),
|
||||
):
|
||||
result = method(installed_app, "mid")
|
||||
|
||||
assert result["data"] == ["q1", "q2"]
|
||||
|
||||
def test_not_chat_app(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_disabled(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=SuggestedQuestionsAfterAnswerDisabledError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppSuggestedQuestionsAfterAnswerDisabledError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_message_not_exists_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_conversation_not_exists_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_provider_not_init_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=ProviderTokenNotInitError("test"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_quota_exceeded_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_model_not_support_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_invoke_error_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=InvokeError("test error"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_unexpected_error_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=Exception("unexpected"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(installed_app, "mid")
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import controllers.console.explore.parameter as module
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
def test_get_app_none(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock(app=None)
|
||||
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(installed_app)
|
||||
|
||||
def test_get_advanced_chat_workflow(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.features_dict = {"f": "v"}
|
||||
workflow.user_input_form.return_value = [{"name": "x"}]
|
||||
|
||||
app = MagicMock(
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
module,
|
||||
"get_parameters_from_feature_dict",
|
||||
return_value={"any": "thing"},
|
||||
),
|
||||
patch.object(
|
||||
module.fields.Parameters,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: {"ok": True}),
|
||||
),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_get_advanced_chat_workflow_missing(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
app = MagicMock(
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
workflow=None,
|
||||
)
|
||||
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(installed_app)
|
||||
|
||||
def test_get_non_workflow_app(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.to_dict.return_value = {"user_input_form": [{"name": "y"}]}
|
||||
|
||||
app = MagicMock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
module,
|
||||
"get_parameters_from_feature_dict",
|
||||
return_value={"whatever": 123},
|
||||
),
|
||||
patch.object(
|
||||
module.fields.Parameters,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: {"ok": True}),
|
||||
),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_get_non_workflow_missing_config(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
app = MagicMock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(installed_app)
|
||||
|
||||
|
||||
class TestExploreAppMetaApi:
|
||||
def test_get_meta_success(self):
|
||||
api = module.ExploreAppMetaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
app = MagicMock()
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with patch.object(
|
||||
module.AppService,
|
||||
"get_app_meta",
|
||||
return_value={"meta": "ok"},
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result == {"meta": "ok"}
|
||||
|
||||
def test_get_meta_app_missing(self):
|
||||
api = module.ExploreAppMetaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock(app=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(installed_app)
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.explore.recommended_app as module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestRecommendedAppListApi:
|
||||
def test_get_with_language_param(self, app):
|
||||
api = module.RecommendedAppListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
result_data = {"recommended_apps": [], "categories": []}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"language": "en-US"}),
|
||||
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
service_mock.assert_called_once_with("en-US")
|
||||
assert result == result_data
|
||||
|
||||
def test_get_fallback_to_user_language(self, app):
|
||||
api = module.RecommendedAppListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
result_data = {"recommended_apps": [], "categories": []}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"language": "invalid"}),
|
||||
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
service_mock.assert_called_once_with("fr-FR")
|
||||
assert result == result_data
|
||||
|
||||
def test_get_fallback_to_default_language(self, app):
|
||||
api = module.RecommendedAppListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
result_data = {"recommended_apps": [], "categories": []}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_user", MagicMock(interface_language=None)),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
service_mock.assert_called_once_with(module.languages[0])
|
||||
assert result == result_data
|
||||
|
||||
|
||||
class TestRecommendedAppApi:
|
||||
def test_get_success(self, app):
|
||||
api = module.RecommendedAppApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
result_data = {"id": "app1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommend_app_detail",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api, "11111111-1111-1111-1111-111111111111")
|
||||
|
||||
service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111")
|
||||
assert result == result_data
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.saved_message as module
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def make_saved_message():
|
||||
msg = MagicMock()
|
||||
msg.id = str(uuid4())
|
||||
msg.message_id = str(uuid4())
|
||||
msg.app_id = str(uuid4())
|
||||
msg.inputs = {}
|
||||
msg.query = "hello"
|
||||
msg.answer = "world"
|
||||
msg.user_feedback = MagicMock(rating="like")
|
||||
msg.created_at = None
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch():
|
||||
def _patch(payload):
|
||||
return patch.object(
|
||||
type(module.console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestSavedMessageListApi:
|
||||
def test_get_success(self, app):
|
||||
api = module.SavedMessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
pagination = MagicMock(
|
||||
limit=20,
|
||||
has_more=False,
|
||||
data=[make_saved_message(), make_saved_message()],
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={}),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.SavedMessageService,
|
||||
"pagination_by_last_id",
|
||||
return_value=pagination,
|
||||
),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_get_not_completion_app(self):
|
||||
api = module.SavedMessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app)
|
||||
|
||||
def test_post_success(self, app, payload_patch):
|
||||
api = module.SavedMessageListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
payload = {"message_id": str(uuid4())}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
payload_patch(payload),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(module.SavedMessageService, "save") as save_mock,
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_post_message_not_exists(self, app, payload_patch):
|
||||
api = module.SavedMessageListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
payload = {"message_id": str(uuid4())}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
payload_patch(payload),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.SavedMessageService,
|
||||
"save",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app)
|
||||
|
||||
|
||||
class TestSavedMessageApi:
|
||||
def test_delete_success(self):
|
||||
api = module.SavedMessageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(module.SavedMessageService, "delete") as delete_mock,
|
||||
):
|
||||
result, status = method(installed_app, str(uuid4()))
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
assert status == 204
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_delete_not_completion_app(self):
|
||||
api = module.SavedMessageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app, str(uuid4()))
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,151 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.console.explore.error import NotWorkflowAppError
|
||||
from controllers.console.explore.workflow import (
|
||||
InstalledAppWorkflowRunApi,
|
||||
InstalledAppWorkflowTaskStopApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from models.model import AppMode
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_app():
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installed_workflow_app(workflow_app):
|
||||
return MagicMock(app=workflow_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_workflow_installed_app():
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.CHAT
|
||||
return MagicMock(app=app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload():
|
||||
return {"inputs": {"a": 1}}
|
||||
|
||||
|
||||
class TestInstalledAppWorkflowRunApi:
|
||||
def test_not_workflow_app(self, app, non_workflow_installed_app):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
method(non_workflow_installed_app)
|
||||
|
||||
def test_success(self, app, installed_workflow_app, user, payload):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
) as generate_mock,
|
||||
):
|
||||
result = method(installed_workflow_app)
|
||||
|
||||
generate_mock.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
def test_rate_limit_error(self, app, installed_workflow_app, user, payload):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("rate limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(installed_workflow_app)
|
||||
|
||||
def test_unexpected_exception(self, app, installed_workflow_app, user, payload):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(installed_workflow_app)
|
||||
|
||||
|
||||
class TestInstalledAppWorkflowTaskStopApi:
|
||||
def test_not_workflow_app(self, non_workflow_installed_app):
|
||||
api = InstalledAppWorkflowTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
method(non_workflow_installed_app, "task-1")
|
||||
|
||||
def test_success(self, installed_workflow_app):
|
||||
api = InstalledAppWorkflowTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
patch("controllers.console.explore.workflow.AppQueueManager.set_stop_flag_no_user_check") as stop_flag,
|
||||
patch("controllers.console.explore.workflow.GraphEngineManager.send_stop_command") as send_stop,
|
||||
):
|
||||
result = method(installed_workflow_app, "task-1")
|
||||
|
||||
stop_flag.assert_called_once_with("task-1")
|
||||
send_stop.assert_called_once_with("task-1")
|
||||
assert result == {"result": "success"}
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console.explore.error import (
|
||||
AppAccessDeniedError,
|
||||
TrialAppLimitExceeded,
|
||||
TrialAppNotAllowed,
|
||||
)
|
||||
from controllers.console.explore.wraps import (
|
||||
InstalledAppResource,
|
||||
TrialAppResource,
|
||||
installed_app_required,
|
||||
trial_app_required,
|
||||
trial_feature_enable,
|
||||
user_allowed_to_access_app,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_installed_app_required_not_found():
|
||||
@installed_app_required
|
||||
def view(installed_app):
|
||||
return "ok"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
|
||||
|
||||
def test_installed_app_required_app_deleted():
|
||||
installed_app = MagicMock(app=None)
|
||||
|
||||
@installed_app_required
|
||||
def view(installed_app):
|
||||
return "ok"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.delete"),
|
||||
patch("controllers.console.explore.wraps.db.session.commit"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
|
||||
|
||||
def test_installed_app_required_success():
|
||||
installed_app = MagicMock(app=MagicMock())
|
||||
|
||||
@installed_app_required
|
||||
def view(installed_app):
|
||||
return installed_app
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
|
||||
result = view("app-id")
|
||||
assert result == installed_app
|
||||
|
||||
|
||||
def test_user_allowed_to_access_app_denied():
|
||||
installed_app = MagicMock(app_id="app-1")
|
||||
|
||||
@user_allowed_to_access_app
|
||||
def view(installed_app):
|
||||
return "ok"
|
||||
|
||||
feature = MagicMock()
|
||||
feature.webapp_auth.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.wraps.FeatureService.get_system_features",
|
||||
return_value=feature,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppAccessDeniedError):
|
||||
view(installed_app)
|
||||
|
||||
|
||||
def test_user_allowed_to_access_app_success():
|
||||
installed_app = MagicMock(app_id="app-1")
|
||||
|
||||
@user_allowed_to_access_app
|
||||
def view(installed_app):
|
||||
return "ok"
|
||||
|
||||
feature = MagicMock()
|
||||
feature.webapp_auth.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.wraps.FeatureService.get_system_features",
|
||||
return_value=feature,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
assert view(installed_app) == "ok"
|
||||
|
||||
|
||||
def test_trial_app_required_not_allowed():
|
||||
@trial_app_required
|
||||
def view(app):
|
||||
return "ok"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(TrialAppNotAllowed):
|
||||
view("app-id")
|
||||
|
||||
|
||||
def test_trial_app_required_limit_exceeded():
|
||||
trial_app = MagicMock(trial_limit=1, app=MagicMock())
|
||||
record = MagicMock(count=1)
|
||||
|
||||
@trial_app_required
|
||||
def view(app):
|
||||
return "ok"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
||||
with pytest.raises(TrialAppLimitExceeded):
|
||||
view("app-id")
|
||||
|
||||
|
||||
def test_trial_app_required_success():
|
||||
trial_app = MagicMock(trial_limit=2, app=MagicMock())
|
||||
record = MagicMock(count=1)
|
||||
|
||||
@trial_app_required
|
||||
def view(app):
|
||||
return app
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
||||
result = view("app-id")
|
||||
assert result == trial_app.app
|
||||
|
||||
|
||||
def test_trial_feature_enable_disabled():
|
||||
@trial_feature_enable
|
||||
def view():
|
||||
return "ok"
|
||||
|
||||
features = MagicMock(enable_trial_app=False)
|
||||
|
||||
with patch(
|
||||
"controllers.console.explore.wraps.FeatureService.get_system_features",
|
||||
return_value=features,
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
def test_trial_feature_enable_enabled():
|
||||
@trial_feature_enable
|
||||
def view():
|
||||
return "ok"
|
||||
|
||||
features = MagicMock(enable_trial_app=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.explore.wraps.FeatureService.get_system_features",
|
||||
return_value=features,
|
||||
):
|
||||
assert view() == "ok"
|
||||
|
||||
|
||||
def test_installed_app_resource_decorators():
|
||||
decorators = InstalledAppResource.method_decorators
|
||||
assert len(decorators) == 4
|
||||
|
||||
|
||||
def test_trial_app_resource_decorators():
|
||||
decorators = TrialAppResource.method_decorators
|
||||
assert len(decorators) == 3
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.tag.tags import (
|
||||
TagBindingCreateApi,
|
||||
TagBindingDeleteApi,
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_tag")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user():
|
||||
return MagicMock(
|
||||
id="user-1",
|
||||
has_edit_permission=True,
|
||||
is_dataset_editor=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def readonly_user():
|
||||
return MagicMock(
|
||||
id="user-2",
|
||||
has_edit_permission=False,
|
||||
is_dataset_editor=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tag():
|
||||
tag = MagicMock()
|
||||
tag.id = "tag-1"
|
||||
tag.name = "test-tag"
|
||||
tag.type = "knowledge"
|
||||
return tag
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch():
|
||||
def _patch(payload):
|
||||
return patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestTagListApi:
|
||||
def test_get_success(self, app):
|
||||
api = TagListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/?type=knowledge"):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.get_tags",
|
||||
return_value=[{"id": "1", "name": "tag"}],
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_post_success(self, app, admin_user, tag, payload_patch):
|
||||
api = TagListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "test-tag", "type": "knowledge"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.save_tags",
|
||||
return_value=tag,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["name"] == "test-tag"
|
||||
|
||||
def test_post_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "x"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestTagUpdateDeleteApi:
|
||||
def test_patch_success(self, app, admin_user, tag, payload_patch):
|
||||
api = TagUpdateDeleteApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {"name": "updated", "type": "knowledge"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.update_tags",
|
||||
return_value=tag,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.get_tag_binding_count",
|
||||
return_value=3,
|
||||
),
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
assert status == 200
|
||||
assert result["binding_count"] == 3
|
||||
|
||||
def test_patch_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagUpdateDeleteApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {"name": "x"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "tag-1")
|
||||
|
||||
def test_delete_success(self, app, admin_user):
|
||||
api = TagUpdateDeleteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock,
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tag-1")
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestTagBindingCreateApi:
|
||||
def test_create_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"tag_ids": ["tag-1"],
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_create_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch({}),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestTagBindingDeleteApi:
|
||||
def test_remove_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"tag_id": "tag-1",
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_remove_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch({}),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
Loading…
Reference in New Issue