mirror of https://github.com/langgenius/dify.git
refactor: fix decoupled runtime CI regressions
This commit is contained in:
parent
23aca7f567
commit
fb113bf3a4
|
|
@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner):
|
|||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
|
|
|||
|
|
@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner):
|
|||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
|
|
|||
|
|
@ -115,11 +115,11 @@ class LLMQuotaLayer(GraphEngineLayer):
|
|||
try:
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
return cast("LLMNode", node).model_instance
|
||||
model_instance = cast("LLMNode", node).model_instance
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
return cast("ParameterExtractorNode", node).model_instance
|
||||
model_instance = cast("ParameterExtractorNode", node).model_instance
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
return cast("QuestionClassifierNode", node).model_instance
|
||||
model_instance = cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
|
|
@ -128,3 +128,12 @@ class LLMQuotaLayer(GraphEngineLayer):
|
|||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
if isinstance(model_instance, ModelInstance):
|
||||
return model_instance
|
||||
|
||||
raw_model_instance = getattr(model_instance, "_model_instance", None)
|
||||
if isinstance(raw_model_instance, ModelInstance):
|
||||
return raw_model_instance
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class CacheEmbedding(Embeddings):
|
|||
|
||||
@staticmethod
|
||||
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
|
||||
if user is None:
|
||||
if user is None or not isinstance(model_instance, ModelInstance):
|
||||
return model_instance
|
||||
|
||||
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import time
|
|||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
|
|
@ -31,7 +31,7 @@ from services.enterprise.plugin_manager_service import PluginCredentialType
|
|||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.nodes.tool.entities import ToolEntity
|
||||
pass
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -62,7 +62,7 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
|
|||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.nodes.tool.entities import ToolEntity
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -77,6 +77,14 @@ class EmojiIconDict(TypedDict):
|
|||
content: str
|
||||
|
||||
|
||||
class WorkflowToolRuntimeSpec(Protocol):
|
||||
provider_type: ToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_configurations: dict[str, Any]
|
||||
credential_id: str | None
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
|
|
@ -405,7 +413,7 @@ class ToolManager:
|
|||
tenant_id: str,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
workflow_tool: "ToolEntity",
|
||||
workflow_tool: WorkflowToolRuntimeSpec,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
|
|
|
|||
|
|
@ -343,9 +343,9 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
|||
*,
|
||||
provider_name: str,
|
||||
default_icon: str | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
icon = default_icon
|
||||
icon_dark = None
|
||||
) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]:
|
||||
icon: str | Mapping[str, str] | None = default_icon
|
||||
icon_dark: str | Mapping[str, str] | None = None
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self._run_context.tenant_id)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||
from .file_factory import get_file_type_by_mime_type, standardize_file_type
|
||||
from .models import (
|
||||
File,
|
||||
FileUploadConfig,
|
||||
|
|
@ -16,4 +17,6 @@ __all__ = [
|
|||
"FileType",
|
||||
"FileUploadConfig",
|
||||
"ImageConfig",
|
||||
"get_file_type_by_mime_type",
|
||||
"standardize_file_type",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
|
||||
from .enums import FileType
|
||||
|
||||
|
||||
def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
|
||||
"""
|
||||
Infer the actual file type from extension and mime type.
|
||||
"""
|
||||
guessed_type = None
|
||||
if extension:
|
||||
guessed_type = _get_file_type_by_extension(extension)
|
||||
if guessed_type is None and mime_type:
|
||||
guessed_type = get_file_type_by_mime_type(mime_type)
|
||||
return guessed_type or FileType.CUSTOM
|
||||
|
||||
|
||||
def _get_file_type_by_extension(extension: str) -> FileType | None:
|
||||
normalized_extension = extension.lstrip(".")
|
||||
if normalized_extension in IMAGE_EXTENSIONS:
|
||||
return FileType.IMAGE
|
||||
if normalized_extension in VIDEO_EXTENSIONS:
|
||||
return FileType.VIDEO
|
||||
if normalized_extension in AUDIO_EXTENSIONS:
|
||||
return FileType.AUDIO
|
||||
if normalized_extension in DOCUMENT_EXTENSIONS:
|
||||
return FileType.DOCUMENT
|
||||
return None
|
||||
|
||||
|
||||
def get_file_type_by_mime_type(mime_type: str) -> FileType:
|
||||
if "image" in mime_type:
|
||||
return FileType.IMAGE
|
||||
if "video" in mime_type:
|
||||
return FileType.VIDEO
|
||||
if "audio" in mime_type:
|
||||
return FileType.AUDIO
|
||||
if "text" in mime_type or "pdf" in mime_type:
|
||||
return FileType.DOCUMENT
|
||||
return FileType.CUSTOM
|
||||
|
|
@ -60,7 +60,7 @@ class ToolNodeRuntimeProtocol(Protocol):
|
|||
*,
|
||||
provider_name: str,
|
||||
default_icon: str | None = None,
|
||||
) -> tuple[str | None, str | None]: ...
|
||||
) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ...
|
||||
|
||||
|
||||
class HumanInputNodeRuntimeProtocol(Protocol):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from dify_graph.enums import (
|
|||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.file import File, FileTransferMethod, get_file_type_by_mime_type
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
|
@ -252,9 +252,9 @@ class ToolNode(Node[ToolNodeData]):
|
|||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not found")
|
||||
|
||||
mapping = {
|
||||
mapping: dict[str, Any] = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"type": get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
|
|
@ -270,7 +270,7 @@ class ToolNode(Node[ToolNodeData]):
|
|||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
mapping: dict[str, Any] = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
|
|
@ -19,8 +20,9 @@ def handle(sender, **kwargs):
|
|||
if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL:
|
||||
try:
|
||||
tool_entity = ToolEntity.model_validate(node_data["data"])
|
||||
provider_type = ToolProviderType(tool_entity.provider_type.value)
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=tool_entity.provider_type,
|
||||
provider_type=provider_type,
|
||||
provider_id=tool_entity.provider_id,
|
||||
tool_name=tool_entity.tool_name,
|
||||
tenant_id=app.tenant_id,
|
||||
|
|
@ -30,7 +32,7 @@ def handle(sender, **kwargs):
|
|||
tenant_id=app.tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=tool_entity.provider_name,
|
||||
provider_type=tool_entity.provider_type,
|
||||
provider_type=provider_type,
|
||||
identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session
|
||||
from werkzeug.http import parse_options_header
|
||||
|
||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from core.helper import ssrf_proxy
|
||||
from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
|
||||
from dify_graph.file.file_factory import standardize_file_type
|
||||
from extensions.ext_database import db
|
||||
from models import MessageFile, ToolFile, UploadFile
|
||||
|
||||
|
|
@ -175,7 +175,7 @@ def _build_from_local_file(
|
|||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||
specified_type = mapping.get("type", "custom")
|
||||
|
||||
if strict_type_validation and detected_file_type.value != specified_type:
|
||||
|
|
@ -223,7 +223,7 @@ def _build_from_remote_url(
|
|||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
detected_file_type = _standardize_file_type(
|
||||
detected_file_type = standardize_file_type(
|
||||
extension="." + upload_file.extension, mime_type=upload_file.mime_type
|
||||
)
|
||||
|
||||
|
|
@ -257,7 +257,7 @@ def _build_from_remote_url(
|
|||
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
|
||||
|
||||
detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type)
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type)
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
|
|
@ -387,7 +387,7 @@ def _build_from_tool_file(
|
|||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
|
||||
detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
|
|
@ -436,7 +436,7 @@ def _build_from_datasource_file(
|
|||
|
||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||
|
||||
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
|
|
@ -503,49 +503,6 @@ def _is_file_valid_with_config(
|
|||
return True
|
||||
|
||||
|
||||
def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
|
||||
"""
|
||||
Infer the possible actual type of the file based on the extension and mime_type
|
||||
"""
|
||||
guessed_type = None
|
||||
if extension:
|
||||
guessed_type = _get_file_type_by_extension(extension)
|
||||
if guessed_type is None and mime_type:
|
||||
guessed_type = _get_file_type_by_mimetype(mime_type)
|
||||
return guessed_type or FileType.CUSTOM
|
||||
|
||||
|
||||
def _get_file_type_by_extension(extension: str) -> FileType | None:
|
||||
extension = extension.lstrip(".")
|
||||
if extension in IMAGE_EXTENSIONS:
|
||||
return FileType.IMAGE
|
||||
elif extension in VIDEO_EXTENSIONS:
|
||||
return FileType.VIDEO
|
||||
elif extension in AUDIO_EXTENSIONS:
|
||||
return FileType.AUDIO
|
||||
elif extension in DOCUMENT_EXTENSIONS:
|
||||
return FileType.DOCUMENT
|
||||
return None
|
||||
|
||||
|
||||
def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
|
||||
if "image" in mime_type:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in mime_type:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in mime_type:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in mime_type or "pdf" in mime_type:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
return file_type
|
||||
|
||||
|
||||
def get_file_type_by_mime_type(mime_type: str) -> FileType:
|
||||
return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM
|
||||
|
||||
|
||||
class StorageKeyLoader:
|
||||
"""FileKeyLoader load the storage key from database for a list of files.
|
||||
This loader is batched, the database query count is constant regardless of the input size.
|
||||
|
|
|
|||
|
|
@ -124,11 +124,19 @@ class AppService:
|
|||
"completion_params": {},
|
||||
}
|
||||
else:
|
||||
provider, model = model_manager.get_default_provider_model_name(
|
||||
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
|
||||
)
|
||||
default_model_config["model"]["provider"] = provider
|
||||
default_model_config["model"]["name"] = model
|
||||
try:
|
||||
provider, model = model_manager.get_default_provider_model_name(
|
||||
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Get default provider model failed, tenant_id: %s", tenant_id)
|
||||
provider = default_model_config["model"].get("provider")
|
||||
model = default_model_config["model"].get("name")
|
||||
|
||||
if provider:
|
||||
default_model_config["model"]["provider"] = provider
|
||||
if model:
|
||||
default_model_config["model"]["name"] = model
|
||||
default_model_dict = default_model_config["model"]
|
||||
|
||||
default_model_config["model"] = json.dumps(default_model_dict)
|
||||
|
|
|
|||
|
|
@ -416,7 +416,7 @@ class TestDatasetApiGet:
|
|||
"check_dataset_permission",
|
||||
return_value=None,
|
||||
),
|
||||
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
|
||||
patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock,
|
||||
):
|
||||
# embedding models exist → embedding_available stays True
|
||||
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
|
||||
|
|
@ -520,7 +520,7 @@ class TestDatasetApiGet:
|
|||
"check_dataset_permission",
|
||||
return_value=None,
|
||||
),
|
||||
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
|
||||
patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock,
|
||||
):
|
||||
# embedding model NOT configured
|
||||
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
|
||||
|
|
@ -579,7 +579,7 @@ class TestDatasetApiGet:
|
|||
"get_dataset_partial_member_list",
|
||||
return_value=partial_members,
|
||||
),
|
||||
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
|
||||
patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock,
|
||||
):
|
||||
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
|
||||
|
||||
|
|
|
|||
|
|
@ -941,11 +941,11 @@ class TestDatasetListApiGet:
|
|||
"""Test suite for DatasetListApi.get() endpoint.
|
||||
|
||||
``get`` has no billing decorators but calls ``current_user``,
|
||||
``DatasetService``, ``ProviderManager``, and ``marshal``.
|
||||
``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.marshal")
|
||||
@patch("controllers.service_api.dataset.dataset.ProviderManager")
|
||||
@patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager")
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
@patch("controllers.service_api.dataset.dataset.DatasetService")
|
||||
def test_list_datasets_success(
|
||||
|
|
@ -1043,12 +1043,12 @@ class TestDatasetApiGet:
|
|||
"""Test suite for DatasetApi.get() endpoint.
|
||||
|
||||
``get`` has no billing decorators but calls ``DatasetService``,
|
||||
``ProviderManager``, ``marshal``, and ``current_user``.
|
||||
``create_plugin_provider_manager``, ``marshal``, and ``current_user``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.DatasetPermissionService")
|
||||
@patch("controllers.service_api.dataset.dataset.marshal")
|
||||
@patch("controllers.service_api.dataset.dataset.ProviderManager")
|
||||
@patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager")
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
@patch("controllers.service_api.dataset.dataset.DatasetService")
|
||||
def test_get_dataset_success(
|
||||
|
|
|
|||
|
|
@ -28,10 +28,7 @@ def mock_model_instance(mocker):
|
|||
def mock_model_manager(mocker, mock_model_instance):
|
||||
manager = mocker.MagicMock()
|
||||
manager.get_default_model_instance.return_value = mock_model_instance
|
||||
mocker.patch(
|
||||
"core.base.tts.app_generator_tts_publisher.ModelManager",
|
||||
return_value=manager,
|
||||
)
|
||||
mocker.patch("core.base.tts.app_generator_tts_publisher.ModelManager.for_tenant", return_value=manager)
|
||||
return manager
|
||||
|
||||
|
||||
|
|
@ -64,16 +61,14 @@ class TestInvoiceTTS:
|
|||
[None, "", " "],
|
||||
)
|
||||
def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance):
|
||||
result = _invoice_tts(text, mock_model_instance, "tenant", "voice1")
|
||||
result = _invoice_tts(text, mock_model_instance, "voice1")
|
||||
assert result is None
|
||||
mock_model_instance.invoke_tts.assert_not_called()
|
||||
|
||||
def test_invoice_tts_valid_text(self, mock_model_instance):
|
||||
result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1")
|
||||
result = _invoice_tts(" hello ", mock_model_instance, "voice1")
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="hello",
|
||||
user="responding_tts",
|
||||
tenant_id="tenant",
|
||||
voice="voice1",
|
||||
)
|
||||
assert result == [b"audio1", b"audio2"]
|
||||
|
|
|
|||
|
|
@ -350,7 +350,7 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
|
|||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"}
|
||||
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.encrypter.encrypt_token",
|
||||
|
|
@ -380,7 +380,9 @@ def test_validate_provider_credentials_opens_session_when_not_passed() -> None:
|
|||
with patch("core.entities.provider_configuration.db") as mock_db:
|
||||
mock_db.engine = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
|
||||
|
||||
assert validated == {"region": "us"}
|
||||
|
|
@ -434,7 +436,7 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
|||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
|
||||
|
|
@ -475,7 +477,7 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
|
|||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = provider_schema
|
||||
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False)
|
||||
active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True)
|
||||
|
||||
|
|
@ -689,7 +691,7 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None:
|
|||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"}
|
||||
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
|
|
@ -1034,7 +1036,7 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
|
|||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
|
|
@ -1050,7 +1052,9 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
|
|||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"region": "us"}
|
||||
with _patched_session(session):
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
|
|
@ -1540,7 +1544,7 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing()
|
|||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
|
|
@ -1662,7 +1666,7 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No
|
|||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
|||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "vision-embedding-model"
|
||||
model_instance.model_name = "vision-embedding-model"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
|
||||
|
|
@ -316,6 +317,7 @@ class TestCacheEmbeddingMultimodalQuery:
|
|||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "vision-embedding-model"
|
||||
model_instance.model_name = "vision-embedding-model"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
return model_instance
|
||||
|
|
@ -467,6 +469,7 @@ class TestCacheEmbeddingQueryErrors:
|
|||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
return model_instance
|
||||
|
|
@ -536,20 +539,20 @@ class TestCacheEmbeddingInitialization:
|
|||
"""Test CacheEmbedding initialization with user parameter."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
|
||||
cache_embedding = CacheEmbedding(model_instance, user="test-user")
|
||||
|
||||
assert cache_embedding._model_instance == model_instance
|
||||
assert cache_embedding._user == "test-user"
|
||||
|
||||
def test_initialization_without_user(self):
|
||||
"""Test CacheEmbedding initialization without user parameter."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
|
||||
cache_embedding = CacheEmbedding(model_instance)
|
||||
|
||||
assert cache_embedding._model_instance == model_instance
|
||||
assert cache_embedding._user is None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import threading
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph_engine.entities.commands import CommandType
|
||||
from dify_graph.graph_events.node import NodeRunSucceededEvent
|
||||
|
|
@ -36,6 +38,11 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
|
|||
)
|
||||
|
||||
|
||||
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
|
||||
raw_model_instance = ModelInstance.__new__(ModelInstance)
|
||||
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
|
|
@ -44,7 +51,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None:
|
|||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance = object()
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
|
|
@ -52,7 +59,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None:
|
|||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=node.model_instance,
|
||||
model_instance=raw_model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
|
@ -65,7 +72,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None:
|
|||
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance = object()
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
|
|
@ -73,7 +80,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None:
|
|||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=node.model_instance,
|
||||
model_instance=raw_model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
|
@ -125,7 +132,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
|||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_run_context_value.return_value = _build_dify_context()
|
||||
node.model_instance = object()
|
||||
node.model_instance, _ = _build_wrapped_model_instance()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
|
|
@ -152,7 +159,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
|||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.model_instance = object()
|
||||
node.model_instance, _ = _build_wrapped_model_instance()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
|
|
@ -178,7 +185,7 @@ def test_quota_precheck_passes_without_abort() -> None:
|
|||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.model_instance = object()
|
||||
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
|
|
@ -186,5 +193,5 @@ def test_quota_precheck_passes_without_abort() -> None:
|
|||
layer.on_node_run_start(node)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
mock_check.assert_called_once_with(model_instance=node.model_instance)
|
||||
mock_check.assert_called_once_with(model_instance=raw_model_instance)
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -2,89 +2,55 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
|
||||
|
||||
class TestModerationModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
@pytest.fixture
|
||||
def provider_schema() -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider="test_provider",
|
||||
label=I18nObject(en_US="test_provider"),
|
||||
supported_model_types=[ModelType.MODERATION],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def moderation_model(self, mock_plugin_model_provider):
|
||||
return ModerationModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.MODERATION,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
)
|
||||
|
||||
def test_model_type(self, moderation_model):
|
||||
assert moderation_model.model_type == ModelType.MODERATION
|
||||
@pytest.fixture
|
||||
def model_runtime() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
def test_invoke_success(self, moderation_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
text = "test text"
|
||||
user = "user_123"
|
||||
|
||||
with (
|
||||
patch("core.plugin.impl.model.PluginModelClient") as mock_client_class,
|
||||
patch("time.perf_counter", return_value=1.0),
|
||||
):
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_moderation.return_value = True
|
||||
@pytest.fixture
|
||||
def moderation_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> ModerationModel:
|
||||
return ModerationModel(provider_schema=provider_schema, model_runtime=model_runtime)
|
||||
|
||||
result = moderation_model.invoke(model=model_name, credentials=credentials, text=text, user=user)
|
||||
|
||||
assert result is True
|
||||
assert moderation_model.started_at == 1.0
|
||||
mock_client.invoke_moderation.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="user_123",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
def test_model_type(moderation_model: ModerationModel) -> None:
|
||||
assert moderation_model.model_type == ModelType.MODERATION
|
||||
|
||||
def test_invoke_success_no_user(self, moderation_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
text = "test text"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_moderation.return_value = False
|
||||
def test_invoke_success(moderation_model: ModerationModel, model_runtime: MagicMock) -> None:
|
||||
with patch("time.perf_counter", return_value=1.0):
|
||||
model_runtime.invoke_moderation.return_value = True
|
||||
|
||||
result = moderation_model.invoke(model=model_name, credentials=credentials, text=text)
|
||||
result = moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text")
|
||||
|
||||
assert result is False
|
||||
mock_client.invoke_moderation.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
text=text,
|
||||
)
|
||||
assert result is True
|
||||
assert moderation_model.started_at == 1.0
|
||||
model_runtime.invoke_moderation.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
text="test text",
|
||||
)
|
||||
|
||||
def test_invoke_exception(self, moderation_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
text = "test text"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_moderation.side_effect = Exception("Test error")
|
||||
def test_invoke_exception(moderation_model: ModerationModel, model_runtime: MagicMock) -> None:
|
||||
model_runtime.invoke_moderation.side_effect = Exception("Test error")
|
||||
|
||||
with pytest.raises(InvokeError) as excinfo:
|
||||
moderation_model.invoke(model=model_name, credentials=credentials, text=text)
|
||||
|
||||
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
|
||||
with pytest.raises(InvokeError, match="Test error"):
|
||||
moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text")
|
||||
|
|
|
|||
|
|
@ -1,71 +1,42 @@
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rerank_model() -> RerankModel:
|
||||
plugin_provider = PluginModelProviderEntity.model_construct(
|
||||
id="provider-id",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider="provider",
|
||||
tenant_id="tenant",
|
||||
plugin_unique_identifier="plugin-uid",
|
||||
plugin_id="plugin-id",
|
||||
declaration=MagicMock(),
|
||||
)
|
||||
return RerankModel.model_construct(
|
||||
tenant_id="tenant",
|
||||
model_type=ModelType.RERANK,
|
||||
plugin_id="plugin-id",
|
||||
provider_name="provider",
|
||||
plugin_model_provider=plugin_provider,
|
||||
def provider_schema() -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider="test_provider",
|
||||
label=I18nObject(en_US="test_provider"),
|
||||
supported_model_types=[ModelType.RERANK],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
|
||||
def test_model_type_is_rerank_by_default() -> None:
|
||||
plugin_provider = PluginModelProviderEntity.model_construct(
|
||||
id="provider-id",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider="provider",
|
||||
tenant_id="tenant",
|
||||
plugin_unique_identifier="plugin-uid",
|
||||
plugin_id="plugin-id",
|
||||
declaration=MagicMock(),
|
||||
)
|
||||
model = RerankModel(
|
||||
tenant_id="tenant",
|
||||
plugin_id="plugin-id",
|
||||
provider_name="provider",
|
||||
plugin_model_provider=plugin_provider,
|
||||
)
|
||||
assert model.model_type == ModelType.RERANK
|
||||
@pytest.fixture
|
||||
def model_runtime() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@pytest.fixture
|
||||
def rerank_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> RerankModel:
|
||||
return RerankModel(provider_schema=provider_schema, model_runtime=model_runtime)
|
||||
|
||||
|
||||
def test_model_type_is_rerank_by_default(rerank_model: RerankModel) -> None:
|
||||
assert rerank_model.model_type == ModelType.RERANK
|
||||
|
||||
|
||||
def test_invoke_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None:
|
||||
expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)])
|
||||
|
||||
class FakePluginModelClient:
|
||||
def __init__(self) -> None:
|
||||
self.invoke_rerank_called_with: dict[str, Any] | None = None
|
||||
|
||||
def invoke_rerank(self, **kwargs: Any) -> RerankResult:
|
||||
self.invoke_rerank_called_with = kwargs
|
||||
return expected
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
fake_client = FakePluginModelClient()
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
|
||||
model_runtime.invoke_rerank.return_value = expected
|
||||
|
||||
result = rerank_model.invoke(
|
||||
model="rerank",
|
||||
|
|
@ -74,76 +45,30 @@ def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypa
|
|||
docs=["d1", "d2"],
|
||||
score_threshold=0.2,
|
||||
top_n=10,
|
||||
user="user-1",
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
assert fake_client.invoke_rerank_called_with == {
|
||||
"tenant_id": "tenant",
|
||||
"user_id": "user-1",
|
||||
"plugin_id": "plugin-id",
|
||||
"provider": "provider",
|
||||
"model": "rerank",
|
||||
"credentials": {"k": "v"},
|
||||
"query": "q",
|
||||
"docs": ["d1", "d2"],
|
||||
"score_threshold": 0.2,
|
||||
"top_n": 10,
|
||||
}
|
||||
model_runtime.invoke_rerank.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="rerank",
|
||||
credentials={"k": "v"},
|
||||
query="q",
|
||||
docs=["d1", "d2"],
|
||||
score_threshold=0.2,
|
||||
top_n=10,
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_uses_unknown_user_when_not_provided(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FakePluginModelClient:
|
||||
def __init__(self) -> None:
|
||||
self.kwargs: dict[str, Any] | None = None
|
||||
def test_invoke_transforms_and_raises_on_runtime_error(rerank_model: RerankModel, model_runtime: MagicMock) -> None:
|
||||
model_runtime.invoke_rerank.side_effect = Exception("runtime down")
|
||||
|
||||
def invoke_rerank(self, **kwargs: Any) -> RerankResult:
|
||||
self.kwargs = kwargs
|
||||
return RerankResult(model="m", docs=[])
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
fake_client = FakePluginModelClient()
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
|
||||
|
||||
rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"])
|
||||
assert fake_client.kwargs is not None
|
||||
assert fake_client.kwargs["user_id"] == "unknown"
|
||||
|
||||
|
||||
def test_invoke_transforms_and_raises_on_plugin_error(
|
||||
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
class FakePluginModelClient:
|
||||
def invoke_rerank(self, **_: Any) -> RerankResult:
|
||||
raise ValueError("plugin down")
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
|
||||
monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="transformed: plugin down"):
|
||||
with pytest.raises(InvokeError, match="runtime down"):
|
||||
rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"])
|
||||
|
||||
|
||||
def test_invoke_multimodal_calls_plugin_and_passes_args(
|
||||
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
def test_invoke_multimodal_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None:
|
||||
expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)])
|
||||
|
||||
class FakePluginModelClient:
|
||||
def __init__(self) -> None:
|
||||
self.invoke_multimodal_rerank_called_with: dict[str, Any] | None = None
|
||||
|
||||
def invoke_multimodal_rerank(self, **kwargs: Any) -> RerankResult:
|
||||
self.invoke_multimodal_rerank_called_with = kwargs
|
||||
return expected
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
fake_client = FakePluginModelClient()
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
|
||||
model_runtime.invoke_multimodal_rerank.return_value = expected
|
||||
|
||||
query = {"type": "text", "text": "q"}
|
||||
docs = [{"type": "text", "text": "d1"}]
|
||||
|
|
@ -154,28 +79,15 @@ def test_invoke_multimodal_calls_plugin_and_passes_args(
|
|||
docs=docs,
|
||||
score_threshold=None,
|
||||
top_n=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
assert fake_client.invoke_multimodal_rerank_called_with is not None
|
||||
assert fake_client.invoke_multimodal_rerank_called_with["tenant_id"] == "tenant"
|
||||
assert fake_client.invoke_multimodal_rerank_called_with["user_id"] == "unknown"
|
||||
assert fake_client.invoke_multimodal_rerank_called_with["query"] == query
|
||||
assert fake_client.invoke_multimodal_rerank_called_with["docs"] == docs
|
||||
|
||||
|
||||
def test_invoke_multimodal_transforms_and_raises_on_plugin_error(
|
||||
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
class FakePluginModelClient:
|
||||
def invoke_multimodal_rerank(self, **_: Any) -> RerankResult:
|
||||
raise ValueError("plugin down")
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
|
||||
monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="transformed: plugin down"):
|
||||
rerank_model.invoke_multimodal_rerank(model="m", credentials={}, query={"q": 1}, docs=[{"d": 1}])
|
||||
model_runtime.invoke_multimodal_rerank.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="mm",
|
||||
credentials={"k": "v"},
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=None,
|
||||
top_n=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,87 +1,56 @@
|
|||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
|
||||
|
||||
class TestSpeech2TextModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
@pytest.fixture
|
||||
def provider_schema() -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider="test_provider",
|
||||
label=I18nObject(en_US="test_provider"),
|
||||
supported_model_types=[ModelType.SPEECH2TEXT],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def speech2text_model(self, mock_plugin_model_provider):
|
||||
return Speech2TextModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
)
|
||||
|
||||
def test_model_type(self, speech2text_model):
|
||||
assert speech2text_model.model_type == ModelType.SPEECH2TEXT
|
||||
@pytest.fixture
|
||||
def model_runtime() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
def test_invoke_success(self, speech2text_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
file = BytesIO(b"audio data")
|
||||
user = "user_123"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_speech_to_text.return_value = "transcribed text"
|
||||
@pytest.fixture
|
||||
def speech2text_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> Speech2TextModel:
|
||||
return Speech2TextModel(provider_schema=provider_schema, model_runtime=model_runtime)
|
||||
|
||||
result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file, user=user)
|
||||
|
||||
assert result == "transcribed text"
|
||||
mock_client.invoke_speech_to_text.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="user_123",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
def test_model_type(speech2text_model: Speech2TextModel) -> None:
|
||||
assert speech2text_model.model_type == ModelType.SPEECH2TEXT
|
||||
|
||||
def test_invoke_success_no_user(self, speech2text_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
file = BytesIO(b"audio data")
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_speech_to_text.return_value = "transcribed text"
|
||||
def test_invoke_success(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None:
|
||||
file = BytesIO(b"audio data")
|
||||
model_runtime.invoke_speech_to_text.return_value = "transcribed text"
|
||||
|
||||
result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file)
|
||||
result = speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=file)
|
||||
|
||||
assert result == "transcribed text"
|
||||
mock_client.invoke_speech_to_text.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
file=file,
|
||||
)
|
||||
assert result == "transcribed text"
|
||||
model_runtime.invoke_speech_to_text.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
file=file,
|
||||
)
|
||||
|
||||
def test_invoke_exception(self, speech2text_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
file = BytesIO(b"audio data")
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_speech_to_text.side_effect = Exception("Test error")
|
||||
def test_invoke_exception(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None:
|
||||
model_runtime.invoke_speech_to_text.side_effect = Exception("Test error")
|
||||
|
||||
with pytest.raises(InvokeError) as excinfo:
|
||||
speech2text_model.invoke(model=model_name, credentials=credentials, file=file)
|
||||
|
||||
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
|
||||
with pytest.raises(InvokeError, match="Test error"):
|
||||
speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=BytesIO(b"audio data"))
|
||||
|
|
|
|||
|
|
@ -2,184 +2,145 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class TestTextEmbeddingModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
@pytest.fixture
|
||||
def provider_schema() -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider="test_provider",
|
||||
label=I18nObject(en_US="test_provider"),
|
||||
supported_model_types=[ModelType.TEXT_EMBEDDING],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def text_embedding_model(self, mock_plugin_model_provider):
|
||||
return TextEmbeddingModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
)
|
||||
|
||||
def test_model_type(self, text_embedding_model):
|
||||
assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING
|
||||
@pytest.fixture
|
||||
def model_runtime() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
def test_invoke_with_texts(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
texts = ["hello", "world"]
|
||||
user = "user_123"
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_text_embedding.return_value = expected_result
|
||||
@pytest.fixture
|
||||
def text_embedding_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TextEmbeddingModel:
|
||||
return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=model_runtime)
|
||||
|
||||
result = text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts, user=user)
|
||||
|
||||
assert result == expected_result
|
||||
mock_client.invoke_text_embedding.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="user_123",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
def test_model_type(text_embedding_model: TextEmbeddingModel) -> None:
|
||||
assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING
|
||||
|
||||
def test_invoke_with_multimodel_documents(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
multimodel_documents = [{"type": "text", "text": "hello"}]
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_multimodal_embedding.return_value = expected_result
|
||||
def test_invoke_with_texts(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None:
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
model_runtime.invoke_text_embedding.return_value = expected_result
|
||||
|
||||
result = text_embedding_model.invoke(
|
||||
model=model_name, credentials=credentials, multimodel_documents=multimodel_documents
|
||||
)
|
||||
result = text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"])
|
||||
|
||||
assert result == expected_result
|
||||
mock_client.invoke_multimodal_embedding.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
documents=multimodel_documents,
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
assert result == expected_result
|
||||
model_runtime.invoke_text_embedding.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
texts=["hello", "world"],
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
def test_invoke_no_input(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
text_embedding_model.invoke(model=model_name, credentials=credentials)
|
||||
def test_invoke_with_multimodal_documents(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None:
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
model_runtime.invoke_multimodal_embedding.return_value = expected_result
|
||||
|
||||
assert "No texts or files provided" in str(excinfo.value)
|
||||
result = text_embedding_model.invoke(
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
multimodel_documents=[{"type": "text", "text": "hello"}],
|
||||
)
|
||||
|
||||
def test_invoke_precedence(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
texts = ["hello"]
|
||||
multimodel_documents = [{"type": "text", "text": "world"}]
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
assert result == expected_result
|
||||
model_runtime.invoke_multimodal_embedding.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
documents=[{"type": "text", "text": "hello"}],
|
||||
input_type=EmbeddingInputType.DOCUMENT,
|
||||
)
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_text_embedding.return_value = expected_result
|
||||
|
||||
result = text_embedding_model.invoke(
|
||||
model=model_name, credentials=credentials, texts=texts, multimodel_documents=multimodel_documents
|
||||
)
|
||||
def test_invoke_no_input(text_embedding_model: TextEmbeddingModel) -> None:
|
||||
with pytest.raises(ValueError, match="No texts or files provided"):
|
||||
text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"})
|
||||
|
||||
assert result == expected_result
|
||||
mock_client.invoke_text_embedding.assert_called_once()
|
||||
mock_client.invoke_multimodal_embedding.assert_not_called()
|
||||
|
||||
def test_invoke_exception(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
texts = ["hello"]
|
||||
def test_invoke_prefers_texts_over_multimodal_documents(
|
||||
text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock
|
||||
) -> None:
|
||||
expected_result = MagicMock(spec=EmbeddingResult)
|
||||
model_runtime.invoke_text_embedding.return_value = expected_result
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_text_embedding.side_effect = Exception("Test error")
|
||||
result = text_embedding_model.invoke(
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
texts=["hello"],
|
||||
multimodel_documents=[{"type": "text", "text": "world"}],
|
||||
)
|
||||
|
||||
with pytest.raises(InvokeError) as excinfo:
|
||||
text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts)
|
||||
assert result == expected_result
|
||||
model_runtime.invoke_text_embedding.assert_called_once()
|
||||
model_runtime.invoke_multimodal_embedding.assert_not_called()
|
||||
|
||||
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
|
||||
|
||||
def test_get_num_tokens(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
texts = ["hello", "world"]
|
||||
expected_tokens = [1, 1]
|
||||
def test_invoke_exception(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None:
|
||||
model_runtime.invoke_text_embedding.side_effect = Exception("Test error")
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.get_text_embedding_num_tokens.return_value = expected_tokens
|
||||
with pytest.raises(InvokeError, match="Test error"):
|
||||
text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello"])
|
||||
|
||||
result = text_embedding_model.get_num_tokens(model=model_name, credentials=credentials, texts=texts)
|
||||
|
||||
assert result == expected_tokens
|
||||
mock_client.get_text_embedding_num_tokens.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
texts=texts,
|
||||
)
|
||||
def test_get_num_tokens(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None:
|
||||
model_runtime.get_text_embedding_num_tokens.return_value = [1, 1]
|
||||
|
||||
def test_get_context_size(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
result = text_embedding_model.get_num_tokens(
|
||||
model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"]
|
||||
)
|
||||
|
||||
# Test case 1: Context size in schema
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048}
|
||||
assert result == [1, 1]
|
||||
model_runtime.get_text_embedding_num_tokens.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_context_size(model_name, credentials) == 2048
|
||||
|
||||
# Test case 2: No schema
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
|
||||
assert text_embedding_model._get_context_size(model_name, credentials) == 1000
|
||||
def test_get_context_size(text_embedding_model: TextEmbeddingModel) -> None:
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048}
|
||||
|
||||
# Test case 3: Context size NOT in schema properties
|
||||
mock_schema.model_properties = {}
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_context_size(model_name, credentials) == 1000
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 2048
|
||||
|
||||
def test_get_max_chunks(self, text_embedding_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
|
||||
assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000
|
||||
|
||||
# Test case 1: Max chunks in schema
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10}
|
||||
mock_schema.model_properties = {}
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000
|
||||
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_max_chunks(model_name, credentials) == 10
|
||||
|
||||
# Test case 2: No schema
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
|
||||
assert text_embedding_model._get_max_chunks(model_name, credentials) == 1
|
||||
def test_get_max_chunks(text_embedding_model: TextEmbeddingModel) -> None:
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10}
|
||||
|
||||
# Test case 3: Max chunks NOT in schema properties
|
||||
mock_schema.model_properties = {}
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_max_chunks(model_name, credentials) == 1
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 10
|
||||
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
|
||||
assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1
|
||||
|
||||
mock_schema.model_properties = {}
|
||||
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
|
||||
assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1
|
||||
|
|
|
|||
|
|
@ -1,131 +1,83 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
|
||||
class TestTTSModel:
|
||||
@pytest.fixture
|
||||
def mock_plugin_model_provider(self):
|
||||
return MagicMock(spec=PluginModelProviderEntity)
|
||||
@pytest.fixture
|
||||
def provider_schema() -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider="test_provider",
|
||||
label=I18nObject(en_US="test_provider"),
|
||||
supported_model_types=[ModelType.TTS],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def tts_model(self, mock_plugin_model_provider):
|
||||
return TTSModel(
|
||||
tenant_id="tenant_123",
|
||||
model_type=ModelType.TTS,
|
||||
plugin_id="plugin_123",
|
||||
provider_name="test_provider",
|
||||
plugin_model_provider=mock_plugin_model_provider,
|
||||
|
||||
@pytest.fixture
|
||||
def model_runtime() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TTSModel:
|
||||
return TTSModel(provider_schema=provider_schema, model_runtime=model_runtime)
|
||||
|
||||
|
||||
def test_model_type(tts_model: TTSModel) -> None:
|
||||
assert tts_model.model_type == ModelType.TTS
|
||||
|
||||
|
||||
def test_invoke_success(tts_model: TTSModel, model_runtime: MagicMock) -> None:
|
||||
model_runtime.invoke_tts.return_value = [b"audio_chunk"]
|
||||
|
||||
result = tts_model.invoke(
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
content_text="Hello world",
|
||||
voice="alloy",
|
||||
)
|
||||
|
||||
assert list(result) == [b"audio_chunk"]
|
||||
model_runtime.invoke_tts.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
content_text="Hello world",
|
||||
voice="alloy",
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_exception(tts_model: TTSModel, model_runtime: MagicMock) -> None:
|
||||
model_runtime.invoke_tts.side_effect = Exception("Test error")
|
||||
|
||||
with pytest.raises(InvokeError, match="Test error"):
|
||||
tts_model.invoke(
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
content_text="Hello world",
|
||||
voice="alloy",
|
||||
)
|
||||
|
||||
def test_model_type(self, tts_model):
|
||||
assert tts_model.model_type == ModelType.TTS
|
||||
|
||||
def test_invoke_success(self, tts_model):
|
||||
model_name = "test_model"
|
||||
tenant_id = "ignored_tenant_id"
|
||||
credentials = {"api_key": "abc"}
|
||||
content_text = "Hello world"
|
||||
voice = "alloy"
|
||||
user = "user_123"
|
||||
def test_get_tts_model_voices(tts_model: TTSModel, model_runtime: MagicMock) -> None:
|
||||
model_runtime.get_tts_model_voices.return_value = [{"name": "Voice1"}]
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_tts.return_value = [b"audio_chunk"]
|
||||
result = tts_model.get_tts_model_voices(
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
language="en-US",
|
||||
)
|
||||
|
||||
result = tts_model.invoke(
|
||||
model=model_name,
|
||||
tenant_id=tenant_id,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert list(result) == [b"audio_chunk"]
|
||||
mock_client.invoke_tts.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="user_123",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def test_invoke_success_no_user(self, tts_model):
|
||||
model_name = "test_model"
|
||||
tenant_id = "ignored_tenant_id"
|
||||
credentials = {"api_key": "abc"}
|
||||
content_text = "Hello world"
|
||||
voice = "alloy"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_tts.return_value = [b"audio_chunk"]
|
||||
|
||||
result = tts_model.invoke(
|
||||
model=model_name, tenant_id=tenant_id, credentials=credentials, content_text=content_text, voice=voice
|
||||
)
|
||||
|
||||
assert list(result) == [b"audio_chunk"]
|
||||
mock_client.invoke_tts.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def test_invoke_exception(self, tts_model):
|
||||
model_name = "test_model"
|
||||
tenant_id = "ignored_tenant_id"
|
||||
credentials = {"api_key": "abc"}
|
||||
content_text = "Hello world"
|
||||
voice = "alloy"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.invoke_tts.side_effect = Exception("Test error")
|
||||
|
||||
with pytest.raises(InvokeError) as excinfo:
|
||||
tts_model.invoke(
|
||||
model=model_name,
|
||||
tenant_id=tenant_id,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
|
||||
|
||||
def test_get_tts_model_voices(self, tts_model):
|
||||
model_name = "test_model"
|
||||
credentials = {"api_key": "abc"}
|
||||
language = "en-US"
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
|
||||
mock_client = mock_client_class.return_value
|
||||
mock_client.get_tts_model_voices.return_value = [{"name": "Voice1"}]
|
||||
|
||||
result = tts_model.get_tts_model_voices(model=model_name, credentials=credentials, language=language)
|
||||
|
||||
assert result == [{"name": "Voice1"}]
|
||||
mock_client.get_tts_model_voices.assert_called_once_with(
|
||||
tenant_id="tenant_123",
|
||||
user_id="unknown",
|
||||
plugin_id="plugin_123",
|
||||
provider="test_provider",
|
||||
model=model_name,
|
||||
credentials=credentials,
|
||||
language=language,
|
||||
)
|
||||
assert result == [{"name": "Voice1"}]
|
||||
model_runtime.get_tts_model_voices.assert_called_once_with(
|
||||
provider="test_provider",
|
||||
model="test_model",
|
||||
credentials={"api_key": "abc"},
|
||||
language="en-US",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,522 +0,0 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis import RedisError
|
||||
|
||||
import contexts
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
|
||||
def _provider_entity(
|
||||
*,
|
||||
provider: str,
|
||||
supported_model_types: list[ModelType] | None = None,
|
||||
models: list[AIModelEntity] | None = None,
|
||||
icon_small: I18nObject | None = None,
|
||||
icon_small_dark: I18nObject | None = None,
|
||||
) -> ProviderEntity:
|
||||
return ProviderEntity(
|
||||
provider=provider,
|
||||
label=I18nObject(en_US=provider),
|
||||
supported_model_types=supported_model_types or [ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
models=models or [],
|
||||
icon_small=icon_small,
|
||||
icon_small_dark=icon_small_dark,
|
||||
)
|
||||
|
||||
|
||||
def _plugin_provider(
|
||||
*, plugin_id: str, declaration: ProviderEntity, provider: str = "provider"
|
||||
) -> PluginModelProviderEntity:
|
||||
return PluginModelProviderEntity.model_construct(
|
||||
id=f"{plugin_id}-id",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider=provider,
|
||||
tenant_id="tenant",
|
||||
plugin_unique_identifier=f"{plugin_id}-uid",
|
||||
plugin_id=plugin_id,
|
||||
declaration=declaration,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_plugin_model_provider_context() -> None:
|
||||
contexts.plugin_model_providers_lock.set(Lock())
|
||||
contexts.plugin_model_providers.set(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_plugin_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
manager = MagicMock()
|
||||
|
||||
import core.plugin.impl.model as plugin_model_module
|
||||
|
||||
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: manager)
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory(fake_plugin_manager: MagicMock) -> ModelProviderFactory:
|
||||
return ModelProviderFactory(tenant_id="tenant")
|
||||
|
||||
|
||||
def test_get_plugin_model_providers_initializes_context_on_lookup_error(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
original_get = contexts.plugin_model_providers.get
|
||||
calls = {"n": 0}
|
||||
|
||||
def flaky_get() -> Any:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise LookupError
|
||||
return original_get()
|
||||
|
||||
monkeypatch.setattr(contexts.plugin_model_providers, "get", flaky_get)
|
||||
|
||||
providers = factory.get_plugin_model_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0].declaration.provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
def test_get_plugin_model_providers_caches_and_does_not_refetch(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
first = factory.get_plugin_model_providers()
|
||||
second = factory.get_plugin_model_providers()
|
||||
|
||||
assert first is second
|
||||
fake_plugin_manager.fetch_model_providers.assert_called_once_with("tenant")
|
||||
|
||||
|
||||
def test_get_providers_returns_declarations(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None:
|
||||
d1 = _provider_entity(provider="openai")
|
||||
d2 = _provider_entity(provider="anthropic")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=d1),
|
||||
_plugin_provider(plugin_id="langgenius/anthropic", declaration=d2),
|
||||
]
|
||||
|
||||
providers = factory.get_providers()
|
||||
assert [p.provider for p in providers] == ["langgenius/openai/openai", "langgenius/anthropic/anthropic"]
|
||||
|
||||
|
||||
def test_get_plugin_model_provider_converts_short_provider_id(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
provider = factory.get_plugin_model_provider("openai")
|
||||
assert provider.declaration.provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
def test_get_plugin_model_provider_raises_on_invalid_provider(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
factory.get_plugin_model_provider("langgenius/unknown/unknown")
|
||||
|
||||
|
||||
def test_get_provider_schema_returns_declaration(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None:
|
||||
declaration = _provider_entity(provider="openai")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
|
||||
]
|
||||
|
||||
schema = factory.get_provider_schema("openai")
|
||||
assert schema.provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
def test_provider_credentials_validate_errors_when_schema_missing(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
schema = _provider_entity(provider="openai")
|
||||
schema.provider_credential_schema = None
|
||||
monkeypatch.setattr(
|
||||
factory,
|
||||
"get_plugin_model_provider",
|
||||
lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not have provider_credential_schema"):
|
||||
factory.provider_credentials_validate(provider="openai", credentials={"x": "y"})
|
||||
|
||||
|
||||
def test_provider_credentials_validate_filters_and_calls_plugin_validation(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
schema = _provider_entity(provider="openai")
|
||||
schema.provider_credential_schema = MagicMock()
|
||||
plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema)
|
||||
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider)
|
||||
|
||||
fake_validator = MagicMock()
|
||||
fake_validator.validate_and_filter.return_value = {"filtered": True}
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.model_provider_factory.ProviderCredentialSchemaValidator",
|
||||
lambda _: fake_validator,
|
||||
)
|
||||
|
||||
filtered = factory.provider_credentials_validate(provider="openai", credentials={"raw": True})
|
||||
assert filtered == {"filtered": True}
|
||||
fake_plugin_manager.validate_provider_credentials.assert_called_once()
|
||||
kwargs = fake_plugin_manager.validate_provider_credentials.call_args.kwargs
|
||||
assert kwargs["plugin_id"] == "langgenius/openai"
|
||||
assert kwargs["provider"] == "provider"
|
||||
assert kwargs["credentials"] == {"filtered": True}
|
||||
|
||||
|
||||
def test_model_credentials_validate_errors_when_schema_missing(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
schema = _provider_entity(provider="openai")
|
||||
schema.model_credential_schema = None
|
||||
monkeypatch.setattr(
|
||||
factory,
|
||||
"get_plugin_model_provider",
|
||||
lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not have model_credential_schema"):
|
||||
factory.model_credentials_validate(
|
||||
provider="openai", model_type=ModelType.LLM, model="m", credentials={"x": "y"}
|
||||
)
|
||||
|
||||
|
||||
def test_model_credentials_validate_filters_and_calls_plugin_validation(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
schema = _provider_entity(provider="openai")
|
||||
schema.model_credential_schema = MagicMock()
|
||||
plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema)
|
||||
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider)
|
||||
|
||||
fake_validator = MagicMock()
|
||||
fake_validator.validate_and_filter.return_value = {"filtered": True}
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.model_runtime.model_providers.model_provider_factory.ModelCredentialSchemaValidator",
|
||||
lambda *_: fake_validator,
|
||||
)
|
||||
|
||||
filtered = factory.model_credentials_validate(
|
||||
provider="openai", model_type=ModelType.TEXT_EMBEDDING, model="m", credentials={"raw": True}
|
||||
)
|
||||
assert filtered == {"filtered": True}
|
||||
kwargs = fake_plugin_manager.validate_model_credentials.call_args.kwargs
|
||||
assert kwargs["plugin_id"] == "langgenius/openai"
|
||||
assert kwargs["provider"] == "provider"
|
||||
assert kwargs["model_type"] == "text-embedding"
|
||||
assert kwargs["model"] == "m"
|
||||
assert kwargs["credentials"] == {"filtered": True}
|
||||
|
||||
|
||||
def test_get_model_schema_cache_hit(factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model_schema = AIModelEntity(
|
||||
model="m",
|
||||
label=I18nObject(en_US="m"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = model_schema.model_dump_json().encode()
|
||||
assert (
|
||||
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials={"k": "v"})
|
||||
== model_schema
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_schema_cache_invalid_json_deletes_key(
|
||||
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = b'{"model":"m"}'
|
||||
factory.plugin_model_manager.get_model_schema.return_value = None
|
||||
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
|
||||
assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None
|
||||
assert mock_redis.delete.called
|
||||
assert any("Failed to validate cached plugin model schema" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_model_schema_cache_delete_redis_error_is_logged(
|
||||
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = b'{"model":"m"}'
|
||||
mock_redis.delete.side_effect = RedisError("nope")
|
||||
factory.plugin_model_manager.get_model_schema.return_value = None
|
||||
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
|
||||
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None)
|
||||
assert any("Failed to delete invalid plugin model schema cache" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_model_schema_redis_get_error_falls_back_to_plugin(
|
||||
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
|
||||
factory.plugin_model_manager.get_model_schema.return_value = None
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.side_effect = RedisError("down")
|
||||
assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None
|
||||
assert any("Failed to read plugin model schema cache" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_model_schema_cache_miss_sets_cache_and_handles_setex_error(
|
||||
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
|
||||
|
||||
model_schema = AIModelEntity(
|
||||
model="m",
|
||||
label=I18nObject(en_US="m"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
factory.plugin_model_manager.get_model_schema.return_value = model_schema
|
||||
|
||||
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.setex.side_effect = RedisError("nope")
|
||||
assert (
|
||||
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None)
|
||||
== model_schema
|
||||
)
|
||||
assert any("Failed to write plugin model schema cache" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_type", "expected_class"),
|
||||
[
|
||||
(ModelType.LLM, "LargeLanguageModel"),
|
||||
(ModelType.TEXT_EMBEDDING, "TextEmbeddingModel"),
|
||||
(ModelType.RERANK, "RerankModel"),
|
||||
(ModelType.SPEECH2TEXT, "Speech2TextModel"),
|
||||
(ModelType.MODERATION, "ModerationModel"),
|
||||
(ModelType.TTS, "TTSModel"),
|
||||
],
|
||||
)
|
||||
def test_get_model_type_instance_dispatches_by_type(
|
||||
factory: ModelProviderFactory, model_type: ModelType, expected_class: str, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
|
||||
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity))
|
||||
|
||||
sentinel = object()
|
||||
monkeypatch.setattr(
|
||||
f"dify_graph.model_runtime.model_providers.model_provider_factory.{expected_class}",
|
||||
MagicMock(model_validate=lambda _: sentinel),
|
||||
)
|
||||
|
||||
assert factory.get_model_type_instance("langgenius/openai/openai", model_type) is sentinel
|
||||
|
||||
|
||||
def test_get_model_type_instance_raises_on_unsupported(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
|
||||
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity))
|
||||
|
||||
class UnknownModelType:
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model type"):
|
||||
factory.get_model_type_instance("langgenius/openai/openai", UnknownModelType()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_get_models_filters_by_provider_and_model_type(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
llm = AIModelEntity(
|
||||
model="m1",
|
||||
label=I18nObject(en_US="m1"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
embed = AIModelEntity(
|
||||
model="e1",
|
||||
label=I18nObject(en_US="e1"),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
openai = _provider_entity(
|
||||
provider="openai", supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], models=[llm, embed]
|
||||
)
|
||||
anthropic = _provider_entity(provider="anthropic", supported_model_types=[ModelType.LLM], models=[llm])
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=openai),
|
||||
_plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic),
|
||||
]
|
||||
|
||||
# ModelType filter picks only matching models
|
||||
providers = factory.get_models(model_type=ModelType.TEXT_EMBEDDING)
|
||||
assert len(providers) == 1
|
||||
assert providers[0].provider == "langgenius/openai/openai"
|
||||
assert [m.model for m in providers[0].models] == ["e1"]
|
||||
|
||||
# Provider filter excludes others
|
||||
providers = factory.get_models(provider="langgenius/anthropic/anthropic", model_type=ModelType.LLM)
|
||||
assert len(providers) == 1
|
||||
assert providers[0].provider == "langgenius/anthropic/anthropic"
|
||||
|
||||
|
||||
def test_get_models_provider_filter_skips_non_matching(
|
||||
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
|
||||
) -> None:
|
||||
openai = _provider_entity(provider="openai")
|
||||
anthropic = _provider_entity(provider="anthropic")
|
||||
fake_plugin_manager.fetch_model_providers.return_value = [
|
||||
_plugin_provider(plugin_id="langgenius/openai", declaration=openai),
|
||||
_plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic),
|
||||
]
|
||||
|
||||
providers = factory.get_models(provider="langgenius/not-exist/not-exist", model_type=ModelType.LLM)
|
||||
assert providers == []
|
||||
|
||||
|
||||
def test_get_provider_icon_fetches_asset_and_returns_mime_type(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(
|
||||
provider="langgenius/openai/openai",
|
||||
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon-zh.png"),
|
||||
icon_small_dark=I18nObject(en_US="dark.svg", zh_Hans="dark-zh.svg"),
|
||||
)
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
|
||||
class FakePluginAssetManager:
|
||||
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
||||
assert tenant_id == "tenant"
|
||||
return f"bytes:{id}".encode()
|
||||
|
||||
import core.plugin.impl.asset as asset_module
|
||||
|
||||
monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager)
|
||||
|
||||
data, mime = factory.get_provider_icon("openai", "icon_small", "en_US")
|
||||
assert data == b"bytes:icon.png"
|
||||
assert mime == "image/png"
|
||||
|
||||
data, mime = factory.get_provider_icon("openai", "icon_small_dark", "zh_Hans")
|
||||
assert data == b"bytes:dark-zh.svg"
|
||||
assert mime == "image/svg+xml"
|
||||
|
||||
|
||||
def test_get_provider_icon_uses_zh_hans_for_small_and_en_us_for_dark(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(
|
||||
provider="langgenius/openai/openai",
|
||||
icon_small=I18nObject(en_US="icon-en.png", zh_Hans="icon-zh.png"),
|
||||
icon_small_dark=I18nObject(en_US="dark-en.svg", zh_Hans="dark-zh.svg"),
|
||||
)
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
|
||||
class FakePluginAssetManager:
|
||||
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
||||
return id.encode()
|
||||
|
||||
import core.plugin.impl.asset as asset_module
|
||||
|
||||
monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager)
|
||||
|
||||
data, _ = factory.get_provider_icon("openai", "icon_small", "zh_Hans")
|
||||
assert data == b"icon-zh.png"
|
||||
|
||||
data, _ = factory.get_provider_icon("openai", "icon_small_dark", "en_US")
|
||||
assert data == b"dark-en.svg"
|
||||
|
||||
|
||||
def test_get_provider_icon_raises_for_missing_icons(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(provider="langgenius/openai/openai")
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
|
||||
with pytest.raises(ValueError, match="does not have small icon"):
|
||||
factory.get_provider_icon("openai", "icon_small", "en_US")
|
||||
|
||||
with pytest.raises(ValueError, match="does not have small dark icon"):
|
||||
factory.get_provider_icon("openai", "icon_small_dark", "en_US")
|
||||
|
||||
|
||||
def test_get_provider_icon_raises_for_unsupported_icon_type(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(
|
||||
provider="langgenius/openai/openai",
|
||||
icon_small=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
with pytest.raises(ValueError, match="Unsupported icon type"):
|
||||
factory.get_provider_icon("openai", "nope", "en_US")
|
||||
|
||||
|
||||
def test_get_provider_icon_raises_when_file_name_missing(
|
||||
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
provider_schema = _provider_entity(
|
||||
provider="langgenius/openai/openai",
|
||||
icon_small=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
|
||||
with pytest.raises(ValueError, match="does not have icon"):
|
||||
factory.get_provider_icon("openai", "icon_small", "en_US")
|
||||
|
||||
|
||||
def test_get_plugin_id_and_provider_name_from_provider_handles_google_special_case(
|
||||
factory: ModelProviderFactory,
|
||||
) -> None:
|
||||
plugin_id, provider_name = factory.get_plugin_id_and_provider_name_from_provider("google")
|
||||
assert plugin_id == "langgenius/gemini"
|
||||
assert provider_name == "google"
|
||||
|
|
@ -336,7 +336,6 @@ def test_structured_output_parser():
|
|||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||
user="test_user",
|
||||
)
|
||||
|
||||
if case["expected_result_type"] == "generator":
|
||||
|
|
@ -367,7 +366,7 @@ def test_structured_output_parser():
|
|||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
assert call_args.kwargs["stream"] == case["stream"]
|
||||
assert call_args.kwargs["user"] == "test_user"
|
||||
assert "user" not in call_args.kwargs
|
||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue