From fb113bf3a47ad1779e1ae973a6e559a2ce395ebf Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 16 Mar 2026 05:33:49 +0800 Subject: [PATCH] refactor: fix decoupled runtime CI regressions --- api/core/app/apps/chat/app_runner.py | 1 - api/core/app/apps/completion/app_runner.py | 1 - api/core/app/workflow/layers/llm_quota.py | 15 +- api/core/rag/embedding/cached_embedding.py | 2 +- api/core/tools/tool_manager.py | 16 +- api/core/workflow/node_runtime.py | 6 +- api/dify_graph/file/__init__.py | 3 + api/dify_graph/file/file_factory.py | 40 ++ api/dify_graph/nodes/runtime.py | 2 +- api/dify_graph/nodes/tool/tool_node.py | 8 +- ...rameters_cache_when_sync_draft_workflow.py | 6 +- api/factories/file_factory.py | 55 +- api/services/app_service.py | 18 +- .../console/datasets/test_datasets.py | 6 +- .../service_api/dataset/test_dataset.py | 8 +- .../base/test_app_generator_tts_publisher.py | 11 +- .../test_entities_provider_configuration.py | 22 +- .../rag/embedding/test_cached_embedding.py | 7 +- .../graph_engine/layers/test_llm_quota.py | 23 +- .../__base/test_moderation_model.py | 102 ++-- .../__base/test_rerank_model.py | 180 ++---- .../__base/test_speech2text_model.py | 99 ++-- .../__base/test_text_embedding_model.py | 243 ++++---- .../model_providers/__base/test_tts_model.py | 186 +++---- .../test_model_provider_factory.py | 522 ------------------ .../test_structured_output_parser.py | 3 +- 26 files changed, 428 insertions(+), 1157 deletions(-) create mode 100644 api/dify_graph/file/file_factory.py delete mode 100644 api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f63b38fc86..f8656aac02 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 56a4519879..a62a6ad0ab 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 706caf120e..4bbd229cbb 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -115,11 +115,11 @@ class LLMQuotaLayer(GraphEngineLayer): try: match node.node_type: case BuiltinNodeTypes.LLM: - return cast("LLMNode", node).model_instance + model_instance = cast("LLMNode", node).model_instance case BuiltinNodeTypes.PARAMETER_EXTRACTOR: - return cast("ParameterExtractorNode", node).model_instance + model_instance = cast("ParameterExtractorNode", node).model_instance case BuiltinNodeTypes.QUESTION_CLASSIFIER: - return cast("QuestionClassifierNode", node).model_instance + model_instance = cast("QuestionClassifierNode", node).model_instance case _: return None except AttributeError: @@ -128,3 +128,12 @@ class LLMQuotaLayer(GraphEngineLayer): node.id, ) return None + + if isinstance(model_instance, ModelInstance): + return model_instance + + raw_model_instance = getattr(model_instance, "_model_instance", None) + if isinstance(raw_model_instance, ModelInstance): + return raw_model_instance + + return None diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 3c706d800b..08f851ebd8 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -26,7 +26,7 @@ class CacheEmbedding(Embeddings): @staticmethod def _bind_model_instance(model_instance: ModelInstance, user: str | None) -> ModelInstance: - if user is None: + if user is None or not isinstance(model_instance, ModelInstance): return model_instance tenant_id = model_instance.provider_model_bundle.configuration.tenant_id diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 23a877b7e3..34a1d8aa5d 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa from sqlalchemy import select @@ -31,7 +31,7 @@ from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -62,7 +62,7 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass logger = logging.getLogger(__name__) @@ -77,6 +77,14 @@ class EmojiIconDict(TypedDict): content: str +class WorkflowToolRuntimeSpec(Protocol): + provider_type: ToolProviderType + provider_id: str + tool_name: str + tool_configurations: dict[str, Any] + credential_id: str | None + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -405,7 +413,7 @@ class ToolManager: tenant_id: str, app_id: str, node_id: str, - workflow_tool: "ToolEntity", + workflow_tool: WorkflowToolRuntimeSpec, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 107e55993a..a07830f7e7 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -343,9 +343,9 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): *, provider_name: str, default_icon: str | None = None, - ) -> tuple[str | None, str | None]: - icon = default_icon - icon_dark = None + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: + icon: str | Mapping[str, str] | None = default_icon + icon_dark: str | Mapping[str, str] | None = None manager = PluginInstaller() plugins = manager.list_plugins(self._run_context.tenant_id) diff --git a/api/dify_graph/file/__init__.py b/api/dify_graph/file/__init__.py index 44749ebec3..4908ae9795 100644 --- a/api/dify_graph/file/__init__.py +++ b/api/dify_graph/file/__init__.py @@ -1,5 +1,6 @@ from .constants import FILE_MODEL_IDENTITY from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType +from .file_factory import get_file_type_by_mime_type, standardize_file_type from .models import ( File, FileUploadConfig, @@ -16,4 +17,6 @@ __all__ = [ "FileType", "FileUploadConfig", "ImageConfig", + "get_file_type_by_mime_type", + "standardize_file_type", ] diff --git a/api/dify_graph/file/file_factory.py b/api/dify_graph/file/file_factory.py new file mode 100644 index 0000000000..b8e48e36fc --- /dev/null +++ b/api/dify_graph/file/file_factory.py @@ -0,0 +1,40 @@ +from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS + +from .enums import FileType + + +def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: + """ + Infer the actual file type from extension and mime type. + """ + guessed_type = None + if extension: + guessed_type = _get_file_type_by_extension(extension) + if guessed_type is None and mime_type: + guessed_type = get_file_type_by_mime_type(mime_type) + return guessed_type or FileType.CUSTOM + + +def _get_file_type_by_extension(extension: str) -> FileType | None: + normalized_extension = extension.lstrip(".") + if normalized_extension in IMAGE_EXTENSIONS: + return FileType.IMAGE + if normalized_extension in VIDEO_EXTENSIONS: + return FileType.VIDEO + if normalized_extension in AUDIO_EXTENSIONS: + return FileType.AUDIO + if normalized_extension in DOCUMENT_EXTENSIONS: + return FileType.DOCUMENT + return None + + +def get_file_type_by_mime_type(mime_type: str) -> FileType: + if "image" in mime_type: + return FileType.IMAGE + if "video" in mime_type: + return FileType.VIDEO + if "audio" in mime_type: + return FileType.AUDIO + if "text" in mime_type or "pdf" in mime_type: + return FileType.DOCUMENT + return FileType.CUSTOM diff --git a/api/dify_graph/nodes/runtime.py b/api/dify_graph/nodes/runtime.py index 01f3398807..fc1ee84c6d 100644 --- a/api/dify_graph/nodes/runtime.py +++ b/api/dify_graph/nodes/runtime.py @@ -60,7 +60,7 @@ class ToolNodeRuntimeProtocol(Protocol): *, provider_name: str, default_icon: str | None = None, - ) -> tuple[str | None, str | None]: ... + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ... class HumanInputNodeRuntimeProtocol(Protocol): diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index dc6aedac19..567cfd1692 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -8,7 +8,7 @@ from dify_graph.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod +from dify_graph.file import File, FileTransferMethod, get_file_type_by_mime_type from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.base.node import Node @@ -252,9 +252,9 @@ class ToolNode(Node[ToolNodeData]): if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not found") - mapping = { + mapping: dict[str, Any] = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mimetype), "transfer_method": transfer_method, "url": url, } @@ -270,7 +270,7 @@ class ToolNode(Node[ToolNodeData]): if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not exists") - mapping = { + mapping: dict[str, Any] = { "tool_file_id": tool_file_id, "transfer_method": FileTransferMethod.TOOL_FILE, } diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index c43e99f0f4..bd5bc08bff 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,5 +1,6 @@ import logging +from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from dify_graph.nodes import BuiltinNodeTypes @@ -19,8 +20,9 @@ def handle(sender, **kwargs): if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: tool_entity = ToolEntity.model_validate(node_data["data"]) + provider_type = ToolProviderType(tool_entity.provider_type.value) tool_runtime = ToolManager.get_tool_runtime( - provider_type=tool_entity.provider_type, + provider_type=provider_type, provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, @@ -30,7 +32,7 @@ def handle(sender, **kwargs): tenant_id=app.tenant_id, tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, - provider_type=tool_entity.provider_type, + provider_type=provider_type, identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}", ) manager.delete_tool_parameters_cache() diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index cb07ba58ae..187d6d0bb7 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -12,9 +12,9 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.http import parse_options_header -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.helper import ssrf_proxy from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers +from dify_graph.file.file_factory import standardize_file_type from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile @@ -175,7 +175,7 @@ def _build_from_local_file( if row is None: raise ValueError("Invalid upload file") - detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) specified_type = mapping.get("type", "custom") if strict_type_validation and detected_file_type.value != specified_type: @@ -223,7 +223,7 @@ def _build_from_remote_url( if upload_file is None: raise ValueError("Invalid upload file") - detected_file_type = _standardize_file_type( + detected_file_type = standardize_file_type( extension="." + upload_file.extension, mime_type=upload_file.mime_type ) @@ -257,7 +257,7 @@ def _build_from_remote_url( mime_type, filename, file_size = _get_remote_file_info(url) extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type) + detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type) specified_type = mapping.get("type") if strict_type_validation and specified_type and detected_file_type.value != specified_type: @@ -387,7 +387,7 @@ def _build_from_tool_file( extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) specified_type = mapping.get("type") @@ -436,7 +436,7 @@ def _build_from_datasource_file( extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) specified_type = mapping.get("type") @@ -503,49 +503,6 @@ def _is_file_valid_with_config( return True -def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the possible actual type of the file based on the extension and mime_type - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = _get_file_type_by_mimetype(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - extension = extension.lstrip(".") - if extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - elif extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - elif extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - elif extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM - - class StorageKeyLoader: """FileKeyLoader load the storage key from database for a list of files. This loader is batched, the database query count is constant regardless of the input size. diff --git a/api/services/app_service.py b/api/services/app_service.py index dd38942180..5c14477f38 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -124,11 +124,19 @@ class AppService: "completion_params": {}, } else: - provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM - ) - default_model_config["model"]["provider"] = provider - default_model_config["model"]["name"] = model + try: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM + ) + except Exception: + logger.exception("Get default provider model failed, tenant_id: %s", tenant_id) + provider = default_model_config["model"].get("provider") + model = default_model_config["model"].get("name") + + if provider: + default_model_config["model"]["provider"] = provider + if model: + default_model_config["model"]["name"] = model default_model_dict = default_model_config["model"] default_model_config["model"] = json.dumps(default_model_dict) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 0ee76e504b..0313764b24 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -416,7 +416,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding models exist → embedding_available stays True provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -520,7 +520,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding model NOT configured provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -579,7 +579,7 @@ class TestDatasetApiGet: "get_dataset_partial_member_list", return_value=partial_members, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 7cb2f1050c..3001a46081 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -941,11 +941,11 @@ class TestDatasetListApiGet: """Test suite for DatasetListApi.get() endpoint. ``get`` has no billing decorators but calls ``current_user``, - ``DatasetService``, ``ProviderManager``, and ``marshal``. + ``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``. """ @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_list_datasets_success( @@ -1043,12 +1043,12 @@ class TestDatasetApiGet: """Test suite for DatasetApi.get() endpoint. ``get`` has no billing decorators but calls ``DatasetService``, - ``ProviderManager``, ``marshal``, and ``current_user``. + ``create_plugin_provider_manager``, ``marshal``, and ``current_user``. """ @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_get_dataset_success( diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 3759b6aa37..eb1599bacc 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -28,10 +28,7 @@ def mock_model_instance(mocker): def mock_model_manager(mocker, mock_model_instance): manager = mocker.MagicMock() manager.get_default_model_instance.return_value = mock_model_instance - mocker.patch( - "core.base.tts.app_generator_tts_publisher.ModelManager", - return_value=manager, - ) + mocker.patch("core.base.tts.app_generator_tts_publisher.ModelManager.for_tenant", return_value=manager) return manager @@ -64,16 +61,14 @@ class TestInvoiceTTS: [None, "", " "], ) def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): - result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + result = _invoice_tts(text, mock_model_instance, "voice1") assert result is None mock_model_instance.invoke_tts.assert_not_called() def test_invoice_tts_valid_text(self, mock_model_instance): - result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + result = _invoice_tts(" hello ", mock_model_instance, "voice1") mock_model_instance.invoke_tts.assert_called_once_with( content_text="hello", - user="responding_tts", - tenant_id="tenant", voice="voice1", ) assert result == [b"audio1", b"audio2"] diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 95d58757f1..37d8853a73 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -350,7 +350,7 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): with patch( "core.entities.provider_configuration.encrypter.encrypt_token", @@ -380,7 +380,9 @@ def test_validate_provider_credentials_opens_session_when_not_passed() -> None: with patch("core.entities.provider_configuration.db") as mock_db: mock_db.engine = Mock() mock_session_cls.return_value.__enter__.return_value = mock_session - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} @@ -434,7 +436,7 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) @@ -475,7 +477,7 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N mock_factory = Mock() mock_factory.get_provider_schema.return_value = provider_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) @@ -689,7 +691,7 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1034,7 +1036,7 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( @@ -1050,7 +1052,9 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"region": "us"} with _patched_session(session): - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, model="gpt-4o", @@ -1540,7 +1544,7 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1662,7 +1666,7 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index a0db25174d..d7a50b8603 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -28,6 +28,7 @@ class TestCacheEmbeddingMultimodalDocuments: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} @@ -316,6 +317,7 @@ class TestCacheEmbeddingMultimodalQuery: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance @@ -467,6 +469,7 @@ class TestCacheEmbeddingQueryErrors: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance @@ -536,20 +539,20 @@ class TestCacheEmbeddingInitialization: """Test CacheEmbedding initialization with user parameter.""" model_instance = Mock() model_instance.model = "test-model" + model_instance.model_name = "test-model" model_instance.provider = "test-provider" cache_embedding = CacheEmbedding(model_instance, user="test-user") assert cache_embedding._model_instance == model_instance - assert cache_embedding._user == "test-user" def test_initialization_without_user(self): """Test CacheEmbedding initialization without user parameter.""" model_instance = Mock() model_instance.model = "test-model" + model_instance.model_name = "test-model" model_instance.provider = "test-provider" cache_embedding = CacheEmbedding(model_instance) assert cache_embedding._model_instance == model_instance - assert cache_embedding._user is None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index db9b85640a..4dba88932f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,10 +1,12 @@ import threading from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_engine.entities.commands import CommandType from dify_graph.graph_events.node import NodeRunSucceededEvent @@ -36,6 +38,11 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: ) +def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: + raw_model_instance = ModelInstance.__new__(ModelInstance) + return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance + + def test_deduct_quota_called_for_successful_llm_node() -> None: layer = LLMQuotaLayer() node = MagicMock() @@ -44,7 +51,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" node.require_run_context_value.return_value = _build_dify_context() - node.model_instance = object() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -52,7 +59,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -65,7 +72,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" node.require_run_context_value.return_value = _build_dify_context() - node.model_instance = object() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -73,7 +80,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -125,7 +132,7 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" node.require_run_context_value.return_value = _build_dify_context() - node.model_instance = object() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -152,7 +159,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -178,7 +185,7 @@ def test_quota_precheck_passes_without_abort() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -186,5 +193,5 @@ def test_quota_precheck_passes_without_abort() -> None: layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=node.model_instance) + mock_check.assert_called_once_with(model_instance=raw_model_instance) layer.command_channel.send_command.assert_not_called() diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py index 6ccc44ceb8..3548c25a01 100644 --- a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py @@ -2,89 +2,55 @@ from unittest.mock import MagicMock, patch import pytest -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from dify_graph.model_runtime.errors.invoke import InvokeError from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -class TestModerationModel: - @pytest.fixture - def mock_plugin_model_provider(self): - return MagicMock(spec=PluginModelProviderEntity) +@pytest.fixture +def provider_schema() -> ProviderEntity: + return ProviderEntity( + provider="test_provider", + label=I18nObject(en_US="test_provider"), + supported_model_types=[ModelType.MODERATION], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) - @pytest.fixture - def moderation_model(self, mock_plugin_model_provider): - return ModerationModel( - tenant_id="tenant_123", - model_type=ModelType.MODERATION, - plugin_id="plugin_123", - provider_name="test_provider", - plugin_model_provider=mock_plugin_model_provider, - ) - def test_model_type(self, moderation_model): - assert moderation_model.model_type == ModelType.MODERATION +@pytest.fixture +def model_runtime() -> MagicMock: + return MagicMock() - def test_invoke_success(self, moderation_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - text = "test text" - user = "user_123" - with ( - patch("core.plugin.impl.model.PluginModelClient") as mock_client_class, - patch("time.perf_counter", return_value=1.0), - ): - mock_client = mock_client_class.return_value - mock_client.invoke_moderation.return_value = True +@pytest.fixture +def moderation_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> ModerationModel: + return ModerationModel(provider_schema=provider_schema, model_runtime=model_runtime) - result = moderation_model.invoke(model=model_name, credentials=credentials, text=text, user=user) - assert result is True - assert moderation_model.started_at == 1.0 - mock_client.invoke_moderation.assert_called_once_with( - tenant_id="tenant_123", - user_id="user_123", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - text=text, - ) +def test_model_type(moderation_model: ModerationModel) -> None: + assert moderation_model.model_type == ModelType.MODERATION - def test_invoke_success_no_user(self, moderation_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - text = "test text" - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_moderation.return_value = False +def test_invoke_success(moderation_model: ModerationModel, model_runtime: MagicMock) -> None: + with patch("time.perf_counter", return_value=1.0): + model_runtime.invoke_moderation.return_value = True - result = moderation_model.invoke(model=model_name, credentials=credentials, text=text) + result = moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text") - assert result is False - mock_client.invoke_moderation.assert_called_once_with( - tenant_id="tenant_123", - user_id="unknown", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - text=text, - ) + assert result is True + assert moderation_model.started_at == 1.0 + model_runtime.invoke_moderation.assert_called_once_with( + provider="test_provider", + model="test_model", + credentials={"api_key": "abc"}, + text="test text", + ) - def test_invoke_exception(self, moderation_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - text = "test text" - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_moderation.side_effect = Exception("Test error") +def test_invoke_exception(moderation_model: ModerationModel, model_runtime: MagicMock) -> None: + model_runtime.invoke_moderation.side_effect = Exception("Test error") - with pytest.raises(InvokeError) as excinfo: - moderation_model.invoke(model=model_name, credentials=credentials, text=text) - - assert "[test_provider] Error: Test error" in str(excinfo.value.description) + with pytest.raises(InvokeError, match="Test error"): + moderation_model.invoke(model="test_model", credentials={"api_key": "abc"}, text="test text") diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py index 67828894b3..5528b0e72f 100644 --- a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py @@ -1,71 +1,42 @@ -from datetime import datetime -from typing import Any from unittest.mock import MagicMock import pytest -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from dify_graph.model_runtime.errors.invoke import InvokeError from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel @pytest.fixture -def rerank_model() -> RerankModel: - plugin_provider = PluginModelProviderEntity.model_construct( - id="provider-id", - created_at=datetime.now(), - updated_at=datetime.now(), - provider="provider", - tenant_id="tenant", - plugin_unique_identifier="plugin-uid", - plugin_id="plugin-id", - declaration=MagicMock(), - ) - return RerankModel.model_construct( - tenant_id="tenant", - model_type=ModelType.RERANK, - plugin_id="plugin-id", - provider_name="provider", - plugin_model_provider=plugin_provider, +def provider_schema() -> ProviderEntity: + return ProviderEntity( + provider="test_provider", + label=I18nObject(en_US="test_provider"), + supported_model_types=[ModelType.RERANK], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], ) -def test_model_type_is_rerank_by_default() -> None: - plugin_provider = PluginModelProviderEntity.model_construct( - id="provider-id", - created_at=datetime.now(), - updated_at=datetime.now(), - provider="provider", - tenant_id="tenant", - plugin_unique_identifier="plugin-uid", - plugin_id="plugin-id", - declaration=MagicMock(), - ) - model = RerankModel( - tenant_id="tenant", - plugin_id="plugin-id", - provider_name="provider", - plugin_model_provider=plugin_provider, - ) - assert model.model_type == ModelType.RERANK +@pytest.fixture +def model_runtime() -> MagicMock: + return MagicMock() -def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None: +@pytest.fixture +def rerank_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> RerankModel: + return RerankModel(provider_schema=provider_schema, model_runtime=model_runtime) + + +def test_model_type_is_rerank_by_default(rerank_model: RerankModel) -> None: + assert rerank_model.model_type == ModelType.RERANK + + +def test_invoke_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None: expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)]) - - class FakePluginModelClient: - def __init__(self) -> None: - self.invoke_rerank_called_with: dict[str, Any] | None = None - - def invoke_rerank(self, **kwargs: Any) -> RerankResult: - self.invoke_rerank_called_with = kwargs - return expected - - import core.plugin.impl.model as plugin_model_module - - fake_client = FakePluginModelClient() - monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + model_runtime.invoke_rerank.return_value = expected result = rerank_model.invoke( model="rerank", @@ -74,76 +45,30 @@ def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypa docs=["d1", "d2"], score_threshold=0.2, top_n=10, - user="user-1", ) assert result == expected - assert fake_client.invoke_rerank_called_with == { - "tenant_id": "tenant", - "user_id": "user-1", - "plugin_id": "plugin-id", - "provider": "provider", - "model": "rerank", - "credentials": {"k": "v"}, - "query": "q", - "docs": ["d1", "d2"], - "score_threshold": 0.2, - "top_n": 10, - } + model_runtime.invoke_rerank.assert_called_once_with( + provider="test_provider", + model="rerank", + credentials={"k": "v"}, + query="q", + docs=["d1", "d2"], + score_threshold=0.2, + top_n=10, + ) -def test_invoke_uses_unknown_user_when_not_provided(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None: - class FakePluginModelClient: - def __init__(self) -> None: - self.kwargs: dict[str, Any] | None = None +def test_invoke_transforms_and_raises_on_runtime_error(rerank_model: RerankModel, model_runtime: MagicMock) -> None: + model_runtime.invoke_rerank.side_effect = Exception("runtime down") - def invoke_rerank(self, **kwargs: Any) -> RerankResult: - self.kwargs = kwargs - return RerankResult(model="m", docs=[]) - - import core.plugin.impl.model as plugin_model_module - - fake_client = FakePluginModelClient() - monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) - - rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) - assert fake_client.kwargs is not None - assert fake_client.kwargs["user_id"] == "unknown" - - -def test_invoke_transforms_and_raises_on_plugin_error( - rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch -) -> None: - class FakePluginModelClient: - def invoke_rerank(self, **_: Any) -> RerankResult: - raise ValueError("plugin down") - - import core.plugin.impl.model as plugin_model_module - - monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) - monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}")) - - with pytest.raises(RuntimeError, match="transformed: plugin down"): + with pytest.raises(InvokeError, match="runtime down"): rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) -def test_invoke_multimodal_calls_plugin_and_passes_args( - rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch -) -> None: +def test_invoke_multimodal_calls_runtime_and_passes_args(rerank_model: RerankModel, model_runtime: MagicMock) -> None: expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)]) - - class FakePluginModelClient: - def __init__(self) -> None: - self.invoke_multimodal_rerank_called_with: dict[str, Any] | None = None - - def invoke_multimodal_rerank(self, **kwargs: Any) -> RerankResult: - self.invoke_multimodal_rerank_called_with = kwargs - return expected - - import core.plugin.impl.model as plugin_model_module - - fake_client = FakePluginModelClient() - monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + model_runtime.invoke_multimodal_rerank.return_value = expected query = {"type": "text", "text": "q"} docs = [{"type": "text", "text": "d1"}] @@ -154,28 +79,15 @@ def test_invoke_multimodal_calls_plugin_and_passes_args( docs=docs, score_threshold=None, top_n=None, - user=None, ) assert result == expected - assert fake_client.invoke_multimodal_rerank_called_with is not None - assert fake_client.invoke_multimodal_rerank_called_with["tenant_id"] == "tenant" - assert fake_client.invoke_multimodal_rerank_called_with["user_id"] == "unknown" - assert fake_client.invoke_multimodal_rerank_called_with["query"] == query - assert fake_client.invoke_multimodal_rerank_called_with["docs"] == docs - - -def test_invoke_multimodal_transforms_and_raises_on_plugin_error( - rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch -) -> None: - class FakePluginModelClient: - def invoke_multimodal_rerank(self, **_: Any) -> RerankResult: - raise ValueError("plugin down") - - import core.plugin.impl.model as plugin_model_module - - monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) - monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}")) - - with pytest.raises(RuntimeError, match="transformed: plugin down"): - rerank_model.invoke_multimodal_rerank(model="m", credentials={}, query={"q": 1}, docs=[{"d": 1}]) + model_runtime.invoke_multimodal_rerank.assert_called_once_with( + provider="test_provider", + model="mm", + credentials={"k": "v"}, + query=query, + docs=docs, + score_threshold=None, + top_n=None, + ) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py index f891718dc6..f1c4a0d523 100644 --- a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py @@ -1,87 +1,56 @@ from io import BytesIO -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from dify_graph.model_runtime.errors.invoke import InvokeError from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -class TestSpeech2TextModel: - @pytest.fixture - def mock_plugin_model_provider(self): - return MagicMock(spec=PluginModelProviderEntity) +@pytest.fixture +def provider_schema() -> ProviderEntity: + return ProviderEntity( + provider="test_provider", + label=I18nObject(en_US="test_provider"), + supported_model_types=[ModelType.SPEECH2TEXT], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) - @pytest.fixture - def speech2text_model(self, mock_plugin_model_provider): - return Speech2TextModel( - tenant_id="tenant_123", - model_type=ModelType.SPEECH2TEXT, - plugin_id="plugin_123", - provider_name="test_provider", - plugin_model_provider=mock_plugin_model_provider, - ) - def test_model_type(self, speech2text_model): - assert speech2text_model.model_type == ModelType.SPEECH2TEXT +@pytest.fixture +def model_runtime() -> MagicMock: + return MagicMock() - def test_invoke_success(self, speech2text_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - file = BytesIO(b"audio data") - user = "user_123" - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_speech_to_text.return_value = "transcribed text" +@pytest.fixture +def speech2text_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> Speech2TextModel: + return Speech2TextModel(provider_schema=provider_schema, model_runtime=model_runtime) - result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file, user=user) - assert result == "transcribed text" - mock_client.invoke_speech_to_text.assert_called_once_with( - tenant_id="tenant_123", - user_id="user_123", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - file=file, - ) +def test_model_type(speech2text_model: Speech2TextModel) -> None: + assert speech2text_model.model_type == ModelType.SPEECH2TEXT - def test_invoke_success_no_user(self, speech2text_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - file = BytesIO(b"audio data") - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_speech_to_text.return_value = "transcribed text" +def test_invoke_success(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None: + file = BytesIO(b"audio data") + model_runtime.invoke_speech_to_text.return_value = "transcribed text" - result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file) + result = speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=file) - assert result == "transcribed text" - mock_client.invoke_speech_to_text.assert_called_once_with( - tenant_id="tenant_123", - user_id="unknown", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - file=file, - ) + assert result == "transcribed text" + model_runtime.invoke_speech_to_text.assert_called_once_with( + provider="test_provider", + model="test_model", + credentials={"api_key": "abc"}, + file=file, + ) - def test_invoke_exception(self, speech2text_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - file = BytesIO(b"audio data") - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_speech_to_text.side_effect = Exception("Test error") +def test_invoke_exception(speech2text_model: Speech2TextModel, model_runtime: MagicMock) -> None: + model_runtime.invoke_speech_to_text.side_effect = Exception("Test error") - with pytest.raises(InvokeError) as excinfo: - speech2text_model.invoke(model=model_name, credentials=credentials, file=file) - - assert "[test_provider] Error: Test error" in str(excinfo.value.description) + with pytest.raises(InvokeError, match="Test error"): + speech2text_model.invoke(model="test_model", credentials={"api_key": "abc"}, file=BytesIO(b"audio data")) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py index c8f0a2ad49..74be13d6ec 100644 --- a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py @@ -2,184 +2,145 @@ from unittest.mock import MagicMock, patch import pytest -from core.entities.embedding_type import EmbeddingInputType -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult from dify_graph.model_runtime.errors.invoke import InvokeError from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -class TestTextEmbeddingModel: - @pytest.fixture - def mock_plugin_model_provider(self): - return MagicMock(spec=PluginModelProviderEntity) +@pytest.fixture +def provider_schema() -> ProviderEntity: + return ProviderEntity( + provider="test_provider", + label=I18nObject(en_US="test_provider"), + supported_model_types=[ModelType.TEXT_EMBEDDING], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) - @pytest.fixture - def text_embedding_model(self, mock_plugin_model_provider): - return TextEmbeddingModel( - tenant_id="tenant_123", - model_type=ModelType.TEXT_EMBEDDING, - plugin_id="plugin_123", - provider_name="test_provider", - plugin_model_provider=mock_plugin_model_provider, - ) - def test_model_type(self, text_embedding_model): - assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING +@pytest.fixture +def model_runtime() -> MagicMock: + return MagicMock() - def test_invoke_with_texts(self, text_embedding_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - texts = ["hello", "world"] - user = "user_123" - expected_result = MagicMock(spec=EmbeddingResult) - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_text_embedding.return_value = expected_result +@pytest.fixture +def text_embedding_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TextEmbeddingModel: + return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=model_runtime) - result = text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts, user=user) - assert result == expected_result - mock_client.invoke_text_embedding.assert_called_once_with( - tenant_id="tenant_123", - user_id="user_123", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - texts=texts, - input_type=EmbeddingInputType.DOCUMENT, - ) +def test_model_type(text_embedding_model: TextEmbeddingModel) -> None: + assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING - def test_invoke_with_multimodel_documents(self, text_embedding_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - multimodel_documents = [{"type": "text", "text": "hello"}] - expected_result = MagicMock(spec=EmbeddingResult) - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_multimodal_embedding.return_value = expected_result +def test_invoke_with_texts(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: + expected_result = MagicMock(spec=EmbeddingResult) + model_runtime.invoke_text_embedding.return_value = expected_result - result = text_embedding_model.invoke( - model=model_name, credentials=credentials, multimodel_documents=multimodel_documents - ) + result = text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"]) - assert result == expected_result - mock_client.invoke_multimodal_embedding.assert_called_once_with( - tenant_id="tenant_123", - user_id="unknown", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - documents=multimodel_documents, - input_type=EmbeddingInputType.DOCUMENT, - ) + assert result == expected_result + model_runtime.invoke_text_embedding.assert_called_once_with( + provider="test_provider", + model="test_model", + credentials={"api_key": "abc"}, + texts=["hello", "world"], + input_type=EmbeddingInputType.DOCUMENT, + ) - def test_invoke_no_input(self, text_embedding_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - with pytest.raises(ValueError) as excinfo: - text_embedding_model.invoke(model=model_name, credentials=credentials) +def test_invoke_with_multimodal_documents(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: + expected_result = MagicMock(spec=EmbeddingResult) + model_runtime.invoke_multimodal_embedding.return_value = expected_result - assert "No texts or files provided" in str(excinfo.value) + result = text_embedding_model.invoke( + model="test_model", + credentials={"api_key": "abc"}, + multimodel_documents=[{"type": "text", "text": "hello"}], + ) - def test_invoke_precedence(self, text_embedding_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - texts = ["hello"] - multimodel_documents = [{"type": "text", "text": "world"}] - expected_result = MagicMock(spec=EmbeddingResult) + assert result == expected_result + model_runtime.invoke_multimodal_embedding.assert_called_once_with( + provider="test_provider", + model="test_model", + credentials={"api_key": "abc"}, + documents=[{"type": "text", "text": "hello"}], + input_type=EmbeddingInputType.DOCUMENT, + ) - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_text_embedding.return_value = expected_result - result = text_embedding_model.invoke( - model=model_name, credentials=credentials, texts=texts, multimodel_documents=multimodel_documents - ) +def test_invoke_no_input(text_embedding_model: TextEmbeddingModel) -> None: + with pytest.raises(ValueError, match="No texts or files provided"): + text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}) - assert result == expected_result - mock_client.invoke_text_embedding.assert_called_once() - mock_client.invoke_multimodal_embedding.assert_not_called() - def test_invoke_exception(self, text_embedding_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - texts = ["hello"] +def test_invoke_prefers_texts_over_multimodal_documents( + text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock +) -> None: + expected_result = MagicMock(spec=EmbeddingResult) + model_runtime.invoke_text_embedding.return_value = expected_result - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_text_embedding.side_effect = Exception("Test error") + result = text_embedding_model.invoke( + model="test_model", + credentials={"api_key": "abc"}, + texts=["hello"], + multimodel_documents=[{"type": "text", "text": "world"}], + ) - with pytest.raises(InvokeError) as excinfo: - text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts) + assert result == expected_result + model_runtime.invoke_text_embedding.assert_called_once() + model_runtime.invoke_multimodal_embedding.assert_not_called() - assert "[test_provider] Error: Test error" in str(excinfo.value.description) - def test_get_num_tokens(self, text_embedding_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - texts = ["hello", "world"] - expected_tokens = [1, 1] +def test_invoke_exception(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: + model_runtime.invoke_text_embedding.side_effect = Exception("Test error") - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.get_text_embedding_num_tokens.return_value = expected_tokens + with pytest.raises(InvokeError, match="Test error"): + text_embedding_model.invoke(model="test_model", credentials={"api_key": "abc"}, texts=["hello"]) - result = text_embedding_model.get_num_tokens(model=model_name, credentials=credentials, texts=texts) - assert result == expected_tokens - mock_client.get_text_embedding_num_tokens.assert_called_once_with( - tenant_id="tenant_123", - user_id="unknown", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - texts=texts, - ) +def test_get_num_tokens(text_embedding_model: TextEmbeddingModel, model_runtime: MagicMock) -> None: + model_runtime.get_text_embedding_num_tokens.return_value = [1, 1] - def test_get_context_size(self, text_embedding_model): - model_name = "test_model" - credentials = {"api_key": "abc"} + result = text_embedding_model.get_num_tokens( + model="test_model", credentials={"api_key": "abc"}, texts=["hello", "world"] + ) - # Test case 1: Context size in schema - mock_schema = MagicMock() - mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048} + assert result == [1, 1] + model_runtime.get_text_embedding_num_tokens.assert_called_once_with( + provider="test_provider", + model="test_model", + credentials={"api_key": "abc"}, + texts=["hello", "world"], + ) - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_context_size(model_name, credentials) == 2048 - # Test case 2: No schema - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): - assert text_embedding_model._get_context_size(model_name, credentials) == 1000 +def test_get_context_size(text_embedding_model: TextEmbeddingModel) -> None: + mock_schema = MagicMock() + mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048} - # Test case 3: Context size NOT in schema properties - mock_schema.model_properties = {} - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_context_size(model_name, credentials) == 1000 + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 2048 - def test_get_max_chunks(self, text_embedding_model): - model_name = "test_model" - credentials = {"api_key": "abc"} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): + assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000 - # Test case 1: Max chunks in schema - mock_schema = MagicMock() - mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} + mock_schema.model_properties = {} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_context_size("test_model", {"api_key": "abc"}) == 1000 - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_max_chunks(model_name, credentials) == 10 - # Test case 2: No schema - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): - assert text_embedding_model._get_max_chunks(model_name, credentials) == 1 +def test_get_max_chunks(text_embedding_model: TextEmbeddingModel) -> None: + mock_schema = MagicMock() + mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} - # Test case 3: Max chunks NOT in schema properties - mock_schema.model_properties = {} - with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): - assert text_embedding_model._get_max_chunks(model_name, credentials) == 1 + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 10 + + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): + assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1 + + mock_schema.model_properties = {} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_max_chunks("test_model", {"api_key": "abc"}) == 1 diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py index b1aca9baa3..f5ae390867 100644 --- a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py @@ -1,131 +1,83 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from dify_graph.model_runtime.errors.invoke import InvokeError from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel -class TestTTSModel: - @pytest.fixture - def mock_plugin_model_provider(self): - return MagicMock(spec=PluginModelProviderEntity) +@pytest.fixture +def provider_schema() -> ProviderEntity: + return ProviderEntity( + provider="test_provider", + label=I18nObject(en_US="test_provider"), + supported_model_types=[ModelType.TTS], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) - @pytest.fixture - def tts_model(self, mock_plugin_model_provider): - return TTSModel( - tenant_id="tenant_123", - model_type=ModelType.TTS, - plugin_id="plugin_123", - provider_name="test_provider", - plugin_model_provider=mock_plugin_model_provider, + +@pytest.fixture +def model_runtime() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def tts_model(provider_schema: ProviderEntity, model_runtime: MagicMock) -> TTSModel: + return TTSModel(provider_schema=provider_schema, model_runtime=model_runtime) + + +def test_model_type(tts_model: TTSModel) -> None: + assert tts_model.model_type == ModelType.TTS + + +def test_invoke_success(tts_model: TTSModel, model_runtime: MagicMock) -> None: + model_runtime.invoke_tts.return_value = [b"audio_chunk"] + + result = tts_model.invoke( + model="test_model", + credentials={"api_key": "abc"}, + content_text="Hello world", + voice="alloy", + ) + + assert list(result) == [b"audio_chunk"] + model_runtime.invoke_tts.assert_called_once_with( + provider="test_provider", + model="test_model", + credentials={"api_key": "abc"}, + content_text="Hello world", + voice="alloy", + ) + + +def test_invoke_exception(tts_model: TTSModel, model_runtime: MagicMock) -> None: + model_runtime.invoke_tts.side_effect = Exception("Test error") + + with pytest.raises(InvokeError, match="Test error"): + tts_model.invoke( + model="test_model", + credentials={"api_key": "abc"}, + content_text="Hello world", + voice="alloy", ) - def test_model_type(self, tts_model): - assert tts_model.model_type == ModelType.TTS - def test_invoke_success(self, tts_model): - model_name = "test_model" - tenant_id = "ignored_tenant_id" - credentials = {"api_key": "abc"} - content_text = "Hello world" - voice = "alloy" - user = "user_123" +def test_get_tts_model_voices(tts_model: TTSModel, model_runtime: MagicMock) -> None: + model_runtime.get_tts_model_voices.return_value = [{"name": "Voice1"}] - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_tts.return_value = [b"audio_chunk"] + result = tts_model.get_tts_model_voices( + model="test_model", + credentials={"api_key": "abc"}, + language="en-US", + ) - result = tts_model.invoke( - model=model_name, - tenant_id=tenant_id, - credentials=credentials, - content_text=content_text, - voice=voice, - user=user, - ) - - assert list(result) == [b"audio_chunk"] - mock_client.invoke_tts.assert_called_once_with( - tenant_id="tenant_123", - user_id="user_123", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - content_text=content_text, - voice=voice, - ) - - def test_invoke_success_no_user(self, tts_model): - model_name = "test_model" - tenant_id = "ignored_tenant_id" - credentials = {"api_key": "abc"} - content_text = "Hello world" - voice = "alloy" - - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_tts.return_value = [b"audio_chunk"] - - result = tts_model.invoke( - model=model_name, tenant_id=tenant_id, credentials=credentials, content_text=content_text, voice=voice - ) - - assert list(result) == [b"audio_chunk"] - mock_client.invoke_tts.assert_called_once_with( - tenant_id="tenant_123", - user_id="unknown", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - content_text=content_text, - voice=voice, - ) - - def test_invoke_exception(self, tts_model): - model_name = "test_model" - tenant_id = "ignored_tenant_id" - credentials = {"api_key": "abc"} - content_text = "Hello world" - voice = "alloy" - - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.invoke_tts.side_effect = Exception("Test error") - - with pytest.raises(InvokeError) as excinfo: - tts_model.invoke( - model=model_name, - tenant_id=tenant_id, - credentials=credentials, - content_text=content_text, - voice=voice, - ) - - assert "[test_provider] Error: Test error" in str(excinfo.value.description) - - def test_get_tts_model_voices(self, tts_model): - model_name = "test_model" - credentials = {"api_key": "abc"} - language = "en-US" - - with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: - mock_client = mock_client_class.return_value - mock_client.get_tts_model_voices.return_value = [{"name": "Voice1"}] - - result = tts_model.get_tts_model_voices(model=model_name, credentials=credentials, language=language) - - assert result == [{"name": "Voice1"}] - mock_client.get_tts_model_voices.assert_called_once_with( - tenant_id="tenant_123", - user_id="unknown", - plugin_id="plugin_123", - provider="test_provider", - model=model_name, - credentials=credentials, - language=language, - ) + assert result == [{"name": "Voice1"}] + model_runtime.get_tts_model_voices.assert_called_once_with( + provider="test_provider", + model="test_model", + credentials={"api_key": "abc"}, + language="en-US", + ) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py deleted file mode 100644 index 1ad0210375..0000000000 --- a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py +++ /dev/null @@ -1,522 +0,0 @@ -import logging -from datetime import datetime -from threading import Lock -from typing import Any -from unittest.mock import MagicMock, patch - -import pytest -from redis import RedisError - -import contexts -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ( - AIModelEntity, - FetchFrom, - ModelPropertyKey, - ModelType, -) -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory - - -def _provider_entity( - *, - provider: str, - supported_model_types: list[ModelType] | None = None, - models: list[AIModelEntity] | None = None, - icon_small: I18nObject | None = None, - icon_small_dark: I18nObject | None = None, -) -> ProviderEntity: - return ProviderEntity( - provider=provider, - label=I18nObject(en_US=provider), - supported_model_types=supported_model_types or [ModelType.LLM], - configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], - models=models or [], - icon_small=icon_small, - icon_small_dark=icon_small_dark, - ) - - -def _plugin_provider( - *, plugin_id: str, declaration: ProviderEntity, provider: str = "provider" -) -> PluginModelProviderEntity: - return PluginModelProviderEntity.model_construct( - id=f"{plugin_id}-id", - created_at=datetime.now(), - updated_at=datetime.now(), - provider=provider, - tenant_id="tenant", - plugin_unique_identifier=f"{plugin_id}-uid", - plugin_id=plugin_id, - declaration=declaration, - ) - - -@pytest.fixture(autouse=True) -def _reset_plugin_model_provider_context() -> None: - contexts.plugin_model_providers_lock.set(Lock()) - contexts.plugin_model_providers.set(None) - - -@pytest.fixture -def fake_plugin_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock: - manager = MagicMock() - - import core.plugin.impl.model as plugin_model_module - - monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: manager) - return manager - - -@pytest.fixture -def factory(fake_plugin_manager: MagicMock) -> ModelProviderFactory: - return ModelProviderFactory(tenant_id="tenant") - - -def test_get_plugin_model_providers_initializes_context_on_lookup_error( - factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch -) -> None: - declaration = _provider_entity(provider="openai") - fake_plugin_manager.fetch_model_providers.return_value = [ - _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) - ] - - original_get = contexts.plugin_model_providers.get - calls = {"n": 0} - - def flaky_get() -> Any: - calls["n"] += 1 - if calls["n"] == 1: - raise LookupError - return original_get() - - monkeypatch.setattr(contexts.plugin_model_providers, "get", flaky_get) - - providers = factory.get_plugin_model_providers() - assert len(providers) == 1 - assert providers[0].declaration.provider == "langgenius/openai/openai" - - -def test_get_plugin_model_providers_caches_and_does_not_refetch( - factory: ModelProviderFactory, fake_plugin_manager: MagicMock -) -> None: - declaration = _provider_entity(provider="openai") - fake_plugin_manager.fetch_model_providers.return_value = [ - _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) - ] - - first = factory.get_plugin_model_providers() - second = factory.get_plugin_model_providers() - - assert first is second - fake_plugin_manager.fetch_model_providers.assert_called_once_with("tenant") - - -def test_get_providers_returns_declarations(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None: - d1 = _provider_entity(provider="openai") - d2 = _provider_entity(provider="anthropic") - fake_plugin_manager.fetch_model_providers.return_value = [ - _plugin_provider(plugin_id="langgenius/openai", declaration=d1), - _plugin_provider(plugin_id="langgenius/anthropic", declaration=d2), - ] - - providers = factory.get_providers() - assert [p.provider for p in providers] == ["langgenius/openai/openai", "langgenius/anthropic/anthropic"] - - -def test_get_plugin_model_provider_converts_short_provider_id( - factory: ModelProviderFactory, fake_plugin_manager: MagicMock -) -> None: - declaration = _provider_entity(provider="openai") - fake_plugin_manager.fetch_model_providers.return_value = [ - _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) - ] - - provider = factory.get_plugin_model_provider("openai") - assert provider.declaration.provider == "langgenius/openai/openai" - - -def test_get_plugin_model_provider_raises_on_invalid_provider( - factory: ModelProviderFactory, fake_plugin_manager: MagicMock -) -> None: - declaration = _provider_entity(provider="openai") - fake_plugin_manager.fetch_model_providers.return_value = [ - _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) - ] - - with pytest.raises(ValueError, match="Invalid provider"): - factory.get_plugin_model_provider("langgenius/unknown/unknown") - - -def test_get_provider_schema_returns_declaration(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None: - declaration = _provider_entity(provider="openai") - fake_plugin_manager.fetch_model_providers.return_value = [ - _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) - ] - - schema = factory.get_provider_schema("openai") - assert schema.provider == "langgenius/openai/openai" - - -def test_provider_credentials_validate_errors_when_schema_missing( - factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch -) -> None: - schema = _provider_entity(provider="openai") - schema.provider_credential_schema = None - monkeypatch.setattr( - factory, - "get_plugin_model_provider", - lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema), - ) - - with pytest.raises(ValueError, match="does not have provider_credential_schema"): - factory.provider_credentials_validate(provider="openai", credentials={"x": "y"}) - - -def test_provider_credentials_validate_filters_and_calls_plugin_validation( - factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch -) -> None: - schema = _provider_entity(provider="openai") - schema.provider_credential_schema = MagicMock() - plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema) - monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider) - - fake_validator = MagicMock() - fake_validator.validate_and_filter.return_value = {"filtered": True} - monkeypatch.setattr( - "dify_graph.model_runtime.model_providers.model_provider_factory.ProviderCredentialSchemaValidator", - lambda _: fake_validator, - ) - - filtered = factory.provider_credentials_validate(provider="openai", credentials={"raw": True}) - assert filtered == {"filtered": True} - fake_plugin_manager.validate_provider_credentials.assert_called_once() - kwargs = fake_plugin_manager.validate_provider_credentials.call_args.kwargs - assert kwargs["plugin_id"] == "langgenius/openai" - assert kwargs["provider"] == "provider" - assert kwargs["credentials"] == {"filtered": True} - - -def test_model_credentials_validate_errors_when_schema_missing( - factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch -) -> None: - schema = _provider_entity(provider="openai") - schema.model_credential_schema = None - monkeypatch.setattr( - factory, - "get_plugin_model_provider", - lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema), - ) - - with pytest.raises(ValueError, match="does not have model_credential_schema"): - factory.model_credentials_validate( - provider="openai", model_type=ModelType.LLM, model="m", credentials={"x": "y"} - ) - - -def test_model_credentials_validate_filters_and_calls_plugin_validation( - factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch -) -> None: - schema = _provider_entity(provider="openai") - schema.model_credential_schema = MagicMock() - plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema) - monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider) - - fake_validator = MagicMock() - fake_validator.validate_and_filter.return_value = {"filtered": True} - monkeypatch.setattr( - "dify_graph.model_runtime.model_providers.model_provider_factory.ModelCredentialSchemaValidator", - lambda *_: fake_validator, - ) - - filtered = factory.model_credentials_validate( - provider="openai", model_type=ModelType.TEXT_EMBEDDING, model="m", credentials={"raw": True} - ) - assert filtered == {"filtered": True} - kwargs = fake_plugin_manager.validate_model_credentials.call_args.kwargs - assert kwargs["plugin_id"] == "langgenius/openai" - assert kwargs["provider"] == "provider" - assert kwargs["model_type"] == "text-embedding" - assert kwargs["model"] == "m" - assert kwargs["credentials"] == {"filtered": True} - - -def test_get_model_schema_cache_hit(factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch) -> None: - model_schema = AIModelEntity( - model="m", - label=I18nObject(en_US="m"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], - ) - - monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) - - with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: - mock_redis.get.return_value = model_schema.model_dump_json().encode() - assert ( - factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials={"k": "v"}) - == model_schema - ) - - -def test_get_model_schema_cache_invalid_json_deletes_key( - factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture -) -> None: - caplog.set_level(logging.WARNING) - - with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: - mock_redis.get.return_value = b'{"model":"m"}' - factory.plugin_model_manager.get_model_schema.return_value = None - factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] - assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None - assert mock_redis.delete.called - assert any("Failed to validate cached plugin model schema" in r.message for r in caplog.records) - - -def test_get_model_schema_cache_delete_redis_error_is_logged( - factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture -) -> None: - caplog.set_level(logging.WARNING) - - with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: - mock_redis.get.return_value = b'{"model":"m"}' - mock_redis.delete.side_effect = RedisError("nope") - factory.plugin_model_manager.get_model_schema.return_value = None - factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] - factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) - assert any("Failed to delete invalid plugin model schema cache" in r.message for r in caplog.records) - - -def test_get_model_schema_redis_get_error_falls_back_to_plugin( - factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture -) -> None: - caplog.set_level(logging.WARNING) - factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] - factory.plugin_model_manager.get_model_schema.return_value = None - - with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: - mock_redis.get.side_effect = RedisError("down") - assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None - assert any("Failed to read plugin model schema cache" in r.message for r in caplog.records) - - -def test_get_model_schema_cache_miss_sets_cache_and_handles_setex_error( - factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture -) -> None: - caplog.set_level(logging.WARNING) - factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] - - model_schema = AIModelEntity( - model="m", - label=I18nObject(en_US="m"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], - ) - factory.plugin_model_manager.get_model_schema.return_value = model_schema - - with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: - mock_redis.get.return_value = None - mock_redis.setex.side_effect = RedisError("nope") - assert ( - factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) - == model_schema - ) - assert any("Failed to write plugin model schema cache" in r.message for r in caplog.records) - - -@pytest.mark.parametrize( - ("model_type", "expected_class"), - [ - (ModelType.LLM, "LargeLanguageModel"), - (ModelType.TEXT_EMBEDDING, "TextEmbeddingModel"), - (ModelType.RERANK, "RerankModel"), - (ModelType.SPEECH2TEXT, "Speech2TextModel"), - (ModelType.MODERATION, "ModerationModel"), - (ModelType.TTS, "TTSModel"), - ], -) -def test_get_model_type_instance_dispatches_by_type( - factory: ModelProviderFactory, model_type: ModelType, expected_class: str, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) - monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity)) - - sentinel = object() - monkeypatch.setattr( - f"dify_graph.model_runtime.model_providers.model_provider_factory.{expected_class}", - MagicMock(model_validate=lambda _: sentinel), - ) - - assert factory.get_model_type_instance("langgenius/openai/openai", model_type) is sentinel - - -def test_get_model_type_instance_raises_on_unsupported( - factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) - monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity)) - - class UnknownModelType: - pass - - with pytest.raises(ValueError, match="Unsupported model type"): - factory.get_model_type_instance("langgenius/openai/openai", UnknownModelType()) # type: ignore[arg-type] - - -def test_get_models_filters_by_provider_and_model_type( - factory: ModelProviderFactory, fake_plugin_manager: MagicMock -) -> None: - llm = AIModelEntity( - model="m1", - label=I18nObject(en_US="m1"), - model_type=ModelType.LLM, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], - ) - embed = AIModelEntity( - model="e1", - label=I18nObject(en_US="e1"), - model_type=ModelType.TEXT_EMBEDDING, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, - parameter_rules=[], - ) - - openai = _provider_entity( - provider="openai", supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], models=[llm, embed] - ) - anthropic = _provider_entity(provider="anthropic", supported_model_types=[ModelType.LLM], models=[llm]) - fake_plugin_manager.fetch_model_providers.return_value = [ - _plugin_provider(plugin_id="langgenius/openai", declaration=openai), - _plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic), - ] - - # ModelType filter picks only matching models - providers = factory.get_models(model_type=ModelType.TEXT_EMBEDDING) - assert len(providers) == 1 - assert providers[0].provider == "langgenius/openai/openai" - assert [m.model for m in providers[0].models] == ["e1"] - - # Provider filter excludes others - providers = factory.get_models(provider="langgenius/anthropic/anthropic", model_type=ModelType.LLM) - assert len(providers) == 1 - assert providers[0].provider == "langgenius/anthropic/anthropic" - - -def test_get_models_provider_filter_skips_non_matching( - factory: ModelProviderFactory, fake_plugin_manager: MagicMock -) -> None: - openai = _provider_entity(provider="openai") - anthropic = _provider_entity(provider="anthropic") - fake_plugin_manager.fetch_model_providers.return_value = [ - _plugin_provider(plugin_id="langgenius/openai", declaration=openai), - _plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic), - ] - - providers = factory.get_models(provider="langgenius/not-exist/not-exist", model_type=ModelType.LLM) - assert providers == [] - - -def test_get_provider_icon_fetches_asset_and_returns_mime_type( - factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch -) -> None: - provider_schema = _provider_entity( - provider="langgenius/openai/openai", - icon_small=I18nObject(en_US="icon.png", zh_Hans="icon-zh.png"), - icon_small_dark=I18nObject(en_US="dark.svg", zh_Hans="dark-zh.svg"), - ) - monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) - - class FakePluginAssetManager: - def fetch_asset(self, tenant_id: str, id: str) -> bytes: - assert tenant_id == "tenant" - return f"bytes:{id}".encode() - - import core.plugin.impl.asset as asset_module - - monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager) - - data, mime = factory.get_provider_icon("openai", "icon_small", "en_US") - assert data == b"bytes:icon.png" - assert mime == "image/png" - - data, mime = factory.get_provider_icon("openai", "icon_small_dark", "zh_Hans") - assert data == b"bytes:dark-zh.svg" - assert mime == "image/svg+xml" - - -def test_get_provider_icon_uses_zh_hans_for_small_and_en_us_for_dark( - factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch -) -> None: - provider_schema = _provider_entity( - provider="langgenius/openai/openai", - icon_small=I18nObject(en_US="icon-en.png", zh_Hans="icon-zh.png"), - icon_small_dark=I18nObject(en_US="dark-en.svg", zh_Hans="dark-zh.svg"), - ) - monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) - - class FakePluginAssetManager: - def fetch_asset(self, tenant_id: str, id: str) -> bytes: - return id.encode() - - import core.plugin.impl.asset as asset_module - - monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager) - - data, _ = factory.get_provider_icon("openai", "icon_small", "zh_Hans") - assert data == b"icon-zh.png" - - data, _ = factory.get_provider_icon("openai", "icon_small_dark", "en_US") - assert data == b"dark-en.svg" - - -def test_get_provider_icon_raises_for_missing_icons( - factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch -) -> None: - provider_schema = _provider_entity(provider="langgenius/openai/openai") - monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) - - with pytest.raises(ValueError, match="does not have small icon"): - factory.get_provider_icon("openai", "icon_small", "en_US") - - with pytest.raises(ValueError, match="does not have small dark icon"): - factory.get_provider_icon("openai", "icon_small_dark", "en_US") - - -def test_get_provider_icon_raises_for_unsupported_icon_type( - factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch -) -> None: - provider_schema = _provider_entity( - provider="langgenius/openai/openai", - icon_small=I18nObject(en_US="", zh_Hans=""), - ) - monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) - with pytest.raises(ValueError, match="Unsupported icon type"): - factory.get_provider_icon("openai", "nope", "en_US") - - -def test_get_provider_icon_raises_when_file_name_missing( - factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch -) -> None: - provider_schema = _provider_entity( - provider="langgenius/openai/openai", - icon_small=I18nObject(en_US="", zh_Hans=""), - ) - monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) - with pytest.raises(ValueError, match="does not have icon"): - factory.get_provider_icon("openai", "icon_small", "en_US") - - -def test_get_plugin_id_and_provider_name_from_provider_handles_google_special_case( - factory: ModelProviderFactory, -) -> None: - plugin_id, provider_name = factory.get_plugin_id_and_provider_name_from_provider("google") - assert plugin_id == "langgenius/gemini" - assert provider_name == "google" diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 7ec1343f98..6a5fc4e417 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -336,7 +336,6 @@ def test_structured_output_parser(): json_schema=case["json_schema"], stream=case["stream"], model_parameters={"temperature": 0.7, "max_tokens": 100}, - user="test_user", ) if case["expected_result_type"] == "generator": @@ -367,7 +366,7 @@ def test_structured_output_parser(): call_args = model_instance.invoke_llm.call_args assert call_args.kwargs["stream"] == case["stream"] - assert call_args.kwargs["user"] == "test_user" + assert "user" not in call_args.kwargs assert "temperature" in call_args.kwargs["model_parameters"] assert "max_tokens" in call_args.kwargs["model_parameters"]