fix(model): preserve caller-scoped runtime assembly

This commit is contained in:
WH-2099 2026-03-24 17:48:30 +08:00
parent 57f9053b3a
commit 0b8a5e83db
No known key found for this signature in database
28 changed files with 645 additions and 264 deletions

View File

@ -88,6 +88,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_tenant_id,
@ -127,6 +128,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
)
except Exception:
continue

View File

@ -140,6 +140,7 @@ class BaseAgentRunner(AppRunner):
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
user_id=self.user_id,
invoke_from=self.application_generate_entity.invoke_from,
)
assert tool_entity.entity.description

View File

@ -2,7 +2,7 @@ from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from models.model import AppModelConfigDict
from models.provider_ids import ModelProviderID
@ -53,9 +53,12 @@ class ModelConfigManager:
if not isinstance(config["model"], dict):
raise ValueError("model must be of object type")
# Keep provider discovery and provider-backed model listing on the same
# request-scoped runtime so caller scope and provider caches stay aligned.
assembly = create_plugin_model_assembly(tenant_id=tenant_id)
# model.provider
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
provider_entities = model_provider_factory.get_providers()
provider_entities = assembly.model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"]:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
@ -70,8 +73,7 @@ class ModelConfigManager:
if "name" not in config["model"]:
raise ValueError("model.name is required")
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
models = provider_manager.get_configurations(tenant_id).get_models(
models = assembly.provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], model_type=ModelType.LLM
)

View File

@ -105,8 +105,8 @@ class ProviderConfiguration(BaseModel):
"""Attach the already-composed runtime for request-bound call chains."""
self._bound_model_runtime = model_runtime
def _get_model_provider_factory(self) -> ModelProviderFactory:
"""Reuse a bound runtime when available, else fall back to tenant scope."""
def get_model_provider_factory(self) -> ModelProviderFactory:
"""Return a provider factory that preserves any request-bound runtime."""
if self._bound_model_runtime is not None:
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
@ -362,7 +362,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self._get_model_provider_factory()
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
@ -921,7 +921,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self._get_model_provider_factory()
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@ -1407,7 +1407,7 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = self._get_model_provider_factory()
model_provider_factory = self.get_model_provider_factory()
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
@ -1416,7 +1416,7 @@ class ProviderConfiguration(BaseModel):
"""
Get model schema
"""
model_provider_factory = self._get_model_provider_factory()
model_provider_factory = self.get_model_provider_factory()
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@ -1518,7 +1518,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_factory = self._get_model_provider_factory()
model_provider_factory = self.get_model_provider_factory()
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types: list[ModelType] = []

View File

@ -277,18 +277,22 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
}
@classmethod
def get_system_model_max_tokens(cls, tenant_id: str) -> int:
def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int:
"""
get system model max tokens
"""
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id)
@classmethod
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int:
"""
get prompt tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
return ModelInvocationUtils.calculate_tokens(
tenant_id=tenant_id,
prompt_messages=prompt_messages,
user_id=user_id,
)
@classmethod
def invoke_system_model(
@ -306,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
tool_type=ToolProviderType.PLUGIN,
tool_name="plugin",
prompt_messages=prompt_messages,
caller_user_id=user_id,
)
@classmethod
@ -313,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke summary
"""
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id)
content = payload.text
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
@ -332,6 +337,7 @@ Here is the extra instruction you need to follow:
cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[UserPromptMessage(content=content)],
user_id=user_id,
)
< max_tokens * 0.6
):
@ -344,6 +350,7 @@ Here is the extra instruction you need to follow:
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
UserPromptMessage(content=content),
],
user_id=user_id,
)
def summarize(content: str) -> str:
@ -401,6 +408,7 @@ Here is the extra instruction you need to follow:
cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[UserPromptMessage(content=result)],
user_id=user_id,
)
> max_tokens * 0.7
):

View File

@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
tool_type,
tenant_id,
provider,
tool_name,
tool_parameters,
user_id=user_id,
credential_id=credential_id,
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1

View File

@ -3,12 +3,64 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
if TYPE_CHECKING:
from core.model_manager import ModelManager
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
class PluginModelAssembly:
"""Compose request-scoped model views on top of a single plugin runtime."""
tenant_id: str
user_id: str | None
_model_runtime: PluginModelRuntime | None
_model_provider_factory: ModelProviderFactory | None
_provider_manager: ProviderManager | None
_model_manager: ModelManager | None
def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None:
self.tenant_id = tenant_id
self.user_id = user_id
self._model_runtime = None
self._model_provider_factory = None
self._provider_manager = None
self._model_manager = None
@property
def model_runtime(self) -> PluginModelRuntime:
if self._model_runtime is None:
self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id)
return self._model_runtime
@property
def model_provider_factory(self) -> ModelProviderFactory:
if self._model_provider_factory is None:
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
return self._model_provider_factory
@property
def provider_manager(self) -> ProviderManager:
if self._provider_manager is None:
from core.provider_manager import ProviderManager
self._provider_manager = ProviderManager(model_runtime=self.model_runtime)
return self._provider_manager
@property
def model_manager(self) -> ModelManager:
if self._model_manager is None:
from core.model_manager import ModelManager
self._model_manager = ModelManager(provider_manager=self.provider_manager)
return self._model_manager
def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly:
"""Create a request-scoped assembly that shares one plugin runtime across model views."""
return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id)
def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime:
@ -24,22 +76,14 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -
def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory:
"""Create a tenant-bound model provider factory for service flows."""
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
return ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
"""Create a tenant-bound provider manager for service flows."""
from core.provider_manager import ProviderManager
return ProviderManager(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
"""Create a tenant-bound model manager for service flows."""
from core.model_manager import ModelManager
return ModelManager(
provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id),
)
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager

View File

@ -9,10 +9,14 @@ from core.tools.entities.tool_entities import ToolInvokeFrom
class ToolRuntime(BaseModel):
"""
Meta data of a tool call processing
Meta data of a tool call processing.
``user_id`` is optional so read-only tooling flows can stay tenant-scoped,
while execution paths may bind caller identity for model runtime lookups.
"""
tenant_id: str
user_id: str | None = None
tool_id: str | None = None
invoke_from: InvokeFrom | None = None
tool_invoke_from: ToolInvokeFrom | None = None

View File

@ -53,6 +53,7 @@ class BuiltinTool(Tool):
tool_type=ToolProviderType.BUILT_IN,
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
caller_user_id=self.runtime.user_id,
)
def tool_provider_type(self) -> ToolProviderType:
@ -69,6 +70,7 @@ class BuiltinTool(Tool):
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id or "",
user_id=self.runtime.user_id,
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
@ -82,7 +84,9 @@ class BuiltinTool(Tool):
raise ValueError("runtime is required")
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
tenant_id=self.runtime.tenant_id or "",
prompt_messages=prompt_messages,
user_id=self.runtime.user_id,
)
def summary(self, user_id: str, content: str) -> str:

View File

@ -184,6 +184,7 @@ class ToolManager:
provider_id: str,
tool_name: str,
tenant_id: str,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: str | None = None,
@ -195,6 +196,7 @@ class ToolManager:
:param provider_id: the id of the provider
:param tool_name: the name of the tool
:param tenant_id: the tenant id
:param user_id: the caller id bound to runtime-scoped model/tool lookups
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
@ -213,6 +215,7 @@ class ToolManager:
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@ -321,6 +324,7 @@ class ToolManager:
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
@ -338,6 +342,7 @@ class ToolManager:
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(encrypter.decrypt(credentials)),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@ -361,6 +366,7 @@ class ToolManager:
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@ -369,9 +375,21 @@ class ToolManager:
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
runtime = getattr(plugin_tool, "runtime", None)
if runtime is not None:
runtime.user_id = user_id
runtime.invoke_from = invoke_from
runtime.tool_invoke_from = tool_invoke_from
return plugin_tool
elif provider_type == ToolProviderType.MCP:
return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
runtime = getattr(mcp_tool, "runtime", None)
if runtime is not None:
runtime.user_id = user_id
runtime.invoke_from = invoke_from
runtime.tool_invoke_from = tool_invoke_from
return mcp_tool
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@ -381,6 +399,7 @@ class ToolManager:
tenant_id: str,
app_id: str,
agent_tool: AgentToolEntity,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
) -> Tool:
@ -392,6 +411,7 @@ class ToolManager:
provider_id=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
user_id=user_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
@ -423,6 +443,7 @@ class ToolManager:
app_id: str,
node_id: str,
workflow_tool: WorkflowToolRuntimeSpec,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
) -> Tool:
@ -435,6 +456,7 @@ class ToolManager:
provider_id=workflow_tool.provider_id,
tool_name=workflow_tool.tool_name,
tenant_id=tenant_id,
user_id=user_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=workflow_tool.credential_id,
@ -467,6 +489,7 @@ class ToolManager:
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
user_id: str | None = None,
credential_id: str | None = None,
) -> Tool:
"""
@ -477,6 +500,7 @@ class ToolManager:
provider_id=provider,
tool_name=tool_name,
tenant_id=tenant_id,
user_id=user_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=credential_id,

View File

@ -34,11 +34,12 @@ class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
user_id: str | None = None,
) -> int:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
@ -60,13 +61,13 @@ class ModelInvocationUtils:
return max_tokens
@staticmethod
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int:
"""
calculate tokens from prompt messages and model parameters
"""
# get model instance
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
@ -79,7 +80,12 @@ class ModelInvocationUtils:
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
user_id: str,
tenant_id: str,
tool_type: ToolProviderType,
tool_name: str,
prompt_messages: list[PromptMessage],
caller_user_id: str | None = None,
) -> LLMResult:
"""
invoke model with parameters in user's own context
@ -93,7 +99,7 @@ class ModelInvocationUtils:
"""
# get model manager
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=caller_user_id or user_id)
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,

View File

@ -337,6 +337,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
self._run_context.app_id,
node_id,
self._build_tool_runtime_spec(node_data),
self._run_context.user_id,
self._run_context.invoke_from,
variable_pool,
)

View File

@ -12,9 +12,9 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_manager import ModelInstance
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from core.workflow.system_variables import SystemVariableKey, get_system_text
@ -141,6 +141,7 @@ class AgentRuntimeSupport:
tenant_id,
app_id,
entity,
user_id,
invoke_from,
runtime_variable_pool,
)
@ -242,8 +243,8 @@ class AgentRuntimeSupport:
user_id: str,
value: dict[str, Any],
) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)
provider_model_bundle = provider_manager.get_provider_model_bundle(
assembly = create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id)
provider_model_bundle = assembly.provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
model_type=ModelType.LLM,
@ -255,7 +256,7 @@ class AgentRuntimeSupport:
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
model_instance = assembly.model_manager.get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
@ -34,8 +34,6 @@ def current_account_with_tenant():
return user, user.current_tenant_id
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")

View File

@ -205,6 +205,7 @@ class AppService:
tenant_id=current_user.current_tenant_id,
app_id=app.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,

View File

@ -10,13 +10,14 @@ from core.entities.provider_configuration import ProviderConfiguration
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_manager import LBModelManager
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.entities.provider_entities import (
ModelCredentialSchema,
ProviderCredentialSchema,
)
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.enums import CredentialSourceType
@ -223,8 +224,8 @@ class ModelLoadBalancingService:
:param config_id: load balancing config id
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id)
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_configurations = provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
@ -496,8 +497,8 @@ class ModelLoadBalancingService:
:param config_id: load balancing config id
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id)
assembly = create_plugin_model_assembly(tenant_id=tenant_id)
provider_configurations = assembly.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
@ -533,6 +534,7 @@ class ModelLoadBalancingService:
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config,
model_provider_factory=assembly.model_provider_factory,
)
def _custom_credentials_validate(
@ -543,6 +545,7 @@ class ModelLoadBalancingService:
model: str,
credentials: dict,
load_balancing_model_config: LoadBalancingModelConfig | None = None,
model_provider_factory: ModelProviderFactory | None = None,
validate: bool = True,
):
"""
@ -553,6 +556,7 @@ class ModelLoadBalancingService:
:param model: model name
:param credentials: credentials
:param load_balancing_model_config: load balancing model config
:param model_provider_factory: model provider factory sharing the active runtime
:param validate: validate credentials
:return:
"""
@ -582,7 +586,8 @@ class ModelLoadBalancingService:
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
if validate:
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
if model_provider_factory is None:
model_provider_factory = provider_configuration.get_model_provider_factory()
if isinstance(credential_schemas, ModelCredentialSchema):
credentials = model_provider_factory.model_credentials_validate(
provider=provider_configuration.provider.provider,

View File

@ -13,7 +13,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.file_access import DatabaseFileAccessController
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.trigger.constants import is_trigger_node_type
@ -491,12 +491,15 @@ class WorkflowService:
:raises ValueError: If the model configuration is invalid or credentials fail policy checks
"""
try:
from core.model_manager import ModelManager
from dify_graph.model_runtime.entities.model_entities import ModelType
# Model instance resolution and provider status lookup must reuse the
# same request-scoped runtime so validation does not silently split
# provider discovery and credential reads across different caches.
assembly = create_plugin_model_assembly(tenant_id=tenant_id)
# Get model instance to validate provider+model combination
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_manager.get_model_instance(
assembly.model_manager.get_model_instance(
tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
)
@ -505,8 +508,7 @@ class WorkflowService:
# If it fails, an exception will be raised
# Additionally, check the model status to ensure it's ACTIVE
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_configurations = provider_manager.get_configurations(tenant_id)
provider_configurations = assembly.provider_manager.get_configurations(tenant_id)
models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM)
target_model = None

View File

@ -0,0 +1,57 @@
from types import SimpleNamespace
from unittest.mock import patch
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.entities import ModelConfigEntity
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from models.provider_ids import ModelProviderID
def test_validate_and_set_defaults_reuses_single_model_assembly():
provider_name = str(ModelProviderID("openai"))
provider_entity = SimpleNamespace(provider=provider_name)
model = SimpleNamespace(model="gpt-4o-mini", model_properties={ModelPropertyKey.MODE: "chat"})
provider_configurations = SimpleNamespace(get_models=lambda **kwargs: [model])
assembly = SimpleNamespace(
model_provider_factory=SimpleNamespace(get_providers=lambda: [provider_entity]),
provider_manager=SimpleNamespace(get_configurations=lambda tenant_id: provider_configurations),
)
config = {
"model": {
"provider": "openai",
"name": "gpt-4o-mini",
"completion_params": {"stop": []},
}
}
with patch(
"core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly",
return_value=assembly,
) as mock_assembly:
result, keys = ModelConfigManager.validate_and_set_defaults("tenant-1", config)
assert result["model"]["provider"] == provider_name
assert result["model"]["mode"] == "chat"
assert keys == ["model"]
mock_assembly.assert_called_once_with(tenant_id="tenant-1")
def test_convert_keeps_model_config_shape():
config = {
"model": {
"provider": "openai",
"name": "gpt-4o-mini",
"mode": "chat",
"completion_params": {"temperature": 0.3, "stop": ["END"]},
}
}
result = ModelConfigManager.convert(config)
assert result == ModelConfigEntity(
provider="openai",
model="gpt-4o-mini",
mode="chat",
parameters={"temperature": 0.3},
stop=["END"],
)

View File

@ -0,0 +1,36 @@
from unittest.mock import Mock, patch
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
def test_plugin_model_assembly_reuses_single_runtime_across_views():
runtime = Mock(name="runtime")
provider_factory = Mock(name="provider_factory")
provider_manager = Mock(name="provider_manager")
model_manager = Mock(name="model_manager")
with (
patch(
"core.plugin.impl.model_runtime_factory.create_plugin_model_runtime",
return_value=runtime,
) as mock_runtime_factory,
patch(
"core.plugin.impl.model_runtime_factory.ModelProviderFactory",
return_value=provider_factory,
) as mock_provider_factory_cls,
patch("core.provider_manager.ProviderManager", return_value=provider_manager) as mock_provider_manager_cls,
patch("core.model_manager.ModelManager", return_value=model_manager) as mock_model_manager_cls,
):
assembly = create_plugin_model_assembly(tenant_id="tenant-1", user_id="user-1")
assert assembly.model_provider_factory is provider_factory
assert assembly.provider_manager is provider_manager
assert assembly.model_manager is model_manager
assert assembly.model_provider_factory is provider_factory
assert assembly.provider_manager is provider_manager
assert assembly.model_manager is model_manager
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)

View File

@ -0,0 +1,61 @@
from types import SimpleNamespace
from unittest.mock import patch
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.entities.request import RequestInvokeSummary
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
def test_system_model_helpers_forward_user_id():
with (
patch(
"core.plugin.backwards_invocation.model.ModelInvocationUtils.get_max_llm_context_tokens",
return_value=4096,
) as mock_max_tokens,
patch(
"core.plugin.backwards_invocation.model.ModelInvocationUtils.calculate_tokens",
return_value=7,
) as mock_prompt_tokens,
):
assert PluginModelBackwardsInvocation.get_system_model_max_tokens("tenant-1", user_id="user-1") == 4096
assert (
PluginModelBackwardsInvocation.get_prompt_tokens(
"tenant-1",
[UserPromptMessage(content="hello")],
user_id="user-1",
)
== 7
)
mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_prompt_tokens.assert_called_once_with(
tenant_id="tenant-1",
prompt_messages=[UserPromptMessage(content="hello")],
user_id="user-1",
)
def test_invoke_summary_uses_same_user_scope_for_token_helpers():
tenant = SimpleNamespace(id="tenant-1")
payload = RequestInvokeSummary(text="short", instruction="keep it concise")
with (
patch.object(
PluginModelBackwardsInvocation,
"get_system_model_max_tokens",
return_value=100,
) as mock_max_tokens,
patch.object(
PluginModelBackwardsInvocation,
"get_prompt_tokens",
return_value=10,
) as mock_prompt_tokens,
):
assert PluginModelBackwardsInvocation.invoke_summary("user-1", tenant, payload) == "short"
mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_prompt_tokens.assert_called_once_with(
tenant_id="tenant-1",
prompt_messages=[UserPromptMessage(content="short")],
user_id="user-1",
)

View File

@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool):
yield self.create_text_message("ok")
def _build_tool() -> _BuiltinDummyTool:
def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool:
entity = ToolEntity(
identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
parameters=[],
)
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER)
return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime)
@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type():
def test_invoke_model_calls_model_invocation_utils_invoke():
tool = _build_tool()
tool = _build_tool(user_id="runtime-user")
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke:
assert (
tool.invoke_model(
@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke():
)
== "result"
)
mock_invoke.assert_called_once()
mock_invoke.assert_called_once_with(
user_id="u1",
tenant_id="tenant-1",
tool_type=ToolProviderType.BUILT_IN,
tool_name="tool-a",
prompt_messages=[UserPromptMessage(content="hello")],
caller_user_id="runtime-user",
)
def test_get_max_tokens_returns_value():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096):
tool = _build_tool(user_id="runtime-user")
with patch(
"core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096
) as mock_get:
assert tool.get_max_tokens() == 4096
mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user")
def test_get_prompt_tokens_returns_value():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7):
tool = _build_tool(user_id="runtime-user")
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate:
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
mock_calculate.assert_called_once_with(
tenant_id="tenant-1",
prompt_messages=[UserPromptMessage(content="hello")],
user_id="runtime-user",
)
def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate:
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
mock_calculate.assert_called_once_with(
tenant_id="tenant-1",
prompt_messages=[UserPromptMessage(content="hello")],
user_id=None,
)
def test_runtime_none_raises():

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import calendar
import math
from datetime import date
from types import SimpleNamespace
import pytest
@ -98,7 +100,13 @@ def test_timezone_conversion_tool():
def test_weekday_tool():
weekday_tool = _build_builtin_tool(WeekdayTool)
valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text
assert "January 1, 2024" in valid
expected_date = date(2024, 1, 1)
expected_message = (
f"{calendar.month_name[expected_date.month]} "
f"{expected_date.day}, {expected_date.year} "
f"is {calendar.day_name[expected_date.weekday()]}."
)
assert valid == expected_message
invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[
0
].message.text

View File

@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
@ -421,7 +422,7 @@ def test_get_agent_runtime_apply_runtime_parameters():
tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime):
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime:
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}):
manager = Mock()
manager.decrypt_tool_parameters.return_value = {"query": "decrypted"}
@ -437,12 +438,23 @@ def test_get_agent_runtime_apply_runtime_parameters():
tenant_id="tenant-1",
app_id="app-1",
agent_tool=agent_tool,
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
variable_pool=None,
)
assert result is tool_runtime
assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted"
mock_get_tool_runtime.assert_called_once_with(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tenant_id="tenant-1",
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=None,
)
def test_get_workflow_runtime_apply_runtime_parameters():
@ -463,7 +475,7 @@ def test_get_workflow_runtime_apply_runtime_parameters():
)
tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2):
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime:
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}):
manager = Mock()
manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"}
@ -473,12 +485,23 @@ def test_get_workflow_runtime_apply_runtime_parameters():
app_id="app-1",
node_id="node-1",
workflow_tool=workflow_tool,
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
variable_pool=None,
)
assert workflow_result is tool_runtime2
assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec"
mock_get_tool_runtime.assert_called_once_with(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tenant_id="tenant-1",
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=None,
)
def test_get_agent_runtime_raises_when_runtime_missing():
@ -520,17 +543,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters():
tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity):
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime:
result = ToolManager.get_tool_runtime_from_plugin(
tool_type=ToolProviderType.API,
tenant_id="tenant-1",
provider="api-1",
tool_name="search",
tool_parameters={"q": "hello", "llm": "ignore"},
user_id="user-1",
)
assert result is tool_entity
assert tool_entity.runtime.runtime_parameters == {"q": "hello"}
mock_get_tool_runtime.assert_called_once_with(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tenant_id="tenant-1",
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=None,
)
def test_hardcoded_provider_icon_success():

View File

@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat
manager = Mock()
manager.get_default_model_instance.return_value = model_instance
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
if error_match:
with pytest.raises(InvokeModelError, match=error_match):
ModelInvocationUtils.get_max_llm_context_tokens("tenant")
ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1")
else:
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1")
def test_calculate_tokens_handles_missing_model():
manager = Mock()
manager.get_default_model_instance.return_value = None
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
with pytest.raises(InvokeModelError, match="Model not found"):
ModelInvocationUtils.calculate_tokens("tenant", [])
mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None)
def test_invoke_success_and_error_mappings():
@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings():
db_mock = SimpleNamespace(session=Mock())
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
response = ModelInvocationUtils.invoke(
@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings():
tool_type="builtin",
tool_name="tool-a",
prompt_messages=[],
caller_user_id="caller-1",
)
assert response.message.content == "ok"
assert db_mock.session.add.call_count == 1
assert db_mock.session.commit.call_count == 2
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1")
@pytest.mark.parametrize(
@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected):
db_mock = SimpleNamespace(session=Mock())
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
with pytest.raises(InvokeModelError, match=expected):
@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected):
tool_name="tool-a",
prompt_messages=[],
)
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1")

View File

@ -0,0 +1,49 @@
from types import SimpleNamespace
from unittest.mock import Mock, patch
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from dify_graph.model_runtime.entities.model_entities import ModelType
def test_fetch_model_reuses_single_model_assembly():
provider_configuration = SimpleNamespace(
get_current_credentials=Mock(return_value={"api_key": "x"}),
provider=SimpleNamespace(provider="openai"),
)
model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value="schema"))
provider_model_bundle = SimpleNamespace(
configuration=provider_configuration,
model_type_instance=model_type_instance,
)
model_instance = Mock()
assembly = SimpleNamespace(
provider_manager=Mock(),
model_manager=Mock(),
)
assembly.provider_manager.get_provider_model_bundle.return_value = provider_model_bundle
assembly.model_manager.get_model_instance.return_value = model_instance
with patch(
"core.workflow.nodes.agent.runtime_support.create_plugin_model_assembly",
return_value=assembly,
) as mock_assembly:
resolved_instance, resolved_schema = AgentRuntimeSupport().fetch_model(
tenant_id="tenant-1",
user_id="user-1",
value={"provider": "openai", "model": "gpt-4o-mini", "model_type": "llm"},
)
assert resolved_instance is model_instance
assert resolved_schema == "schema"
mock_assembly.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
assembly.provider_manager.get_provider_model_bundle.assert_called_once_with(
tenant_id="tenant-1",
provider="openai",
model_type=ModelType.LLM,
)
assembly.model_manager.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="openai",
model_type=ModelType.LLM,
model="gpt-4o-mini",
)

View File

@ -1,10 +1,14 @@
from unittest.mock import MagicMock, patch
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask, g
from flask_login import LoginManager, UserMixin
from pytest_mock import MockerFixture
from libs.login import _get_user, current_user, login_required
import libs.login as login_module
from libs.login import current_user
from models.account import Account
class MockUser(UserMixin):
@ -15,229 +19,214 @@ class MockUser(UserMixin):
self._is_authenticated = is_authenticated
@property
def is_authenticated(self):
def is_authenticated(self) -> bool:
return self._is_authenticated
def mock_csrf_check(*args, **kwargs):
return
@pytest.fixture
def login_app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
@login_manager.user_loader
def load_user(_user_id: str):
return None
return app
@pytest.fixture
def csrf_check(mocker: MockerFixture) -> MagicMock:
return mocker.patch.object(login_module, "check_csrf_token")
@pytest.fixture
def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock:
def _ensure_sync(func):
return lambda *args, **kwargs: func(*args, **kwargs)
return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync)
class TestLoginRequired:
"""Test cases for login_required decorator."""
@pytest.fixture
@patch("libs.login.check_csrf_token", mock_csrf_check)
def setup_app(self, app: Flask):
"""Set up Flask app with login manager."""
# Initialize login manager
login_manager = LoginManager()
login_manager.init_app(app)
# Mock unauthorized handler
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
# Add a dummy user loader to prevent exceptions
@login_manager.user_loader
def load_user(user_id):
return None
return app
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
def test_authenticated_user_can_access_protected_view(
self, login_app: Flask, csrf_check: MagicMock, ensure_sync_spy: MagicMock, mocker: MockerFixture
):
"""Test that authenticated users can access protected views."""
@login_required
@login_module.login_required
def protected_view():
return "Protected content"
with setup_app.test_request_context():
# Mock authenticated user
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Protected content"
mock_user = MockUser("test_user", is_authenticated=True)
get_user = mocker.patch.object(login_module, "_get_user", return_value=mock_user)
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
"""Test that unauthenticated users are redirected."""
with login_app.test_request_context():
result = protected_view()
csrf_check.assert_called_once()
assert csrf_check.call_args.args[0].method == "GET"
assert csrf_check.call_args.args[1] == "test_user"
@login_required
assert result == "Protected content"
get_user.assert_called_once_with()
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
login_app.login_manager.unauthorized.assert_not_called()
@pytest.mark.parametrize(
("resolved_user", "description"),
[
pytest.param(None, "missing user", id="missing-user"),
pytest.param(MockUser("test_user", is_authenticated=False), "unauthenticated user", id="unauthenticated"),
],
)
def test_unauthorized_access_returns_login_manager_response(
self,
login_app: Flask,
csrf_check: MagicMock,
ensure_sync_spy: MagicMock,
mocker: MockerFixture,
resolved_user: MockUser | None,
description: str,
):
"""Test that missing or unauthenticated users are redirected."""
@login_module.login_required
def protected_view():
return "Protected content"
with setup_app.test_request_context():
# Mock unauthenticated user
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Unauthorized"
setup_app.login_manager.unauthorized.assert_called_once()
get_user = mocker.patch.object(login_module, "_get_user", return_value=resolved_user)
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
"""Test that LOGIN_DISABLED config bypasses authentication."""
with login_app.test_request_context():
result = protected_view()
@login_required
assert result == "Unauthorized", description
get_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
csrf_check.assert_not_called()
ensure_sync_spy.assert_not_called()
@pytest.mark.parametrize(
("method", "login_disabled"),
[
pytest.param("OPTIONS", False, id="options"),
pytest.param("GET", True, id="login-disabled"),
],
)
def test_bypass_paths_skip_authentication_and_csrf(
self,
login_app: Flask,
csrf_check: MagicMock,
ensure_sync_spy: MagicMock,
monkeypatch: pytest.MonkeyPatch,
mocker: MockerFixture,
method: str,
login_disabled: bool,
):
"""Test that bypass conditions skip auth lookup, CSRF, and unauthorized handling."""
@login_module.login_required
def protected_view():
return "Protected content"
with setup_app.test_request_context():
# Mock unauthenticated user and LOGIN_DISABLED
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
with patch("libs.login.dify_config", autospec=True) as mock_config:
mock_config.LOGIN_DISABLED = True
get_user = mocker.patch.object(login_module, "_get_user")
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled)
result = protected_view()
assert result == "Protected content"
# Ensure unauthorized was not called
setup_app.login_manager.unauthorized.assert_not_called()
with login_app.test_request_context(method=method):
result = protected_view()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_options_request_bypasses_authentication(self, setup_app: Flask):
"""Test that OPTIONS requests are exempt from authentication."""
@login_required
def protected_view():
return "Protected content"
with setup_app.test_request_context(method="OPTIONS"):
# Mock unauthenticated user
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Protected content"
# Ensure unauthorized was not called
setup_app.login_manager.unauthorized.assert_not_called()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_flask_2_compatibility(self, setup_app: Flask):
"""Test Flask 2.x compatibility with ensure_sync."""
@login_required
def protected_view():
return "Protected content"
# Mock Flask 2.x ensure_sync
setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
with setup_app.test_request_context():
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Synced content"
setup_app.ensure_sync.assert_called_once()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_flask_1_compatibility(self, setup_app: Flask):
"""Test Flask 1.x compatibility without ensure_sync."""
@login_required
def protected_view():
return "Protected content"
# Remove ensure_sync to simulate Flask 1.x
if hasattr(setup_app, "ensure_sync"):
del setup_app.ensure_sync
with setup_app.test_request_context():
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Protected content"
assert result == "Protected content"
get_user.assert_not_called()
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
csrf_check.assert_not_called()
login_app.login_manager.unauthorized.assert_not_called()
class TestGetUser:
"""Test cases for _get_user function."""
def test_get_user_returns_user_from_g(self, app: Flask):
def test_get_user_returns_user_from_g(self, login_app: Flask):
"""Test that _get_user returns user from g._login_user."""
mock_user = MockUser("test_user")
with app.test_request_context():
with login_app.test_request_context():
g._login_user = mock_user
user = _get_user()
user = login_module._get_user()
assert user == mock_user
assert user.id == "test_user"
def test_get_user_loads_user_if_not_in_g(self, app: Flask):
def test_get_user_loads_user_if_not_in_g(self, login_app: Flask):
"""Test that _get_user loads user if not already in g."""
mock_user = MockUser("test_user")
# Mock login manager
login_manager = MagicMock()
login_manager._load_user = MagicMock()
app.login_manager = login_manager
def _load_user() -> None:
g._login_user = mock_user
with app.test_request_context():
# Simulate _load_user setting g._login_user
def side_effect():
g._login_user = mock_user
login_app.login_manager._load_user = MagicMock(side_effect=_load_user)
login_manager._load_user.side_effect = side_effect
with login_app.test_request_context():
user = login_module._get_user()
user = _get_user()
assert user == mock_user
login_manager._load_user.assert_called_once()
assert user == mock_user
login_app.login_manager._load_user.assert_called_once_with()
def test_get_user_returns_none_without_request_context(self, app: Flask):
def test_get_user_returns_none_without_request_context(self):
"""Test that _get_user returns None outside request context."""
# Outside of request context
user = _get_user()
user = login_module._get_user()
assert user is None
class TestCurrentUser:
"""Test cases for current_user proxy."""
def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
def test_current_user_proxy_returns_authenticated_user(self, login_app: Flask, mocker: MockerFixture):
"""Test that current_user proxy returns authenticated user."""
mock_user = MockUser("test_user", is_authenticated=True)
mocker.patch.object(login_module, "_get_user", return_value=mock_user)
with app.test_request_context():
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
assert current_user.id == "test_user"
assert current_user.is_authenticated is True
with login_app.test_request_context():
assert current_user.id == "test_user"
assert current_user.is_authenticated is True
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
def test_current_user_proxy_raises_attribute_error_when_no_user(self, login_app: Flask, mocker: MockerFixture):
"""Test that current_user proxy handles None user."""
with app.test_request_context():
with patch("libs.login._get_user", return_value=None, autospec=True):
# When _get_user returns None, accessing attributes should fail
# or current_user should evaluate to falsy
try:
# Try to access an attribute that would exist on a real user
_ = current_user.id
pytest.fail("Should have raised AttributeError")
except AttributeError:
# This is expected when current_user is None
pass
mocker.patch.object(login_module, "_get_user", return_value=None)
def test_current_user_proxy_thread_safety(self, app: Flask):
"""Test that current_user proxy is thread-safe."""
import threading
with login_app.test_request_context():
with pytest.raises(AttributeError):
_ = current_user.id
results = {}
def check_user_in_thread(user_id: str, index: int):
with app.test_request_context():
mock_user = MockUser(user_id)
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
results[index] = current_user.id
class TestCurrentAccountWithTenant:
"""Test cases for current_account_with_tenant helper."""
# Create multiple threads with different users
threads = []
for i in range(5):
thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
threads.append(thread)
thread.start()
def test_returns_account_and_tenant_id(self, mocker: MockerFixture):
account = Account(name="Test User", email="test@example.com")
account._current_tenant = SimpleNamespace(id="tenant-123")
current_user_proxy = MagicMock()
current_user_proxy._get_current_object.return_value = account
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
# Wait for all threads to complete
for thread in threads:
thread.join()
user, tenant_id = login_module.current_account_with_tenant()
# Verify each thread got its own user
for i in range(5):
assert results[i] == f"user_{i}"
assert user is account
assert tenant_id == "tenant-123"
current_user_proxy._get_current_object.assert_called_once_with()
def test_raises_when_current_user_is_not_account(self, mocker: MockerFixture):
mocker.patch.object(login_module, "current_user", new=MockUser("test_user"))
with pytest.raises(ValueError, match="current_user must be an Account instance"):
login_module.current_account_with_tenant()
def test_raises_when_account_has_no_tenant(self, mocker: MockerFixture):
account = Account(name="Test User", email="test@example.com")
mocker.patch.object(login_module, "current_user", new=account)
with pytest.raises(AssertionError, match="tenant information should be loaded"):
login_module.current_account_with_tenant()

View File

@ -70,8 +70,11 @@ def service(mocker: MockerFixture) -> ModelLoadBalancingService:
# Arrange
provider_manager = MagicMock()
mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager)
model_assembly = SimpleNamespace(provider_manager=provider_manager, model_provider_factory=MagicMock())
mocker.patch("services.model_load_balancing_service.create_plugin_model_assembly", return_value=model_assembly)
svc = ModelLoadBalancingService()
svc.provider_manager = provider_manager
svc.model_assembly = model_assembly
svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign]
return svc
@ -667,6 +670,9 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_
assert mock_validate.call_count == 2
assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config
assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None
shared_model_provider_factory = service.model_assembly.model_provider_factory
assert mock_validate.call_args_list[0].kwargs["model_provider_factory"] is shared_model_provider_factory
assert mock_validate.call_args_list[1].kwargs["model_provider_factory"] is shared_model_provider_factory
def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt(
@ -709,10 +715,6 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
mock_factory = MagicMock()
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
mocker.patch(
"services.model_load_balancing_service.create_plugin_model_provider_factory",
return_value=mock_factory,
)
mock_encrypt = mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
@ -726,6 +728,7 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val
model="gpt-4o-mini",
credentials={"api_key": "plain"},
load_balancing_model_config=load_balancing_model_config,
model_provider_factory=mock_factory,
validate=True,
)
@ -744,10 +747,6 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
mock_factory = MagicMock()
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
mocker.patch(
"services.model_load_balancing_service.create_plugin_model_provider_factory",
return_value=mock_factory,
)
mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
@ -760,6 +759,7 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m
model_type=ModelType.LLM,
model="gpt-4o-mini",
credentials={"api_key": "plain"},
model_provider_factory=mock_factory,
validate=True,
)

View File

@ -25,6 +25,7 @@ from dify_graph.enums import (
)
from dify_graph.errors import WorkflowNodeRunFailedError
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
from dify_graph.variables.input_entities import VariableEntityType
@ -1545,7 +1546,10 @@ class TestWorkflowServiceCredentialValidation:
def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None:
"""If ModelManager raises any exception it must be wrapped into ValueError."""
# Arrange
with patch("core.model_manager.ModelManager.get_model_instance", side_effect=RuntimeError("no key")):
assembly = MagicMock()
assembly.model_manager.get_model_instance.side_effect = RuntimeError("no key")
with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly):
# Act + Assert
with pytest.raises(ValueError, match="Failed to validate LLM model configuration"):
service._validate_llm_model_config("tenant-1", "openai", "gpt-4")
@ -1558,30 +1562,30 @@ class TestWorkflowServiceCredentialValidation:
mock_configs = MagicMock()
mock_configs.get_models.return_value = [mock_model]
assembly = MagicMock()
assembly.provider_manager.get_configurations.return_value = mock_configs
with (
patch("core.model_manager.ModelManager.get_model_instance"),
patch("core.provider_manager.ProviderManager") as mock_pm_cls,
):
mock_pm_cls.return_value.get_configurations.return_value = mock_configs
with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly):
# Act
service._validate_llm_model_config("tenant-1", "openai", "gpt-4")
# Assert
mock_model.raise_for_status.assert_called_once()
assembly.model_manager.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="openai",
model_type=ModelType.LLM,
model="gpt-4",
)
def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None:
"""Test ValueError when model is not found in provider configurations."""
mock_configs = MagicMock()
mock_configs.get_models.return_value = [] # No models
assembly = MagicMock()
assembly.provider_manager.get_configurations.return_value = mock_configs
with (
patch("core.model_manager.ModelManager.get_model_instance"),
patch("core.provider_manager.ProviderManager") as mock_pm_cls,
):
mock_pm_cls.return_value.get_configurations.return_value = mock_configs
with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly):
# Act + Assert
with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"):
service._validate_llm_model_config("tenant-1", "openai", "gpt-4")