diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index e9bd30ba7e..8bb5aa2c1b 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -88,6 +88,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_tenant_id, @@ -127,6 +128,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) except Exception: continue diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index e0aeff0082..9ff4e6afde 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -140,6 +140,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, + user_id=self.user_id, invoke_from=self.application_generate_entity.invoke_from, ) assert tool_entity.entity.description diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 1890ab5a42..cc75effe1f 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from typing import Any from core.app.app_config.entities import ModelConfigEntity -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -53,9 +53,12 @@ class ModelConfigManager: if not isinstance(config["model"], dict): raise ValueError("model must be of object type") + # Keep provider discovery and provider-backed model listing on the same + # request-scoped runtime so caller scope and provider caches stay aligned. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + # model.provider - model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) - provider_entities = model_provider_factory.get_providers() + provider_entities = assembly.model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] if "provider" not in config["model"]: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") @@ -70,8 +73,7 @@ class ModelConfigManager: if "name" not in config["model"]: raise ValueError("model.name is required") - provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) - models = provider_manager.get_configurations(tenant_id).get_models( + models = assembly.provider_manager.get_configurations(tenant_id).get_models( provider=config["model"]["provider"], model_type=ModelType.LLM ) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 5418dd0490..297b4fa90e 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -105,8 +105,8 @@ class ProviderConfiguration(BaseModel): """Attach the already-composed runtime for request-bound call chains.""" self._bound_model_runtime = model_runtime - def _get_model_provider_factory(self) -> ModelProviderFactory: - """Reuse a bound runtime when available, else fall back to tenant scope.""" + def get_model_provider_factory(self) -> ModelProviderFactory: + """Return a provider factory that preserves any request-bound runtime.""" if self._bound_model_runtime is not None: return ModelProviderFactory(model_runtime=self._bound_model_runtime) return create_plugin_model_provider_factory(tenant_id=self.tenant_id) @@ -362,7 +362,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = self._get_model_provider_factory() + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) @@ -921,7 +921,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = self._get_model_provider_factory() + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1407,7 +1407,7 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = self._get_model_provider_factory() + model_provider_factory = self.get_model_provider_factory() # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) @@ -1416,7 +1416,7 @@ class ProviderConfiguration(BaseModel): """ Get model schema """ - model_provider_factory = self._get_model_provider_factory() + model_provider_factory = self.get_model_provider_factory() return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1518,7 +1518,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = self._get_model_provider_factory() + model_provider_factory = self.get_model_provider_factory() provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index e929097d2f..5f8c69ff58 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -277,18 +277,22 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): } @classmethod - def get_system_model_max_tokens(cls, tenant_id: str) -> int: + def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int: """ get system model max tokens """ - return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id) + return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id) @classmethod - def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ get prompt tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages) + return ModelInvocationUtils.calculate_tokens( + tenant_id=tenant_id, + prompt_messages=prompt_messages, + user_id=user_id, + ) @classmethod def invoke_system_model( @@ -306,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tool_type=ToolProviderType.PLUGIN, tool_name="plugin", prompt_messages=prompt_messages, + caller_user_id=user_id, ) @classmethod @@ -313,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke summary """ - max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id) + max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id) content = payload.text SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language @@ -332,6 +337,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=content)], + user_id=user_id, ) < max_tokens * 0.6 ): @@ -344,6 +350,7 @@ Here is the extra instruction you need to follow: SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)), UserPromptMessage(content=content), ], + user_id=user_id, ) def summarize(content: str) -> str: @@ -401,6 +408,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=result)], + user_id=user_id, ) > max_tokens * 0.7 ): diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index c2d1574e67..0585494269 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id + tool_type, + tenant_id, + provider, + tool_name, + tool_parameters, + user_id=user_id, + credential_id=credential_id, ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index a5d1c81d5d..efdf93576a 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -3,12 +3,64 @@ from __future__ import annotations from typing import TYPE_CHECKING from core.plugin.impl.model import PluginModelClient +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory if TYPE_CHECKING: from core.model_manager import ModelManager from core.plugin.impl.model_runtime import PluginModelRuntime from core.provider_manager import ProviderManager - from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +class PluginModelAssembly: + """Compose request-scoped model views on top of a single plugin runtime.""" + + tenant_id: str + user_id: str | None + _model_runtime: PluginModelRuntime | None + _model_provider_factory: ModelProviderFactory | None + _provider_manager: ProviderManager | None + _model_manager: ModelManager | None + + def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None: + self.tenant_id = tenant_id + self.user_id = user_id + self._model_runtime = None + self._model_provider_factory = None + self._provider_manager = None + self._model_manager = None + + @property + def model_runtime(self) -> PluginModelRuntime: + if self._model_runtime is None: + self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id) + return self._model_runtime + + @property + def model_provider_factory(self) -> ModelProviderFactory: + if self._model_provider_factory is None: + self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime) + return self._model_provider_factory + + @property + def provider_manager(self) -> ProviderManager: + if self._provider_manager is None: + from core.provider_manager import ProviderManager + + self._provider_manager = ProviderManager(model_runtime=self.model_runtime) + return self._provider_manager + + @property + def model_manager(self) -> ModelManager: + if self._model_manager is None: + from core.model_manager import ModelManager + + self._model_manager = ModelManager(provider_manager=self.provider_manager) + return self._model_manager + + +def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly: + """Create a request-scoped assembly that shares one plugin runtime across model views.""" + return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id) def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime: @@ -24,22 +76,14 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) - def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory: """Create a tenant-bound model provider factory for service flows.""" - from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory - - return ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id)) + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: """Create a tenant-bound provider manager for service flows.""" - from core.provider_manager import ProviderManager - - return ProviderManager(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id)) + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager: """Create a tenant-bound model manager for service flows.""" - from core.model_manager import ModelManager - - return ModelManager( - provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id), - ) + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 961d13f90a..5154bc9805 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -9,10 +9,14 @@ from core.tools.entities.tool_entities import ToolInvokeFrom class ToolRuntime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing. + + ``user_id`` is optional so read-only tooling flows can stay tenant-scoped, + while execution paths may bind caller identity for model runtime lookups. """ tenant_id: str + user_id: str | None = None tool_id: str | None = None invoke_from: InvokeFrom | None = None tool_invoke_from: ToolInvokeFrom | None = None diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index bcf58394ba..64a2c697fe 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -53,6 +53,7 @@ class BuiltinTool(Tool): tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, + caller_user_id=self.runtime.user_id, ) def tool_provider_type(self) -> ToolProviderType: @@ -69,6 +70,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id or "", + user_id=self.runtime.user_id, ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -82,7 +84,9 @@ class BuiltinTool(Tool): raise ValueError("runtime is required") return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + tenant_id=self.runtime.tenant_id or "", + prompt_messages=prompt_messages, + user_id=self.runtime.user_id, ) def summary(self, user_id: str, content: str) -> str: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0ed663ddbd..9e58610f77 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -184,6 +184,7 @@ class ToolManager: provider_id: str, tool_name: str, tenant_id: str, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, @@ -195,6 +196,7 @@ class ToolManager: :param provider_id: the id of the provider :param tool_name: the name of the tool :param tenant_id: the tenant id + :param user_id: the caller id bound to runtime-scoped model/tool lookups :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from :param credential_id: the credential id @@ -213,6 +215,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -321,6 +324,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(decrypted_credentials), credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, @@ -338,6 +342,7 @@ class ToolManager: return api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(encrypter.decrypt(credentials)), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -361,6 +366,7 @@ class ToolManager: return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -369,9 +375,21 @@ class ToolManager: elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") elif provider_type == ToolProviderType.PLUGIN: - return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool elif provider_type == ToolProviderType.MCP: - return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @@ -381,6 +399,7 @@ class ToolManager: tenant_id: str, app_id: str, agent_tool: AgentToolEntity, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -392,6 +411,7 @@ class ToolManager: provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, credential_id=agent_tool.credential_id, @@ -423,6 +443,7 @@ class ToolManager: app_id: str, node_id: str, workflow_tool: WorkflowToolRuntimeSpec, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -435,6 +456,7 @@ class ToolManager: provider_id=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, credential_id=workflow_tool.credential_id, @@ -467,6 +489,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + user_id: str | None = None, credential_id: str | None = None, ) -> Tool: """ @@ -477,6 +500,7 @@ class ToolManager: provider_id=provider, tool_name=tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, credential_id=credential_id, diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9a6c0c50de..5d49bf9f23 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -34,11 +34,12 @@ class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, + user_id: str | None = None, ) -> int: """ get max llm context tokens of the model """ - model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -60,13 +61,13 @@ class ModelInvocationUtils: return max_tokens @staticmethod - def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ calculate tokens from prompt messages and model parameters """ # get model instance - model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: @@ -79,7 +80,12 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, + tenant_id: str, + tool_type: ToolProviderType, + tool_name: str, + prompt_messages: list[PromptMessage], + caller_user_id: str | None = None, ) -> LLMResult: """ invoke model with parameters in user's own context @@ -93,7 +99,7 @@ class ModelInvocationUtils: """ # get model manager - model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=caller_user_id or user_id) # get model instance model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 2e0a3c8928..10c8963c90 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -337,6 +337,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): self._run_context.app_id, node_id, self._build_tool_runtime_spec(node_data), + self._run_context.user_id, self._run_context.invoke_from, variable_pool, ) diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 06eeb32da1..48fcd6a749 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -12,9 +12,9 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager +from core.model_manager import ModelInstance from core.plugin.entities.request import InvokeCredentials -from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager from core.workflow.system_variables import SystemVariableKey, get_system_text @@ -141,6 +141,7 @@ class AgentRuntimeSupport: tenant_id, app_id, entity, + user_id, invoke_from, runtime_variable_pool, ) @@ -242,8 +243,8 @@ class AgentRuntimeSupport: user_id: str, value: dict[str, Any], ) -> tuple[ModelInstance, AIModelEntity | None]: - provider_manager = create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id) - provider_model_bundle = provider_manager.get_provider_model_bundle( + assembly = create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id) + provider_model_bundle = assembly.provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM, @@ -255,7 +256,7 @@ class AgentRuntimeSupport: ) provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + model_instance = assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), diff --git a/api/libs/login.py b/api/libs/login.py index bd5cb5f30d..42c186e048 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS @@ -34,8 +34,6 @@ def current_account_with_tenant(): return user, user.current_tenant_id -from typing import ParamSpec, TypeVar - P = ParamSpec("P") R = TypeVar("R") diff --git a/api/services/app_service.py b/api/services/app_service.py index 203af702ae..aba0256d12 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -205,6 +205,7 @@ class AppService: tenant_id=current_user.current_tenant_id, app_id=app.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 2e927bc693..926e04d503 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -10,13 +10,14 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType @@ -223,8 +224,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) + provider_configurations = provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -496,8 +497,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -533,6 +534,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=assembly.model_provider_factory, ) def _custom_credentials_validate( @@ -543,6 +545,7 @@ class ModelLoadBalancingService: model: str, credentials: dict, load_balancing_model_config: LoadBalancingModelConfig | None = None, + model_provider_factory: ModelProviderFactory | None = None, validate: bool = True, ): """ @@ -553,6 +556,7 @@ class ModelLoadBalancingService: :param model: model name :param credentials: credentials :param load_balancing_model_config: load balancing model config + :param model_provider_factory: model provider factory sharing the active runtime :param validate: validate credentials :return: """ @@ -582,7 +586,8 @@ class ModelLoadBalancingService: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) if validate: - model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) + if model_provider_factory is None: + model_provider_factory = provider_configuration.get_model_provider_factory() if isinstance(credential_schemas, ModelCredentialSchema): credentials = model_provider_factory.model_credentials_validate( provider=provider_configuration.provider.provider, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 051578a570..1872299b62 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -13,7 +13,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.file_access import DatabaseFileAccessController -from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type @@ -491,12 +491,15 @@ class WorkflowService: :raises ValueError: If the model configuration is invalid or credentials fail policy checks """ try: - from core.model_manager import ModelManager from dify_graph.model_runtime.entities.model_entities import ModelType + # Model instance resolution and provider status lookup must reuse the + # same request-scoped runtime so validation does not silently split + # provider discovery and credential reads across different caches. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + # Get model instance to validate provider+model combination - model_manager = ModelManager.for_tenant(tenant_id=tenant_id) - model_manager.get_model_instance( + assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name ) @@ -505,8 +508,7 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) - provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) target_model = None diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py new file mode 100644 index 0000000000..033d22aa47 --- /dev/null +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -0,0 +1,57 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.entities import ModelConfigEntity +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from models.provider_ids import ModelProviderID + + +def test_validate_and_set_defaults_reuses_single_model_assembly(): + provider_name = str(ModelProviderID("openai")) + provider_entity = SimpleNamespace(provider=provider_name) + model = SimpleNamespace(model="gpt-4o-mini", model_properties={ModelPropertyKey.MODE: "chat"}) + provider_configurations = SimpleNamespace(get_models=lambda **kwargs: [model]) + assembly = SimpleNamespace( + model_provider_factory=SimpleNamespace(get_providers=lambda: [provider_entity]), + provider_manager=SimpleNamespace(get_configurations=lambda tenant_id: provider_configurations), + ) + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "completion_params": {"stop": []}, + } + } + + with patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + result, keys = ModelConfigManager.validate_and_set_defaults("tenant-1", config) + + assert result["model"]["provider"] == provider_name + assert result["model"]["mode"] == "chat" + assert keys == ["model"] + mock_assembly.assert_called_once_with(tenant_id="tenant-1") + + +def test_convert_keeps_model_config_shape(): + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {"temperature": 0.3, "stop": ["END"]}, + } + } + + result = ModelConfigManager.convert(config) + + assert result == ModelConfigEntity( + provider="openai", + model="gpt-4o-mini", + mode="chat", + parameters={"temperature": 0.3}, + stop=["END"], + ) diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py new file mode 100644 index 0000000000..7491e79f30 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -0,0 +1,36 @@ +from unittest.mock import Mock, patch + +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly + + +def test_plugin_model_assembly_reuses_single_runtime_across_views(): + runtime = Mock(name="runtime") + provider_factory = Mock(name="provider_factory") + provider_manager = Mock(name="provider_manager") + model_manager = Mock(name="model_manager") + + with ( + patch( + "core.plugin.impl.model_runtime_factory.create_plugin_model_runtime", + return_value=runtime, + ) as mock_runtime_factory, + patch( + "core.plugin.impl.model_runtime_factory.ModelProviderFactory", + return_value=provider_factory, + ) as mock_provider_factory_cls, + patch("core.provider_manager.ProviderManager", return_value=provider_manager) as mock_provider_manager_cls, + patch("core.model_manager.ModelManager", return_value=model_manager) as mock_model_manager_cls, + ): + assembly = create_plugin_model_assembly(tenant_id="tenant-1", user_id="user-1") + + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + + mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime) + mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) + mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py new file mode 100644 index 0000000000..59a9c229d0 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation +from core.plugin.entities.request import RequestInvokeSummary +from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + + +def test_system_model_helpers_forward_user_id(): + with ( + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.get_max_llm_context_tokens", + return_value=4096, + ) as mock_max_tokens, + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.calculate_tokens", + return_value=7, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.get_system_model_max_tokens("tenant-1", user_id="user-1") == 4096 + assert ( + PluginModelBackwardsInvocation.get_prompt_tokens( + "tenant-1", + [UserPromptMessage(content="hello")], + user_id="user-1", + ) + == 7 + ) + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="user-1", + ) + + +def test_invoke_summary_uses_same_user_scope_for_token_helpers(): + tenant = SimpleNamespace(id="tenant-1") + payload = RequestInvokeSummary(text="short", instruction="keep it concise") + + with ( + patch.object( + PluginModelBackwardsInvocation, + "get_system_model_max_tokens", + return_value=100, + ) as mock_max_tokens, + patch.object( + PluginModelBackwardsInvocation, + "get_prompt_tokens", + return_value=10, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.invoke_summary("user-1", tenant, payload) == "short" + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="short")], + user_id="user-1", + ) diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index f123f60a34..ccf6ddccaf 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool): yield self.create_text_message("ok") -def _build_tool() -> _BuiltinDummyTool: +def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool: entity = ToolEntity( identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), parameters=[], ) - runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER) return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) @@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type(): def test_invoke_model_calls_model_invocation_utils_invoke(): - tool = _build_tool() + tool = _build_tool(user_id="runtime-user") with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: assert ( tool.invoke_model( @@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke(): ) == "result" ) - mock_invoke.assert_called_once() + mock_invoke.assert_called_once_with( + user_id="u1", + tenant_id="tenant-1", + tool_type=ToolProviderType.BUILT_IN, + tool_name="tool-a", + prompt_messages=[UserPromptMessage(content="hello")], + caller_user_id="runtime-user", + ) def test_get_max_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + tool = _build_tool(user_id="runtime-user") + with patch( + "core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096 + ) as mock_get: assert tool.get_max_tokens() == 4096 + mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user") def test_get_prompt_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + tool = _build_tool(user_id="runtime-user") + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="runtime-user", + ) + + +def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing(): + tool = _build_tool() + + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id=None, + ) def test_runtime_none_raises(): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 3012e64ebb..a3b03fefd6 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -1,6 +1,8 @@ from __future__ import annotations +import calendar import math +from datetime import date from types import SimpleNamespace import pytest @@ -98,7 +100,13 @@ def test_timezone_conversion_tool(): def test_weekday_tool(): weekday_tool = _build_builtin_tool(WeekdayTool) valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text - assert "January 1, 2024" in valid + expected_date = date(2024, 1, 1) + expected_message = ( + f"{calendar.month_name[expected_date.month]} " + f"{expected_date.day}, {expected_date.year} " + f"is {calendar.day_name[expected_date.weekday()]}." + ) + assert valid == expected_message invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ 0 ].message.text diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 0f73e22654..844bc01e29 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeFrom, ToolParameter, ToolProviderType, ) @@ -421,7 +422,7 @@ def test_get_agent_runtime_apply_runtime_parameters(): tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} @@ -437,12 +438,23 @@ def test_get_agent_runtime_apply_runtime_parameters(): tenant_id="tenant-1", app_id="app-1", agent_tool=agent_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert result is tool_runtime assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=None, + ) def test_get_workflow_runtime_apply_runtime_parameters(): @@ -463,7 +475,7 @@ def test_get_workflow_runtime_apply_runtime_parameters(): ) tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} @@ -473,12 +485,23 @@ def test_get_workflow_runtime_apply_runtime_parameters(): app_id="app-1", node_id="node-1", workflow_tool=workflow_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert workflow_result is tool_runtime2 assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=None, + ) def test_get_agent_runtime_raises_when_runtime_missing(): @@ -520,17 +543,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime: result = ToolManager.get_tool_runtime_from_plugin( tool_type=ToolProviderType.API, tenant_id="tenant-1", provider="api-1", tool_name="search", tool_parameters={"q": "hello", "llm": "ignore"}, + user_id="user-1", ) assert result is tool_entity assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=None, + ) def test_hardcoded_provider_icon_success(): diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index f36993a3b9..8aba05ab4c 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat manager = Mock() manager.get_default_model_instance.return_value = model_instance - with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: if error_match: with pytest.raises(InvokeModelError, match=error_match): - ModelInvocationUtils.get_max_llm_context_tokens("tenant") + ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") else: - assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected + + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1") def test_calculate_tokens_handles_missing_model(): manager = Mock() manager.get_default_model_instance.return_value = None - with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with pytest.raises(InvokeModelError, match="Model not found"): ModelInvocationUtils.calculate_tokens("tenant", []) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None) def test_invoke_success_and_error_mappings(): @@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings(): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): response = ModelInvocationUtils.invoke( @@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings(): tool_type="builtin", tool_name="tool-a", prompt_messages=[], + caller_user_id="caller-1", ) assert response.message.content == "ok" assert db_mock.session.add.call_count == 1 assert db_mock.session.commit.call_count == 2 + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1") @pytest.mark.parametrize( @@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): with pytest.raises(InvokeModelError, match=expected): @@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected): tool_name="tool-a", prompt_messages=[], ) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1") diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py new file mode 100644 index 0000000000..fea5e24cf6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -0,0 +1,49 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from dify_graph.model_runtime.entities.model_entities import ModelType + + +def test_fetch_model_reuses_single_model_assembly(): + provider_configuration = SimpleNamespace( + get_current_credentials=Mock(return_value={"api_key": "x"}), + provider=SimpleNamespace(provider="openai"), + ) + model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value="schema")) + provider_model_bundle = SimpleNamespace( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + model_instance = Mock() + assembly = SimpleNamespace( + provider_manager=Mock(), + model_manager=Mock(), + ) + assembly.provider_manager.get_provider_model_bundle.return_value = provider_model_bundle + assembly.model_manager.get_model_instance.return_value = model_instance + + with patch( + "core.workflow.nodes.agent.runtime_support.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + resolved_instance, resolved_schema = AgentRuntimeSupport().fetch_model( + tenant_id="tenant-1", + user_id="user-1", + value={"provider": "openai", "model": "gpt-4o-mini", "model_type": "llm"}, + ) + + assert resolved_instance is model_instance + assert resolved_schema == "schema" + mock_assembly.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + assembly.provider_manager.get_provider_model_bundle.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + ) + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + ) diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index a94ba0c00b..22e58aa8fc 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -1,10 +1,14 @@ -from unittest.mock import MagicMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock import pytest from flask import Flask, g from flask_login import LoginManager, UserMixin +from pytest_mock import MockerFixture -from libs.login import _get_user, current_user, login_required +import libs.login as login_module +from libs.login import current_user +from models.account import Account class MockUser(UserMixin): @@ -15,229 +19,214 @@ class MockUser(UserMixin): self._is_authenticated = is_authenticated @property - def is_authenticated(self): + def is_authenticated(self) -> bool: return self._is_authenticated -def mock_csrf_check(*args, **kwargs): - return +@pytest.fixture +def login_app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + + login_manager = LoginManager() + login_manager.init_app(app) + login_manager.unauthorized = MagicMock(return_value="Unauthorized") + + @login_manager.user_loader + def load_user(_user_id: str): + return None + + return app + + +@pytest.fixture +def csrf_check(mocker: MockerFixture) -> MagicMock: + return mocker.patch.object(login_module, "check_csrf_token") + + +@pytest.fixture +def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock: + def _ensure_sync(func): + return lambda *args, **kwargs: func(*args, **kwargs) + + return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync) class TestLoginRequired: """Test cases for login_required decorator.""" - @pytest.fixture - @patch("libs.login.check_csrf_token", mock_csrf_check) - def setup_app(self, app: Flask): - """Set up Flask app with login manager.""" - # Initialize login manager - login_manager = LoginManager() - login_manager.init_app(app) - - # Mock unauthorized handler - login_manager.unauthorized = MagicMock(return_value="Unauthorized") - - # Add a dummy user loader to prevent exceptions - @login_manager.user_loader - def load_user(user_id): - return None - - return app - - @patch("libs.login.check_csrf_token", mock_csrf_check) - def test_authenticated_user_can_access_protected_view(self, setup_app: Flask): + def test_authenticated_user_can_access_protected_view( + self, login_app: Flask, csrf_check: MagicMock, ensure_sync_spy: MagicMock, mocker: MockerFixture + ): """Test that authenticated users can access protected views.""" - @login_required + @login_module.login_required def protected_view(): return "Protected content" - with setup_app.test_request_context(): - # Mock authenticated user - mock_user = MockUser("test_user", is_authenticated=True) - with patch("libs.login._get_user", return_value=mock_user, autospec=True): - result = protected_view() - assert result == "Protected content" + mock_user = MockUser("test_user", is_authenticated=True) + get_user = mocker.patch.object(login_module, "_get_user", return_value=mock_user) - @patch("libs.login.check_csrf_token", mock_csrf_check) - def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask): - """Test that unauthenticated users are redirected.""" + with login_app.test_request_context(): + result = protected_view() + csrf_check.assert_called_once() + assert csrf_check.call_args.args[0].method == "GET" + assert csrf_check.call_args.args[1] == "test_user" - @login_required + assert result == "Protected content" + get_user.assert_called_once_with() + ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__) + login_app.login_manager.unauthorized.assert_not_called() + + @pytest.mark.parametrize( + ("resolved_user", "description"), + [ + pytest.param(None, "missing user", id="missing-user"), + pytest.param(MockUser("test_user", is_authenticated=False), "unauthenticated user", id="unauthenticated"), + ], + ) + def test_unauthorized_access_returns_login_manager_response( + self, + login_app: Flask, + csrf_check: MagicMock, + ensure_sync_spy: MagicMock, + mocker: MockerFixture, + resolved_user: MockUser | None, + description: str, + ): + """Test that missing or unauthenticated users are redirected.""" + + @login_module.login_required def protected_view(): return "Protected content" - with setup_app.test_request_context(): - # Mock unauthenticated user - mock_user = MockUser("test_user", is_authenticated=False) - with patch("libs.login._get_user", return_value=mock_user, autospec=True): - result = protected_view() - assert result == "Unauthorized" - setup_app.login_manager.unauthorized.assert_called_once() + get_user = mocker.patch.object(login_module, "_get_user", return_value=resolved_user) - @patch("libs.login.check_csrf_token", mock_csrf_check) - def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask): - """Test that LOGIN_DISABLED config bypasses authentication.""" + with login_app.test_request_context(): + result = protected_view() - @login_required + assert result == "Unauthorized", description + get_user.assert_called_once_with() + login_app.login_manager.unauthorized.assert_called_once_with() + csrf_check.assert_not_called() + ensure_sync_spy.assert_not_called() + + @pytest.mark.parametrize( + ("method", "login_disabled"), + [ + pytest.param("OPTIONS", False, id="options"), + pytest.param("GET", True, id="login-disabled"), + ], + ) + def test_bypass_paths_skip_authentication_and_csrf( + self, + login_app: Flask, + csrf_check: MagicMock, + ensure_sync_spy: MagicMock, + monkeypatch: pytest.MonkeyPatch, + mocker: MockerFixture, + method: str, + login_disabled: bool, + ): + """Test that bypass conditions skip auth lookup, CSRF, and unauthorized handling.""" + + @login_module.login_required def protected_view(): return "Protected content" - with setup_app.test_request_context(): - # Mock unauthenticated user and LOGIN_DISABLED - mock_user = MockUser("test_user", is_authenticated=False) - with patch("libs.login._get_user", return_value=mock_user, autospec=True): - with patch("libs.login.dify_config", autospec=True) as mock_config: - mock_config.LOGIN_DISABLED = True + get_user = mocker.patch.object(login_module, "_get_user") + monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled) - result = protected_view() - assert result == "Protected content" - # Ensure unauthorized was not called - setup_app.login_manager.unauthorized.assert_not_called() + with login_app.test_request_context(method=method): + result = protected_view() - @patch("libs.login.check_csrf_token", mock_csrf_check) - def test_options_request_bypasses_authentication(self, setup_app: Flask): - """Test that OPTIONS requests are exempt from authentication.""" - - @login_required - def protected_view(): - return "Protected content" - - with setup_app.test_request_context(method="OPTIONS"): - # Mock unauthenticated user - mock_user = MockUser("test_user", is_authenticated=False) - with patch("libs.login._get_user", return_value=mock_user, autospec=True): - result = protected_view() - assert result == "Protected content" - # Ensure unauthorized was not called - setup_app.login_manager.unauthorized.assert_not_called() - - @patch("libs.login.check_csrf_token", mock_csrf_check) - def test_flask_2_compatibility(self, setup_app: Flask): - """Test Flask 2.x compatibility with ensure_sync.""" - - @login_required - def protected_view(): - return "Protected content" - - # Mock Flask 2.x ensure_sync - setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content") - - with setup_app.test_request_context(): - mock_user = MockUser("test_user", is_authenticated=True) - with patch("libs.login._get_user", return_value=mock_user, autospec=True): - result = protected_view() - assert result == "Synced content" - setup_app.ensure_sync.assert_called_once() - - @patch("libs.login.check_csrf_token", mock_csrf_check) - def test_flask_1_compatibility(self, setup_app: Flask): - """Test Flask 1.x compatibility without ensure_sync.""" - - @login_required - def protected_view(): - return "Protected content" - - # Remove ensure_sync to simulate Flask 1.x - if hasattr(setup_app, "ensure_sync"): - del setup_app.ensure_sync - - with setup_app.test_request_context(): - mock_user = MockUser("test_user", is_authenticated=True) - with patch("libs.login._get_user", return_value=mock_user, autospec=True): - result = protected_view() - assert result == "Protected content" + assert result == "Protected content" + get_user.assert_not_called() + ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__) + csrf_check.assert_not_called() + login_app.login_manager.unauthorized.assert_not_called() class TestGetUser: """Test cases for _get_user function.""" - def test_get_user_returns_user_from_g(self, app: Flask): + def test_get_user_returns_user_from_g(self, login_app: Flask): """Test that _get_user returns user from g._login_user.""" mock_user = MockUser("test_user") - with app.test_request_context(): + with login_app.test_request_context(): g._login_user = mock_user - user = _get_user() + user = login_module._get_user() assert user == mock_user assert user.id == "test_user" - def test_get_user_loads_user_if_not_in_g(self, app: Flask): + def test_get_user_loads_user_if_not_in_g(self, login_app: Flask): """Test that _get_user loads user if not already in g.""" mock_user = MockUser("test_user") - # Mock login manager - login_manager = MagicMock() - login_manager._load_user = MagicMock() - app.login_manager = login_manager + def _load_user() -> None: + g._login_user = mock_user - with app.test_request_context(): - # Simulate _load_user setting g._login_user - def side_effect(): - g._login_user = mock_user + login_app.login_manager._load_user = MagicMock(side_effect=_load_user) - login_manager._load_user.side_effect = side_effect + with login_app.test_request_context(): + user = login_module._get_user() - user = _get_user() - assert user == mock_user - login_manager._load_user.assert_called_once() + assert user == mock_user + login_app.login_manager._load_user.assert_called_once_with() - def test_get_user_returns_none_without_request_context(self, app: Flask): + def test_get_user_returns_none_without_request_context(self): """Test that _get_user returns None outside request context.""" - # Outside of request context - user = _get_user() + user = login_module._get_user() assert user is None class TestCurrentUser: """Test cases for current_user proxy.""" - def test_current_user_proxy_returns_authenticated_user(self, app: Flask): + def test_current_user_proxy_returns_authenticated_user(self, login_app: Flask, mocker: MockerFixture): """Test that current_user proxy returns authenticated user.""" mock_user = MockUser("test_user", is_authenticated=True) + mocker.patch.object(login_module, "_get_user", return_value=mock_user) - with app.test_request_context(): - with patch("libs.login._get_user", return_value=mock_user, autospec=True): - assert current_user.id == "test_user" - assert current_user.is_authenticated is True + with login_app.test_request_context(): + assert current_user.id == "test_user" + assert current_user.is_authenticated is True - def test_current_user_proxy_returns_none_when_no_user(self, app: Flask): + def test_current_user_proxy_raises_attribute_error_when_no_user(self, login_app: Flask, mocker: MockerFixture): """Test that current_user proxy handles None user.""" - with app.test_request_context(): - with patch("libs.login._get_user", return_value=None, autospec=True): - # When _get_user returns None, accessing attributes should fail - # or current_user should evaluate to falsy - try: - # Try to access an attribute that would exist on a real user - _ = current_user.id - pytest.fail("Should have raised AttributeError") - except AttributeError: - # This is expected when current_user is None - pass + mocker.patch.object(login_module, "_get_user", return_value=None) - def test_current_user_proxy_thread_safety(self, app: Flask): - """Test that current_user proxy is thread-safe.""" - import threading + with login_app.test_request_context(): + with pytest.raises(AttributeError): + _ = current_user.id - results = {} - def check_user_in_thread(user_id: str, index: int): - with app.test_request_context(): - mock_user = MockUser(user_id) - with patch("libs.login._get_user", return_value=mock_user, autospec=True): - results[index] = current_user.id +class TestCurrentAccountWithTenant: + """Test cases for current_account_with_tenant helper.""" - # Create multiple threads with different users - threads = [] - for i in range(5): - thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i)) - threads.append(thread) - thread.start() + def test_returns_account_and_tenant_id(self, mocker: MockerFixture): + account = Account(name="Test User", email="test@example.com") + account._current_tenant = SimpleNamespace(id="tenant-123") + current_user_proxy = MagicMock() + current_user_proxy._get_current_object.return_value = account + mocker.patch.object(login_module, "current_user", new=current_user_proxy) - # Wait for all threads to complete - for thread in threads: - thread.join() + user, tenant_id = login_module.current_account_with_tenant() - # Verify each thread got its own user - for i in range(5): - assert results[i] == f"user_{i}" + assert user is account + assert tenant_id == "tenant-123" + current_user_proxy._get_current_object.assert_called_once_with() + + def test_raises_when_current_user_is_not_account(self, mocker: MockerFixture): + mocker.patch.object(login_module, "current_user", new=MockUser("test_user")) + + with pytest.raises(ValueError, match="current_user must be an Account instance"): + login_module.current_account_with_tenant() + + def test_raises_when_account_has_no_tenant(self, mocker: MockerFixture): + account = Account(name="Test User", email="test@example.com") + mocker.patch.object(login_module, "current_user", new=account) + + with pytest.raises(AssertionError, match="tenant information should be loaded"): + login_module.current_account_with_tenant() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index fea2177f3e..362209380e 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -70,8 +70,11 @@ def service(mocker: MockerFixture) -> ModelLoadBalancingService: # Arrange provider_manager = MagicMock() mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager) + model_assembly = SimpleNamespace(provider_manager=provider_manager, model_provider_factory=MagicMock()) + mocker.patch("services.model_load_balancing_service.create_plugin_model_assembly", return_value=model_assembly) svc = ModelLoadBalancingService() svc.provider_manager = provider_manager + svc.model_assembly = model_assembly svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign] return svc @@ -667,6 +670,9 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_ assert mock_validate.call_count == 2 assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None + shared_model_provider_factory = service.model_assembly.model_provider_factory + assert mock_validate.call_args_list[0].kwargs["model_provider_factory"] is shared_model_provider_factory + assert mock_validate.call_args_list[1].kwargs["model_provider_factory"] is shared_model_provider_factory def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt( @@ -709,10 +715,6 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") mock_factory = MagicMock() mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} - mocker.patch( - "services.model_load_balancing_service.create_plugin_model_provider_factory", - return_value=mock_factory, - ) mock_encrypt = mocker.patch( "services.model_load_balancing_service.encrypter.encrypt_token", side_effect=lambda tenant_id, value: f"enc:{value}", @@ -726,6 +728,7 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val model="gpt-4o-mini", credentials={"api_key": "plain"}, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=mock_factory, validate=True, ) @@ -744,10 +747,6 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) mock_factory = MagicMock() mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} - mocker.patch( - "services.model_load_balancing_service.create_plugin_model_provider_factory", - return_value=mock_factory, - ) mocker.patch( "services.model_load_balancing_service.encrypter.encrypt_token", side_effect=lambda tenant_id, value: f"enc:{value}", @@ -760,6 +759,7 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m model_type=ModelType.LLM, model="gpt-4o-mini", credentials={"api_key": "plain"}, + model_provider_factory=mock_factory, validate=True, ) diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index dde062bc7c..60e110a72d 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -25,6 +25,7 @@ from dify_graph.enums import ( ) from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.node_events import NodeRunResult from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from dify_graph.variables.input_entities import VariableEntityType @@ -1545,7 +1546,10 @@ class TestWorkflowServiceCredentialValidation: def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None: """If ModelManager raises any exception it must be wrapped into ValueError.""" # Arrange - with patch("core.model_manager.ModelManager.get_model_instance", side_effect=RuntimeError("no key")): + assembly = MagicMock() + assembly.model_manager.get_model_instance.side_effect = RuntimeError("no key") + + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act + Assert with pytest.raises(ValueError, match="Failed to validate LLM model configuration"): service._validate_llm_model_config("tenant-1", "openai", "gpt-4") @@ -1558,30 +1562,30 @@ class TestWorkflowServiceCredentialValidation: mock_configs = MagicMock() mock_configs.get_models.return_value = [mock_model] + assembly = MagicMock() + assembly.provider_manager.get_configurations.return_value = mock_configs - with ( - patch("core.model_manager.ModelManager.get_model_instance"), - patch("core.provider_manager.ProviderManager") as mock_pm_cls, - ): - mock_pm_cls.return_value.get_configurations.return_value = mock_configs - + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act service._validate_llm_model_config("tenant-1", "openai", "gpt-4") # Assert mock_model.raise_for_status.assert_called_once() + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4", + ) def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None: """Test ValueError when model is not found in provider configurations.""" mock_configs = MagicMock() mock_configs.get_models.return_value = [] # No models + assembly = MagicMock() + assembly.provider_manager.get_configurations.return_value = mock_configs - with ( - patch("core.model_manager.ModelManager.get_model_instance"), - patch("core.provider_manager.ProviderManager") as mock_pm_cls, - ): - mock_pm_cls.return_value.get_configurations.return_value = mock_configs - + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act + Assert with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"): service._validate_llm_model_config("tenant-1", "openai", "gpt-4")