refactor: fix decoupled runtime CI regressions

This commit is contained in:
-LAN- 2026-03-16 05:33:49 +08:00
parent 23aca7f567
commit fb113bf3a4
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
26 changed files with 428 additions and 1157 deletions

View File

@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner):
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result

View File

@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner):
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result

View File

@ -115,11 +115,11 @@ class LLMQuotaLayer(GraphEngineLayer):
try:
match node.node_type:
case BuiltinNodeTypes.LLM:
return cast("LLMNode", node).model_instance
model_instance = cast("LLMNode", node).model_instance
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
model_instance = cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
model_instance = cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
@ -128,3 +128,12 @@ class LLMQuotaLayer(GraphEngineLayer):
node.id,
)
return None
if isinstance(model_instance, ModelInstance):
return model_instance
raw_model_instance = getattr(model_instance, "_model_instance", None)
if isinstance(raw_model_instance, ModelInstance):
return raw_model_instance
return None

View File

@ -26,7 +26,7 @@ class CacheEmbedding(Embeddings):
@staticmethod
def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance:
if user is None:
if user is None or not isinstance(model_instance, ModelInstance):
return model_instance
tenant_id = model_instance.provider_model_bundle.configuration.tenant_id

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast
import sqlalchemy as sa
from sqlalchemy import select
@ -31,7 +31,7 @@ from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from dify_graph.nodes.tool.entities import ToolEntity
pass
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@ -62,7 +62,7 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
from dify_graph.nodes.tool.entities import ToolEntity
pass
logger = logging.getLogger(__name__)
@ -77,6 +77,14 @@ class EmojiIconDict(TypedDict):
content: str
class WorkflowToolRuntimeSpec(Protocol):
provider_type: ToolProviderType
provider_id: str
tool_name: str
tool_configurations: dict[str, Any]
credential_id: str | None
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
@ -405,7 +413,7 @@ class ToolManager:
tenant_id: str,
app_id: str,
node_id: str,
workflow_tool: "ToolEntity",
workflow_tool: WorkflowToolRuntimeSpec,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
) -> Tool:

View File

@ -343,9 +343,9 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
*,
provider_name: str,
default_icon: str | None = None,
) -> tuple[str | None, str | None]:
icon = default_icon
icon_dark = None
) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]:
icon: str | Mapping[str, str] | None = default_icon
icon_dark: str | Mapping[str, str] | None = None
manager = PluginInstaller()
plugins = manager.list_plugins(self._run_context.tenant_id)

View File

@ -1,5 +1,6 @@
from .constants import FILE_MODEL_IDENTITY
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
from .file_factory import get_file_type_by_mime_type, standardize_file_type
from .models import (
File,
FileUploadConfig,
@ -16,4 +17,6 @@ __all__ = [
"FileType",
"FileUploadConfig",
"ImageConfig",
"get_file_type_by_mime_type",
"standardize_file_type",
]

View File

@ -0,0 +1,40 @@
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from .enums import FileType
def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
"""
Infer the actual file type from extension and mime type.
"""
guessed_type = None
if extension:
guessed_type = _get_file_type_by_extension(extension)
if guessed_type is None and mime_type:
guessed_type = get_file_type_by_mime_type(mime_type)
return guessed_type or FileType.CUSTOM
def _get_file_type_by_extension(extension: str) -> FileType | None:
normalized_extension = extension.lstrip(".")
if normalized_extension in IMAGE_EXTENSIONS:
return FileType.IMAGE
if normalized_extension in VIDEO_EXTENSIONS:
return FileType.VIDEO
if normalized_extension in AUDIO_EXTENSIONS:
return FileType.AUDIO
if normalized_extension in DOCUMENT_EXTENSIONS:
return FileType.DOCUMENT
return None
def get_file_type_by_mime_type(mime_type: str) -> FileType:
if "image" in mime_type:
return FileType.IMAGE
if "video" in mime_type:
return FileType.VIDEO
if "audio" in mime_type:
return FileType.AUDIO
if "text" in mime_type or "pdf" in mime_type:
return FileType.DOCUMENT
return FileType.CUSTOM

View File

@ -60,7 +60,7 @@ class ToolNodeRuntimeProtocol(Protocol):
*,
provider_name: str,
default_icon: str | None = None,
) -> tuple[str | None, str | None]: ...
) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ...
class HumanInputNodeRuntimeProtocol(Protocol):

View File

@ -8,7 +8,7 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod
from dify_graph.file import File, FileTransferMethod, get_file_type_by_mime_type
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
@ -252,9 +252,9 @@ class ToolNode(Node[ToolNodeData]):
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not found")
mapping = {
mapping: dict[str, Any] = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"type": get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
@ -270,7 +270,7 @@ class ToolNode(Node[ToolNodeData]):
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
mapping: dict[str, Any] = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}

View File

@ -1,5 +1,6 @@
import logging
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from dify_graph.nodes import BuiltinNodeTypes
@ -19,8 +20,9 @@ def handle(sender, **kwargs):
if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL:
try:
tool_entity = ToolEntity.model_validate(node_data["data"])
provider_type = ToolProviderType(tool_entity.provider_type.value)
tool_runtime = ToolManager.get_tool_runtime(
provider_type=tool_entity.provider_type,
provider_type=provider_type,
provider_id=tool_entity.provider_id,
tool_name=tool_entity.tool_name,
tenant_id=app.tenant_id,
@ -30,7 +32,7 @@ def handle(sender, **kwargs):
tenant_id=app.tenant_id,
tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type,
provider_type=provider_type,
identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
)
manager.delete_tool_parameters_cache()

View File

@ -12,9 +12,9 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.http import parse_options_header
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.helper import ssrf_proxy
from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
from dify_graph.file.file_factory import standardize_file_type
from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile
@ -175,7 +175,7 @@ def _build_from_local_file(
if row is None:
raise ValueError("Invalid upload file")
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
specified_type = mapping.get("type", "custom")
if strict_type_validation and detected_file_type.value != specified_type:
@ -223,7 +223,7 @@ def _build_from_remote_url(
if upload_file is None:
raise ValueError("Invalid upload file")
detected_file_type = _standardize_file_type(
detected_file_type = standardize_file_type(
extension="." + upload_file.extension, mime_type=upload_file.mime_type
)
@ -257,7 +257,7 @@ def _build_from_remote_url(
mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type)
detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
@ -387,7 +387,7 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")
@ -436,7 +436,7 @@ def _build_from_datasource_file(
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
specified_type = mapping.get("type")
@ -503,49 +503,6 @@ def _is_file_valid_with_config(
return True
def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
"""
Infer the possible actual type of the file based on the extension and mime_type
"""
guessed_type = None
if extension:
guessed_type = _get_file_type_by_extension(extension)
if guessed_type is None and mime_type:
guessed_type = _get_file_type_by_mimetype(mime_type)
return guessed_type or FileType.CUSTOM
def _get_file_type_by_extension(extension: str) -> FileType | None:
extension = extension.lstrip(".")
if extension in IMAGE_EXTENSIONS:
return FileType.IMAGE
elif extension in VIDEO_EXTENSIONS:
return FileType.VIDEO
elif extension in AUDIO_EXTENSIONS:
return FileType.AUDIO
elif extension in DOCUMENT_EXTENSIONS:
return FileType.DOCUMENT
return None
def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type
def get_file_type_by_mime_type(mime_type: str) -> FileType:
return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM
class StorageKeyLoader:
"""FileKeyLoader load the storage key from database for a list of files.
This loader is batched, the database query count is constant regardless of the input size.

View File

@ -124,11 +124,19 @@ class AppService:
"completion_params": {},
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
default_model_config["model"]["provider"] = provider
default_model_config["model"]["name"] = model
try:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
except Exception:
logger.exception("Get default provider model failed, tenant_id: %s", tenant_id)
provider = default_model_config["model"].get("provider")
model = default_model_config["model"].get("name")
if provider:
default_model_config["model"]["provider"] = provider
if model:
default_model_config["model"]["name"] = model
default_model_dict = default_model_config["model"]
default_model_config["model"] = json.dumps(default_model_dict)

View File

@ -416,7 +416,7 @@ class TestDatasetApiGet:
"check_dataset_permission",
return_value=None,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock,
):
# embedding models exist → embedding_available stays True
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
@ -520,7 +520,7 @@ class TestDatasetApiGet:
"check_dataset_permission",
return_value=None,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock,
):
# embedding model NOT configured
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
@ -579,7 +579,7 @@ class TestDatasetApiGet:
"get_dataset_partial_member_list",
return_value=partial_members,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock,
):
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []

View File

@ -941,11 +941,11 @@ class TestDatasetListApiGet:
"""Test suite for DatasetListApi.get() endpoint.
``get`` has no billing decorators but calls ``current_user``,
``DatasetService``, ``ProviderManager``, and ``marshal``.
``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``.
"""
@patch("controllers.service_api.dataset.dataset.marshal")
@patch("controllers.service_api.dataset.dataset.ProviderManager")
@patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager")
@patch("controllers.service_api.dataset.dataset.current_user")
@patch("controllers.service_api.dataset.dataset.DatasetService")
def test_list_datasets_success(
@ -1043,12 +1043,12 @@ class TestDatasetApiGet:
"""Test suite for DatasetApi.get() endpoint.
``get`` has no billing decorators but calls ``DatasetService``,
``ProviderManager``, ``marshal``, and ``current_user``.
``create_plugin_provider_manager``, ``marshal``, and ``current_user``.
"""
@patch("controllers.service_api.dataset.dataset.DatasetPermissionService")
@patch("controllers.service_api.dataset.dataset.marshal")
@patch("controllers.service_api.dataset.dataset.ProviderManager")
@patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager")
@patch("controllers.service_api.dataset.dataset.current_user")
@patch("controllers.service_api.dataset.dataset.DatasetService")
def test_get_dataset_success(

View File

@ -28,10 +28,7 @@ def mock_model_instance(mocker):
def mock_model_manager(mocker, mock_model_instance):
manager = mocker.MagicMock()
manager.get_default_model_instance.return_value = mock_model_instance
mocker.patch(
"core.base.tts.app_generator_tts_publisher.ModelManager",
return_value=manager,
)
mocker.patch("core.base.tts.app_generator_tts_publisher.ModelManager.for_tenant", return_value=manager)
return manager
@ -64,16 +61,14 @@ class TestInvoiceTTS:
[None, "", " "],
)
def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance):
result = _invoice_tts(text, mock_model_instance, "tenant", "voice1")
result = _invoice_tts(text, mock_model_instance, "voice1")
assert result is None
mock_model_instance.invoke_tts.assert_not_called()
def test_invoice_tts_valid_text(self, mock_model_instance):
result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1")
result = _invoice_tts(" hello ", mock_model_instance, "voice1")
mock_model_instance.invoke_tts.assert_called_once_with(
content_text="hello",
user="responding_tts",
tenant_id="tenant",
voice="voice1",
)
assert result == [b"audio1", b"audio2"]

View File

@ -350,7 +350,7 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"}
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
with patch(
"core.entities.provider_configuration.encrypter.encrypt_token",
@ -380,7 +380,9 @@ def test_validate_provider_credentials_opens_session_when_not_passed() -> None:
with patch("core.entities.provider_configuration.db") as mock_db:
mock_db.engine = Mock()
mock_session_cls.return_value.__enter__.return_value = mock_session
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
assert validated == {"region": "us"}
@ -434,7 +436,7 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory.get_model_schema.return_value = mock_schema
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
@ -475,7 +477,7 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = provider_schema
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False)
active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True)
@ -689,7 +691,7 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None:
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"}
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
validated = configuration.validate_provider_credentials(
credentials={"openai_api_key": HIDDEN_VALUE},
@ -1034,7 +1036,7 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(
@ -1050,7 +1052,9 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"region": "us"}
with _patched_session(session):
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
@ -1540,7 +1544,7 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing()
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"}
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_provider_credentials(
credentials={"openai_api_key": HIDDEN_VALUE},
@ -1662,7 +1666,7 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory):
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,

View File

@ -28,6 +28,7 @@ class TestCacheEmbeddingMultimodalDocuments:
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "vision-embedding-model"
model_instance.model_name = "vision-embedding-model"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
@ -316,6 +317,7 @@ class TestCacheEmbeddingMultimodalQuery:
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "vision-embedding-model"
model_instance.model_name = "vision-embedding-model"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
return model_instance
@ -467,6 +469,7 @@ class TestCacheEmbeddingQueryErrors:
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
return model_instance
@ -536,20 +539,20 @@ class TestCacheEmbeddingInitialization:
"""Test CacheEmbedding initialization with user parameter."""
model_instance = Mock()
model_instance.model = "test-model"
model_instance.model_name = "test-model"
model_instance.provider = "test-provider"
cache_embedding = CacheEmbedding(model_instance, user="test-user")
assert cache_embedding._model_instance == model_instance
assert cache_embedding._user == "test-user"
def test_initialization_without_user(self):
"""Test CacheEmbedding initialization without user parameter."""
model_instance = Mock()
model_instance.model = "test-model"
model_instance.model_name = "test-model"
model_instance.provider = "test-provider"
cache_embedding = CacheEmbedding(model_instance)
assert cache_embedding._model_instance == model_instance
assert cache_embedding._user is None

View File

@ -1,10 +1,12 @@
import threading
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.graph_engine.entities.commands import CommandType
from dify_graph.graph_events.node import NodeRunSucceededEvent
@ -36,6 +38,11 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
)
def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]:
raw_model_instance = ModelInstance.__new__(ModelInstance)
return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
@ -44,7 +51,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None:
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance = object()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
@ -52,7 +59,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None:
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
model_instance=raw_model_instance,
usage=result_event.node_run_result.llm_usage,
)
@ -65,7 +72,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None:
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance = object()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
@ -73,7 +80,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None:
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
model_instance=raw_model_instance,
usage=result_event.node_run_result.llm_usage,
)
@ -125,7 +132,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_run_context_value.return_value = _build_dify_context()
node.model_instance = object()
node.model_instance, _ = _build_wrapped_model_instance()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
@ -152,7 +159,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance = object()
node.model_instance, _ = _build_wrapped_model_instance()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
@ -178,7 +185,7 @@ def test_quota_precheck_passes_without_abort() -> None:
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance = object()
node.model_instance, raw_model_instance = _build_wrapped_model_instance()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
@ -186,5 +193,5 @@ def test_quota_precheck_passes_without_abort() -> None:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=node.model_instance)
mock_check.assert_called_once_with(model_instance=raw_model_instance)
layer.command_channel.send_command.assert_not_called()

View File

@ -2,89 +2,55 @@ from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel
class TestModerationModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def provider_schema() -> ProviderEntity:
return ProviderEntity(
provider="test_provider",
label=I18nObject(en_US="test_provider"),
supported_model_types=[ModelType.MODERATION],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
@pytest.fixture
def moderation_model(self, mock_plugin_model_provider):
return ModerationModel(
tenant_id="tenant_123",
model_type=ModelType.MODERATION,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
)
def test_model_type(self, moderation_model):
assert moderation_model.model_type == ModelType.MODERATION
@pytest.fixture
def model_runtime() -> MagicMock:
return MagicMock()
def test_invoke_success(self, moderation_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
text = "test text"
user = "user_123"
with (
patch("core.plugin.impl.model.PluginModelClient") as mock_client_class,
patch("time.perf_counter", return_value=1.0),
):
mock_client = mock_client_class.return_value
mock_client.invoke_moderation.return_value = True
@pytest.fixture
def moderation_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> ModerationModel:
return ModerationModel(provider_schema=provider_schema, model_runtime=model_runtime)
result = moderation_model.invoke(model=model_name, credentials=credentials, text=text, user=user)
assert result is True
assert moderation_model.started_at == 1.0
mock_client.invoke_moderation.assert_called_once_with(
tenant_id="tenant_123",
user_id="user_123",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
text=text,
)
def test_model_type(moderation_model: ModerationModel) -> None:
assert moderation_model.model_type == ModelType.MODERATION
def test_invoke_success_no_user(self, moderation_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
text = "test text"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_moderation.return_value = False
def test_invoke_success(moderation_model: ModerationModel, model_runtime: MagicMock) -> None:
with patch("time.perf_counter", return_value=1.0):
model_runtime.invoke_moderation.return_value = True
result = moderation_model.invoke(model=model_name, credentials=credentials, text=text)
result = moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text")
assert result is False
mock_client.invoke_moderation.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
text=text,
)
assert result is True
assert moderation_model.started_at == 1.0
model_runtime.invoke_moderation.assert_called_once_with(
provider="test_provider",
model="test_model",
credentials={"api_key": "abc"},
text="test text",
)
def test_invoke_exception(self, moderation_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
text = "test text"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_moderation.side_effect = Exception("Test error")
def test_invoke_exception(moderation_model: ModerationModel, model_runtime: MagicMock) -> None:
model_runtime.invoke_moderation.side_effect = Exception("Test error")
with pytest.raises(InvokeError) as excinfo:
moderation_model.invoke(model=model_name, credentials=credentials, text=text)
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
with pytest.raises(InvokeError, match="Test error"):
moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text")

View File

@ -1,71 +1,42 @@
from datetime import datetime
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel
@pytest.fixture
def rerank_model() -> RerankModel:
plugin_provider = PluginModelProviderEntity.model_construct(
id="provider-id",
created_at=datetime.now(),
updated_at=datetime.now(),
provider="provider",
tenant_id="tenant",
plugin_unique_identifier="plugin-uid",
plugin_id="plugin-id",
declaration=MagicMock(),
)
return RerankModel.model_construct(
tenant_id="tenant",
model_type=ModelType.RERANK,
plugin_id="plugin-id",
provider_name="provider",
plugin_model_provider=plugin_provider,
def provider_schema() -> ProviderEntity:
return ProviderEntity(
provider="test_provider",
label=I18nObject(en_US="test_provider"),
supported_model_types=[ModelType.RERANK],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
def test_model_type_is_rerank_by_default() -> None:
plugin_provider = PluginModelProviderEntity.model_construct(
id="provider-id",
created_at=datetime.now(),
updated_at=datetime.now(),
provider="provider",
tenant_id="tenant",
plugin_unique_identifier="plugin-uid",
plugin_id="plugin-id",
declaration=MagicMock(),
)
model = RerankModel(
tenant_id="tenant",
plugin_id="plugin-id",
provider_name="provider",
plugin_model_provider=plugin_provider,
)
assert model.model_type == ModelType.RERANK
@pytest.fixture
def model_runtime() -> MagicMock:
return MagicMock()
def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.fixture
def rerank_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> RerankModel:
return RerankModel(provider_schema=provider_schema, model_runtime=model_runtime)
def test_model_type_is_rerank_by_default(rerank_model: RerankModel) -> None:
assert rerank_model.model_type == ModelType.RERANK
def test_invoke_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None:
expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)])
class FakePluginModelClient:
def __init__(self) -> None:
self.invoke_rerank_called_with: dict[str, Any] | None = None
def invoke_rerank(self, **kwargs: Any) -> RerankResult:
self.invoke_rerank_called_with = kwargs
return expected
import core.plugin.impl.model as plugin_model_module
fake_client = FakePluginModelClient()
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
model_runtime.invoke_rerank.return_value = expected
result = rerank_model.invoke(
model="rerank",
@ -74,76 +45,30 @@ def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypa
docs=["d1", "d2"],
score_threshold=0.2,
top_n=10,
user="user-1",
)
assert result == expected
assert fake_client.invoke_rerank_called_with == {
"tenant_id": "tenant",
"user_id": "user-1",
"plugin_id": "plugin-id",
"provider": "provider",
"model": "rerank",
"credentials": {"k": "v"},
"query": "q",
"docs": ["d1", "d2"],
"score_threshold": 0.2,
"top_n": 10,
}
model_runtime.invoke_rerank.assert_called_once_with(
provider="test_provider",
model="rerank",
credentials={"k": "v"},
query="q",
docs=["d1", "d2"],
score_threshold=0.2,
top_n=10,
)
def test_invoke_uses_unknown_user_when_not_provided(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None:
class FakePluginModelClient:
def __init__(self) -> None:
self.kwargs: dict[str, Any] | None = None
def test_invoke_transforms_and_raises_on_runtime_error(rerank_model: RerankModel, model_runtime: MagicMock) -> None:
model_runtime.invoke_rerank.side_effect = Exception("runtime down")
def invoke_rerank(self, **kwargs: Any) -> RerankResult:
self.kwargs = kwargs
return RerankResult(model="m", docs=[])
import core.plugin.impl.model as plugin_model_module
fake_client = FakePluginModelClient()
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"])
assert fake_client.kwargs is not None
assert fake_client.kwargs["user_id"] == "unknown"
def test_invoke_transforms_and_raises_on_plugin_error(
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
) -> None:
class FakePluginModelClient:
def invoke_rerank(self, **_: Any) -> RerankResult:
raise ValueError("plugin down")
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}"))
with pytest.raises(RuntimeError, match="transformed: plugin down"):
with pytest.raises(InvokeError, match="runtime down"):
rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"])
def test_invoke_multimodal_calls_plugin_and_passes_args(
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
) -> None:
def test_invoke_multimodal_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None:
expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)])
class FakePluginModelClient:
def __init__(self) -> None:
self.invoke_multimodal_rerank_called_with: dict[str, Any] | None = None
def invoke_multimodal_rerank(self, **kwargs: Any) -> RerankResult:
self.invoke_multimodal_rerank_called_with = kwargs
return expected
import core.plugin.impl.model as plugin_model_module
fake_client = FakePluginModelClient()
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client)
model_runtime.invoke_multimodal_rerank.return_value = expected
query = {"type": "text", "text": "q"}
docs = [{"type": "text", "text": "d1"}]
@ -154,28 +79,15 @@ def test_invoke_multimodal_calls_plugin_and_passes_args(
docs=docs,
score_threshold=None,
top_n=None,
user=None,
)
assert result == expected
assert fake_client.invoke_multimodal_rerank_called_with is not None
assert fake_client.invoke_multimodal_rerank_called_with["tenant_id"] == "tenant"
assert fake_client.invoke_multimodal_rerank_called_with["user_id"] == "unknown"
assert fake_client.invoke_multimodal_rerank_called_with["query"] == query
assert fake_client.invoke_multimodal_rerank_called_with["docs"] == docs
def test_invoke_multimodal_transforms_and_raises_on_plugin_error(
rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch
) -> None:
class FakePluginModelClient:
def invoke_multimodal_rerank(self, **_: Any) -> RerankResult:
raise ValueError("plugin down")
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient)
monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}"))
with pytest.raises(RuntimeError, match="transformed: plugin down"):
rerank_model.invoke_multimodal_rerank(model="m", credentials={}, query={"q": 1}, docs=[{"d": 1}])
model_runtime.invoke_multimodal_rerank.assert_called_once_with(
provider="test_provider",
model="mm",
credentials={"k": "v"},
query=query,
docs=docs,
score_threshold=None,
top_n=None,
)

View File

@ -1,87 +1,56 @@
from io import BytesIO
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
class TestSpeech2TextModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def provider_schema() -> ProviderEntity:
return ProviderEntity(
provider="test_provider",
label=I18nObject(en_US="test_provider"),
supported_model_types=[ModelType.SPEECH2TEXT],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
@pytest.fixture
def speech2text_model(self, mock_plugin_model_provider):
return Speech2TextModel(
tenant_id="tenant_123",
model_type=ModelType.SPEECH2TEXT,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
)
def test_model_type(self, speech2text_model):
assert speech2text_model.model_type == ModelType.SPEECH2TEXT
@pytest.fixture
def model_runtime() -> MagicMock:
return MagicMock()
def test_invoke_success(self, speech2text_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
file = BytesIO(b"audio data")
user = "user_123"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_speech_to_text.return_value = "transcribed text"
@pytest.fixture
def speech2text_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> Speech2TextModel:
return Speech2TextModel(provider_schema=provider_schema, model_runtime=model_runtime)
result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file, user=user)
assert result == "transcribed text"
mock_client.invoke_speech_to_text.assert_called_once_with(
tenant_id="tenant_123",
user_id="user_123",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
file=file,
)
def test_model_type(speech2text_model: Speech2TextModel) -> None:
assert speech2text_model.model_type == ModelType.SPEECH2TEXT
def test_invoke_success_no_user(self, speech2text_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
file = BytesIO(b"audio data")
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_speech_to_text.return_value = "transcribed text"
def test_invoke_success(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None:
file = BytesIO(b"audio data")
model_runtime.invoke_speech_to_text.return_value = "transcribed text"
result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file)
result = speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=file)
assert result == "transcribed text"
mock_client.invoke_speech_to_text.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
file=file,
)
assert result == "transcribed text"
model_runtime.invoke_speech_to_text.assert_called_once_with(
provider="test_provider",
model="test_model",
credentials={"api_key": "abc"},
file=file,
)
def test_invoke_exception(self, speech2text_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
file = BytesIO(b"audio data")
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_speech_to_text.side_effect = Exception("Test error")
def test_invoke_exception(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None:
model_runtime.invoke_speech_to_text.side_effect = Exception("Test error")
with pytest.raises(InvokeError) as excinfo:
speech2text_model.invoke(model=model_name, credentials=credentials, file=file)
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
with pytest.raises(InvokeError, match="Test error"):
speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=BytesIO(b"audio data"))

View File

@ -2,184 +2,145 @@ from unittest.mock import MagicMock, patch
import pytest
from core.entities.embedding_type import EmbeddingInputType
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
class TestTextEmbeddingModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def provider_schema() -> ProviderEntity:
return ProviderEntity(
provider="test_provider",
label=I18nObject(en_US="test_provider"),
supported_model_types=[ModelType.TEXT_EMBEDDING],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
@pytest.fixture
def text_embedding_model(self, mock_plugin_model_provider):
return TextEmbeddingModel(
tenant_id="tenant_123",
model_type=ModelType.TEXT_EMBEDDING,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
)
def test_model_type(self, text_embedding_model):
assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING
@pytest.fixture
def model_runtime() -> MagicMock:
return MagicMock()
def test_invoke_with_texts(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
texts = ["hello", "world"]
user = "user_123"
expected_result = MagicMock(spec=EmbeddingResult)
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_text_embedding.return_value = expected_result
@pytest.fixture
def text_embedding_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TextEmbeddingModel:
return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=model_runtime)
result = text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts, user=user)
assert result == expected_result
mock_client.invoke_text_embedding.assert_called_once_with(
tenant_id="tenant_123",
user_id="user_123",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
texts=texts,
input_type=EmbeddingInputType.DOCUMENT,
)
def test_model_type(text_embedding_model: TextEmbeddingModel) -> None:
assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING
def test_invoke_with_multimodel_documents(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
multimodel_documents = [{"type": "text", "text": "hello"}]
expected_result = MagicMock(spec=EmbeddingResult)
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_multimodal_embedding.return_value = expected_result
def test_invoke_with_texts(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None:
expected_result = MagicMock(spec=EmbeddingResult)
model_runtime.invoke_text_embedding.return_value = expected_result
result = text_embedding_model.invoke(
model=model_name, credentials=credentials, multimodel_documents=multimodel_documents
)
result = text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"])
assert result == expected_result
mock_client.invoke_multimodal_embedding.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
documents=multimodel_documents,
input_type=EmbeddingInputType.DOCUMENT,
)
assert result == expected_result
model_runtime.invoke_text_embedding.assert_called_once_with(
provider="test_provider",
model="test_model",
credentials={"api_key": "abc"},
texts=["hello", "world"],
input_type=EmbeddingInputType.DOCUMENT,
)
def test_invoke_no_input(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
with pytest.raises(ValueError) as excinfo:
text_embedding_model.invoke(model=model_name, credentials=credentials)
def test_invoke_with_multimodal_documents(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None:
expected_result = MagicMock(spec=EmbeddingResult)
model_runtime.invoke_multimodal_embedding.return_value = expected_result
assert "No texts or files provided" in str(excinfo.value)
result = text_embedding_model.invoke(
model="test_model",
credentials={"api_key": "abc"},
multimodel_documents=[{"type": "text", "text": "hello"}],
)
def test_invoke_precedence(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
texts = ["hello"]
multimodel_documents = [{"type": "text", "text": "world"}]
expected_result = MagicMock(spec=EmbeddingResult)
assert result == expected_result
model_runtime.invoke_multimodal_embedding.assert_called_once_with(
provider="test_provider",
model="test_model",
credentials={"api_key": "abc"},
documents=[{"type": "text", "text": "hello"}],
input_type=EmbeddingInputType.DOCUMENT,
)
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_text_embedding.return_value = expected_result
result = text_embedding_model.invoke(
model=model_name, credentials=credentials, texts=texts, multimodel_documents=multimodel_documents
)
def test_invoke_no_input(text_embedding_model: TextEmbeddingModel) -> None:
with pytest.raises(ValueError, match="No texts or files provided"):
text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"})
assert result == expected_result
mock_client.invoke_text_embedding.assert_called_once()
mock_client.invoke_multimodal_embedding.assert_not_called()
def test_invoke_exception(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
texts = ["hello"]
def test_invoke_prefers_texts_over_multimodal_documents(
text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock
) -> None:
expected_result = MagicMock(spec=EmbeddingResult)
model_runtime.invoke_text_embedding.return_value = expected_result
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_text_embedding.side_effect = Exception("Test error")
result = text_embedding_model.invoke(
model="test_model",
credentials={"api_key": "abc"},
texts=["hello"],
multimodel_documents=[{"type": "text", "text": "world"}],
)
with pytest.raises(InvokeError) as excinfo:
text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts)
assert result == expected_result
model_runtime.invoke_text_embedding.assert_called_once()
model_runtime.invoke_multimodal_embedding.assert_not_called()
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
def test_get_num_tokens(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
texts = ["hello", "world"]
expected_tokens = [1, 1]
def test_invoke_exception(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None:
model_runtime.invoke_text_embedding.side_effect = Exception("Test error")
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.get_text_embedding_num_tokens.return_value = expected_tokens
with pytest.raises(InvokeError, match="Test error"):
text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello"])
result = text_embedding_model.get_num_tokens(model=model_name, credentials=credentials, texts=texts)
assert result == expected_tokens
mock_client.get_text_embedding_num_tokens.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
texts=texts,
)
def test_get_num_tokens(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None:
model_runtime.get_text_embedding_num_tokens.return_value = [1, 1]
def test_get_context_size(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
result = text_embedding_model.get_num_tokens(
model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"]
)
# Test case 1: Context size in schema
mock_schema = MagicMock()
mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048}
assert result == [1, 1]
model_runtime.get_text_embedding_num_tokens.assert_called_once_with(
provider="test_provider",
model="test_model",
credentials={"api_key": "abc"},
texts=["hello", "world"],
)
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_context_size(model_name, credentials) == 2048
# Test case 2: No schema
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
assert text_embedding_model._get_context_size(model_name, credentials) == 1000
def test_get_context_size(text_embedding_model: TextEmbeddingModel) -> None:
mock_schema = MagicMock()
mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048}
# Test case 3: Context size NOT in schema properties
mock_schema.model_properties = {}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_context_size(model_name, credentials) == 1000
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 2048
def test_get_max_chunks(self, text_embedding_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000
# Test case 1: Max chunks in schema
mock_schema = MagicMock()
mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10}
mock_schema.model_properties = {}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_max_chunks(model_name, credentials) == 10
# Test case 2: No schema
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
assert text_embedding_model._get_max_chunks(model_name, credentials) == 1
def test_get_max_chunks(text_embedding_model: TextEmbeddingModel) -> None:
mock_schema = MagicMock()
mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10}
# Test case 3: Max chunks NOT in schema properties
mock_schema.model_properties = {}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_max_chunks(model_name, credentials) == 1
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 10
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None):
assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1
mock_schema.model_properties = {}
with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema):
assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1

View File

@ -1,131 +1,83 @@
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.errors.invoke import InvokeError
from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel
class TestTTSModel:
@pytest.fixture
def mock_plugin_model_provider(self):
return MagicMock(spec=PluginModelProviderEntity)
@pytest.fixture
def provider_schema() -> ProviderEntity:
return ProviderEntity(
provider="test_provider",
label=I18nObject(en_US="test_provider"),
supported_model_types=[ModelType.TTS],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
@pytest.fixture
def tts_model(self, mock_plugin_model_provider):
return TTSModel(
tenant_id="tenant_123",
model_type=ModelType.TTS,
plugin_id="plugin_123",
provider_name="test_provider",
plugin_model_provider=mock_plugin_model_provider,
@pytest.fixture
def model_runtime() -> MagicMock:
return MagicMock()
@pytest.fixture
def tts_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TTSModel:
return TTSModel(provider_schema=provider_schema, model_runtime=model_runtime)
def test_model_type(tts_model: TTSModel) -> None:
assert tts_model.model_type == ModelType.TTS
def test_invoke_success(tts_model: TTSModel, model_runtime: MagicMock) -> None:
model_runtime.invoke_tts.return_value = [b"audio_chunk"]
result = tts_model.invoke(
model="test_model",
credentials={"api_key": "abc"},
content_text="Hello world",
voice="alloy",
)
assert list(result) == [b"audio_chunk"]
model_runtime.invoke_tts.assert_called_once_with(
provider="test_provider",
model="test_model",
credentials={"api_key": "abc"},
content_text="Hello world",
voice="alloy",
)
def test_invoke_exception(tts_model: TTSModel, model_runtime: MagicMock) -> None:
model_runtime.invoke_tts.side_effect = Exception("Test error")
with pytest.raises(InvokeError, match="Test error"):
tts_model.invoke(
model="test_model",
credentials={"api_key": "abc"},
content_text="Hello world",
voice="alloy",
)
def test_model_type(self, tts_model):
assert tts_model.model_type == ModelType.TTS
def test_invoke_success(self, tts_model):
model_name = "test_model"
tenant_id = "ignored_tenant_id"
credentials = {"api_key": "abc"}
content_text = "Hello world"
voice = "alloy"
user = "user_123"
def test_get_tts_model_voices(tts_model: TTSModel, model_runtime: MagicMock) -> None:
model_runtime.get_tts_model_voices.return_value = [{"name": "Voice1"}]
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_tts.return_value = [b"audio_chunk"]
result = tts_model.get_tts_model_voices(
model="test_model",
credentials={"api_key": "abc"},
language="en-US",
)
result = tts_model.invoke(
model=model_name,
tenant_id=tenant_id,
credentials=credentials,
content_text=content_text,
voice=voice,
user=user,
)
assert list(result) == [b"audio_chunk"]
mock_client.invoke_tts.assert_called_once_with(
tenant_id="tenant_123",
user_id="user_123",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def test_invoke_success_no_user(self, tts_model):
model_name = "test_model"
tenant_id = "ignored_tenant_id"
credentials = {"api_key": "abc"}
content_text = "Hello world"
voice = "alloy"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_tts.return_value = [b"audio_chunk"]
result = tts_model.invoke(
model=model_name, tenant_id=tenant_id, credentials=credentials, content_text=content_text, voice=voice
)
assert list(result) == [b"audio_chunk"]
mock_client.invoke_tts.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def test_invoke_exception(self, tts_model):
model_name = "test_model"
tenant_id = "ignored_tenant_id"
credentials = {"api_key": "abc"}
content_text = "Hello world"
voice = "alloy"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.invoke_tts.side_effect = Exception("Test error")
with pytest.raises(InvokeError) as excinfo:
tts_model.invoke(
model=model_name,
tenant_id=tenant_id,
credentials=credentials,
content_text=content_text,
voice=voice,
)
assert "[test_provider] Error: Test error" in str(excinfo.value.description)
def test_get_tts_model_voices(self, tts_model):
model_name = "test_model"
credentials = {"api_key": "abc"}
language = "en-US"
with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class:
mock_client = mock_client_class.return_value
mock_client.get_tts_model_voices.return_value = [{"name": "Voice1"}]
result = tts_model.get_tts_model_voices(model=model_name, credentials=credentials, language=language)
assert result == [{"name": "Voice1"}]
mock_client.get_tts_model_voices.assert_called_once_with(
tenant_id="tenant_123",
user_id="unknown",
plugin_id="plugin_123",
provider="test_provider",
model=model_name,
credentials=credentials,
language=language,
)
assert result == [{"name": "Voice1"}]
model_runtime.get_tts_model_voices.assert_called_once_with(
provider="test_provider",
model="test_model",
credentials={"api_key": "abc"},
language="en-US",
)

View File

@ -1,522 +0,0 @@
import logging
from datetime import datetime
from threading import Lock
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from redis import RedisError
import contexts
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
)
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
def _provider_entity(
*,
provider: str,
supported_model_types: list[ModelType] | None = None,
models: list[AIModelEntity] | None = None,
icon_small: I18nObject | None = None,
icon_small_dark: I18nObject | None = None,
) -> ProviderEntity:
return ProviderEntity(
provider=provider,
label=I18nObject(en_US=provider),
supported_model_types=supported_model_types or [ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=models or [],
icon_small=icon_small,
icon_small_dark=icon_small_dark,
)
def _plugin_provider(
*, plugin_id: str, declaration: ProviderEntity, provider: str = "provider"
) -> PluginModelProviderEntity:
return PluginModelProviderEntity.model_construct(
id=f"{plugin_id}-id",
created_at=datetime.now(),
updated_at=datetime.now(),
provider=provider,
tenant_id="tenant",
plugin_unique_identifier=f"{plugin_id}-uid",
plugin_id=plugin_id,
declaration=declaration,
)
@pytest.fixture(autouse=True)
def _reset_plugin_model_provider_context() -> None:
contexts.plugin_model_providers_lock.set(Lock())
contexts.plugin_model_providers.set(None)
@pytest.fixture
def fake_plugin_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
manager = MagicMock()
import core.plugin.impl.model as plugin_model_module
monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: manager)
return manager
@pytest.fixture
def factory(fake_plugin_manager: MagicMock) -> ModelProviderFactory:
return ModelProviderFactory(tenant_id="tenant")
def test_get_plugin_model_providers_initializes_context_on_lookup_error(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
original_get = contexts.plugin_model_providers.get
calls = {"n": 0}
def flaky_get() -> Any:
calls["n"] += 1
if calls["n"] == 1:
raise LookupError
return original_get()
monkeypatch.setattr(contexts.plugin_model_providers, "get", flaky_get)
providers = factory.get_plugin_model_providers()
assert len(providers) == 1
assert providers[0].declaration.provider == "langgenius/openai/openai"
def test_get_plugin_model_providers_caches_and_does_not_refetch(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
first = factory.get_plugin_model_providers()
second = factory.get_plugin_model_providers()
assert first is second
fake_plugin_manager.fetch_model_providers.assert_called_once_with("tenant")
def test_get_providers_returns_declarations(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None:
d1 = _provider_entity(provider="openai")
d2 = _provider_entity(provider="anthropic")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=d1),
_plugin_provider(plugin_id="langgenius/anthropic", declaration=d2),
]
providers = factory.get_providers()
assert [p.provider for p in providers] == ["langgenius/openai/openai", "langgenius/anthropic/anthropic"]
def test_get_plugin_model_provider_converts_short_provider_id(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
provider = factory.get_plugin_model_provider("openai")
assert provider.declaration.provider == "langgenius/openai/openai"
def test_get_plugin_model_provider_raises_on_invalid_provider(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
with pytest.raises(ValueError, match="Invalid provider"):
factory.get_plugin_model_provider("langgenius/unknown/unknown")
def test_get_provider_schema_returns_declaration(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None:
declaration = _provider_entity(provider="openai")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=declaration)
]
schema = factory.get_provider_schema("openai")
assert schema.provider == "langgenius/openai/openai"
def test_provider_credentials_validate_errors_when_schema_missing(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
schema = _provider_entity(provider="openai")
schema.provider_credential_schema = None
monkeypatch.setattr(
factory,
"get_plugin_model_provider",
lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema),
)
with pytest.raises(ValueError, match="does not have provider_credential_schema"):
factory.provider_credentials_validate(provider="openai", credentials={"x": "y"})
def test_provider_credentials_validate_filters_and_calls_plugin_validation(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
) -> None:
schema = _provider_entity(provider="openai")
schema.provider_credential_schema = MagicMock()
plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema)
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider)
fake_validator = MagicMock()
fake_validator.validate_and_filter.return_value = {"filtered": True}
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.model_provider_factory.ProviderCredentialSchemaValidator",
lambda _: fake_validator,
)
filtered = factory.provider_credentials_validate(provider="openai", credentials={"raw": True})
assert filtered == {"filtered": True}
fake_plugin_manager.validate_provider_credentials.assert_called_once()
kwargs = fake_plugin_manager.validate_provider_credentials.call_args.kwargs
assert kwargs["plugin_id"] == "langgenius/openai"
assert kwargs["provider"] == "provider"
assert kwargs["credentials"] == {"filtered": True}
def test_model_credentials_validate_errors_when_schema_missing(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
schema = _provider_entity(provider="openai")
schema.model_credential_schema = None
monkeypatch.setattr(
factory,
"get_plugin_model_provider",
lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema),
)
with pytest.raises(ValueError, match="does not have model_credential_schema"):
factory.model_credentials_validate(
provider="openai", model_type=ModelType.LLM, model="m", credentials={"x": "y"}
)
def test_model_credentials_validate_filters_and_calls_plugin_validation(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch
) -> None:
schema = _provider_entity(provider="openai")
schema.model_credential_schema = MagicMock()
plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema)
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider)
fake_validator = MagicMock()
fake_validator.validate_and_filter.return_value = {"filtered": True}
monkeypatch.setattr(
"dify_graph.model_runtime.model_providers.model_provider_factory.ModelCredentialSchemaValidator",
lambda *_: fake_validator,
)
filtered = factory.model_credentials_validate(
provider="openai", model_type=ModelType.TEXT_EMBEDDING, model="m", credentials={"raw": True}
)
assert filtered == {"filtered": True}
kwargs = fake_plugin_manager.validate_model_credentials.call_args.kwargs
assert kwargs["plugin_id"] == "langgenius/openai"
assert kwargs["provider"] == "provider"
assert kwargs["model_type"] == "text-embedding"
assert kwargs["model"] == "m"
assert kwargs["credentials"] == {"filtered": True}
def test_get_model_schema_cache_hit(factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch) -> None:
model_schema = AIModelEntity(
model="m",
label=I18nObject(en_US="m"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.return_value = model_schema.model_dump_json().encode()
assert (
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials={"k": "v"})
== model_schema
)
def test_get_model_schema_cache_invalid_json_deletes_key(
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level(logging.WARNING)
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.return_value = b'{"model":"m"}'
factory.plugin_model_manager.get_model_schema.return_value = None
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None
assert mock_redis.delete.called
assert any("Failed to validate cached plugin model schema" in r.message for r in caplog.records)
def test_get_model_schema_cache_delete_redis_error_is_logged(
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level(logging.WARNING)
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.return_value = b'{"model":"m"}'
mock_redis.delete.side_effect = RedisError("nope")
factory.plugin_model_manager.get_model_schema.return_value = None
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None)
assert any("Failed to delete invalid plugin model schema cache" in r.message for r in caplog.records)
def test_get_model_schema_redis_get_error_falls_back_to_plugin(
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level(logging.WARNING)
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
factory.plugin_model_manager.get_model_schema.return_value = None
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.side_effect = RedisError("down")
assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None
assert any("Failed to read plugin model schema cache" in r.message for r in caplog.records)
def test_get_model_schema_cache_miss_sets_cache_and_handles_setex_error(
factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level(logging.WARNING)
factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign]
model_schema = AIModelEntity(
model="m",
label=I18nObject(en_US="m"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
factory.plugin_model_manager.get_model_schema.return_value = model_schema
with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis:
mock_redis.get.return_value = None
mock_redis.setex.side_effect = RedisError("nope")
assert (
factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None)
== model_schema
)
assert any("Failed to write plugin model schema cache" in r.message for r in caplog.records)
@pytest.mark.parametrize(
("model_type", "expected_class"),
[
(ModelType.LLM, "LargeLanguageModel"),
(ModelType.TEXT_EMBEDDING, "TextEmbeddingModel"),
(ModelType.RERANK, "RerankModel"),
(ModelType.SPEECH2TEXT, "Speech2TextModel"),
(ModelType.MODERATION, "ModerationModel"),
(ModelType.TTS, "TTSModel"),
],
)
def test_get_model_type_instance_dispatches_by_type(
factory: ModelProviderFactory, model_type: ModelType, expected_class: str, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity))
sentinel = object()
monkeypatch.setattr(
f"dify_graph.model_runtime.model_providers.model_provider_factory.{expected_class}",
MagicMock(model_validate=lambda _: sentinel),
)
assert factory.get_model_type_instance("langgenius/openai/openai", model_type) is sentinel
def test_get_model_type_instance_raises_on_unsupported(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov"))
monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity))
class UnknownModelType:
pass
with pytest.raises(ValueError, match="Unsupported model type"):
factory.get_model_type_instance("langgenius/openai/openai", UnknownModelType()) # type: ignore[arg-type]
def test_get_models_filters_by_provider_and_model_type(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
llm = AIModelEntity(
model="m1",
label=I18nObject(en_US="m1"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
embed = AIModelEntity(
model="e1",
label=I18nObject(en_US="e1"),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024},
parameter_rules=[],
)
openai = _provider_entity(
provider="openai", supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], models=[llm, embed]
)
anthropic = _provider_entity(provider="anthropic", supported_model_types=[ModelType.LLM], models=[llm])
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=openai),
_plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic),
]
# ModelType filter picks only matching models
providers = factory.get_models(model_type=ModelType.TEXT_EMBEDDING)
assert len(providers) == 1
assert providers[0].provider == "langgenius/openai/openai"
assert [m.model for m in providers[0].models] == ["e1"]
# Provider filter excludes others
providers = factory.get_models(provider="langgenius/anthropic/anthropic", model_type=ModelType.LLM)
assert len(providers) == 1
assert providers[0].provider == "langgenius/anthropic/anthropic"
def test_get_models_provider_filter_skips_non_matching(
factory: ModelProviderFactory, fake_plugin_manager: MagicMock
) -> None:
openai = _provider_entity(provider="openai")
anthropic = _provider_entity(provider="anthropic")
fake_plugin_manager.fetch_model_providers.return_value = [
_plugin_provider(plugin_id="langgenius/openai", declaration=openai),
_plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic),
]
providers = factory.get_models(provider="langgenius/not-exist/not-exist", model_type=ModelType.LLM)
assert providers == []
def test_get_provider_icon_fetches_asset_and_returns_mime_type(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(
provider="langgenius/openai/openai",
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon-zh.png"),
icon_small_dark=I18nObject(en_US="dark.svg", zh_Hans="dark-zh.svg"),
)
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
class FakePluginAssetManager:
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
assert tenant_id == "tenant"
return f"bytes:{id}".encode()
import core.plugin.impl.asset as asset_module
monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager)
data, mime = factory.get_provider_icon("openai", "icon_small", "en_US")
assert data == b"bytes:icon.png"
assert mime == "image/png"
data, mime = factory.get_provider_icon("openai", "icon_small_dark", "zh_Hans")
assert data == b"bytes:dark-zh.svg"
assert mime == "image/svg+xml"
def test_get_provider_icon_uses_zh_hans_for_small_and_en_us_for_dark(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(
provider="langgenius/openai/openai",
icon_small=I18nObject(en_US="icon-en.png", zh_Hans="icon-zh.png"),
icon_small_dark=I18nObject(en_US="dark-en.svg", zh_Hans="dark-zh.svg"),
)
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
class FakePluginAssetManager:
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
return id.encode()
import core.plugin.impl.asset as asset_module
monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager)
data, _ = factory.get_provider_icon("openai", "icon_small", "zh_Hans")
assert data == b"icon-zh.png"
data, _ = factory.get_provider_icon("openai", "icon_small_dark", "en_US")
assert data == b"dark-en.svg"
def test_get_provider_icon_raises_for_missing_icons(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(provider="langgenius/openai/openai")
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
with pytest.raises(ValueError, match="does not have small icon"):
factory.get_provider_icon("openai", "icon_small", "en_US")
with pytest.raises(ValueError, match="does not have small dark icon"):
factory.get_provider_icon("openai", "icon_small_dark", "en_US")
def test_get_provider_icon_raises_for_unsupported_icon_type(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(
provider="langgenius/openai/openai",
icon_small=I18nObject(en_US="", zh_Hans=""),
)
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
with pytest.raises(ValueError, match="Unsupported icon type"):
factory.get_provider_icon("openai", "nope", "en_US")
def test_get_provider_icon_raises_when_file_name_missing(
factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch
) -> None:
provider_schema = _provider_entity(
provider="langgenius/openai/openai",
icon_small=I18nObject(en_US="", zh_Hans=""),
)
monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema)
with pytest.raises(ValueError, match="does not have icon"):
factory.get_provider_icon("openai", "icon_small", "en_US")
def test_get_plugin_id_and_provider_name_from_provider_handles_google_special_case(
factory: ModelProviderFactory,
) -> None:
plugin_id, provider_name = factory.get_plugin_id_and_provider_name_from_provider("google")
assert plugin_id == "langgenius/gemini"
assert provider_name == "google"

View File

@ -336,7 +336,6 @@ def test_structured_output_parser():
json_schema=case["json_schema"],
stream=case["stream"],
model_parameters={"temperature": 0.7, "max_tokens": 100},
user="test_user",
)
if case["expected_result_type"] == "generator":
@ -367,7 +366,7 @@ def test_structured_output_parser():
call_args = model_instance.invoke_llm.call_args
assert call_args.kwargs["stream"] == case["stream"]
assert call_args.kwargs["user"] == "test_user"
assert "user" not in call_args.kwargs
assert "temperature" in call_args.kwargs["model_parameters"]
assert "max_tokens" in call_args.kwargs["model_parameters"]