mirror of https://github.com/langgenius/dify.git
fix(model): preserve caller-scoped runtime assembly
This commit is contained in:
parent
57f9053b3a
commit
0b8a5e83db
|
|
@ -88,6 +88,7 @@ class ModelConfigResource(Resource):
|
|||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
|
|
@ -127,6 +128,7 @@ class ModelConfigResource(Resource):
|
|||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -140,6 +140,7 @@ class BaseAgentRunner(AppRunner):
|
|||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
user_id=self.user_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
assert tool_entity.entity.description
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections.abc import Mapping
|
|||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from models.model import AppModelConfigDict
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
|
@ -53,9 +53,12 @@ class ModelConfigManager:
|
|||
if not isinstance(config["model"], dict):
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# Keep provider discovery and provider-backed model listing on the same
|
||||
# request-scoped runtime so caller scope and provider caches stay aligned.
|
||||
assembly = create_plugin_model_assembly(tenant_id=tenant_id)
|
||||
|
||||
# model.provider
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
provider_entities = assembly.model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if "provider" not in config["model"]:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
|
@ -70,8 +73,7 @@ class ModelConfigManager:
|
|||
if "name" not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
models = assembly.provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"], model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -105,8 +105,8 @@ class ProviderConfiguration(BaseModel):
|
|||
"""Attach the already-composed runtime for request-bound call chains."""
|
||||
self._bound_model_runtime = model_runtime
|
||||
|
||||
def _get_model_provider_factory(self) -> ModelProviderFactory:
|
||||
"""Reuse a bound runtime when available, else fall back to tenant scope."""
|
||||
def get_model_provider_factory(self) -> ModelProviderFactory:
|
||||
"""Return a provider factory that preserves any request-bound runtime."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
|
@ -362,7 +362,7 @@ class ProviderConfiguration(BaseModel):
|
|||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
|
@ -921,7 +921,7 @@ class ProviderConfiguration(BaseModel):
|
|||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
|
@ -1407,7 +1407,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
|
|
@ -1416,7 +1416,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
Get model schema
|
||||
"""
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
return model_provider_factory.get_model_schema(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
|
@ -1518,7 +1518,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = self._get_model_provider_factory()
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
||||
|
||||
model_types: list[ModelType] = []
|
||||
|
|
|
|||
|
|
@ -277,18 +277,22 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def get_system_model_max_tokens(cls, tenant_id: str) -> int:
|
||||
def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int:
|
||||
"""
|
||||
get system model max tokens
|
||||
"""
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int:
|
||||
"""
|
||||
get prompt tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=tenant_id,
|
||||
prompt_messages=prompt_messages,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def invoke_system_model(
|
||||
|
|
@ -306,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||
tool_type=ToolProviderType.PLUGIN,
|
||||
tool_name="plugin",
|
||||
prompt_messages=prompt_messages,
|
||||
caller_user_id=user_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -313,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||
"""
|
||||
invoke summary
|
||||
"""
|
||||
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
|
||||
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id)
|
||||
content = payload.text
|
||||
|
||||
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
|
|
@ -332,6 +337,7 @@ Here is the extra instruction you need to follow:
|
|||
cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[UserPromptMessage(content=content)],
|
||||
user_id=user_id,
|
||||
)
|
||||
< max_tokens * 0.6
|
||||
):
|
||||
|
|
@ -344,6 +350,7 @@ Here is the extra instruction you need to follow:
|
|||
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
|
||||
UserPromptMessage(content=content),
|
||||
],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def summarize(content: str) -> str:
|
||||
|
|
@ -401,6 +408,7 @@ Here is the extra instruction you need to follow:
|
|||
cls.get_prompt_tokens(
|
||||
tenant_id=tenant.id,
|
||||
prompt_messages=[UserPromptMessage(content=result)],
|
||||
user_id=user_id,
|
||||
)
|
||||
> max_tokens * 0.7
|
||||
):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
|||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
|
||||
tool_type,
|
||||
tenant_id,
|
||||
provider,
|
||||
tool_name,
|
||||
tool_parameters,
|
||||
user_id=user_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
response = ToolEngine.generic_invoke(
|
||||
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
|
||||
|
|
|
|||
|
|
@ -3,12 +3,64 @@ from __future__ import annotations
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
|
||||
class PluginModelAssembly:
|
||||
"""Compose request-scoped model views on top of a single plugin runtime."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
_model_runtime: PluginModelRuntime | None
|
||||
_model_provider_factory: ModelProviderFactory | None
|
||||
_provider_manager: ProviderManager | None
|
||||
_model_manager: ModelManager | None
|
||||
|
||||
def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self._model_runtime = None
|
||||
self._model_provider_factory = None
|
||||
self._provider_manager = None
|
||||
self._model_manager = None
|
||||
|
||||
@property
|
||||
def model_runtime(self) -> PluginModelRuntime:
|
||||
if self._model_runtime is None:
|
||||
self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id)
|
||||
return self._model_runtime
|
||||
|
||||
@property
|
||||
def model_provider_factory(self) -> ModelProviderFactory:
|
||||
if self._model_provider_factory is None:
|
||||
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
|
||||
return self._model_provider_factory
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ProviderManager:
|
||||
if self._provider_manager is None:
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
self._provider_manager = ProviderManager(model_runtime=self.model_runtime)
|
||||
return self._provider_manager
|
||||
|
||||
@property
|
||||
def model_manager(self) -> ModelManager:
|
||||
if self._model_manager is None:
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
self._model_manager = ModelManager(provider_manager=self.provider_manager)
|
||||
return self._model_manager
|
||||
|
||||
|
||||
def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly:
|
||||
"""Create a request-scoped assembly that shares one plugin runtime across model views."""
|
||||
return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime:
|
||||
|
|
@ -24,22 +76,14 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -
|
|||
|
||||
def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory:
|
||||
"""Create a tenant-bound model provider factory for service flows."""
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
return ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
|
||||
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory
|
||||
|
||||
|
||||
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
|
||||
"""Create a tenant-bound provider manager for service flows."""
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
return ProviderManager(model_runtime=create_plugin_model_runtime(tenant_id=tenant_id, user_id=user_id))
|
||||
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager
|
||||
|
||||
|
||||
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
|
||||
"""Create a tenant-bound model manager for service flows."""
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
return ModelManager(
|
||||
provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id),
|
||||
)
|
||||
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager
|
||||
|
|
|
|||
|
|
@ -9,10 +9,14 @@ from core.tools.entities.tool_entities import ToolInvokeFrom
|
|||
|
||||
class ToolRuntime(BaseModel):
|
||||
"""
|
||||
Meta data of a tool call processing
|
||||
Meta data of a tool call processing.
|
||||
|
||||
``user_id`` is optional so read-only tooling flows can stay tenant-scoped,
|
||||
while execution paths may bind caller identity for model runtime lookups.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None = None
|
||||
tool_id: str | None = None
|
||||
invoke_from: InvokeFrom | None = None
|
||||
tool_invoke_from: ToolInvokeFrom | None = None
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ class BuiltinTool(Tool):
|
|||
tool_type=ToolProviderType.BUILT_IN,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
caller_user_id=self.runtime.user_id,
|
||||
)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
|
|
@ -69,6 +70,7 @@ class BuiltinTool(Tool):
|
|||
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
user_id=self.runtime.user_id,
|
||||
)
|
||||
|
||||
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
|
||||
|
|
@ -82,7 +84,9 @@ class BuiltinTool(Tool):
|
|||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
prompt_messages=prompt_messages,
|
||||
user_id=self.runtime.user_id,
|
||||
)
|
||||
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
|
|
|
|||
|
|
@ -184,6 +184,7 @@ class ToolManager:
|
|||
provider_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
credential_id: str | None = None,
|
||||
|
|
@ -195,6 +196,7 @@ class ToolManager:
|
|||
:param provider_id: the id of the provider
|
||||
:param tool_name: the name of the tool
|
||||
:param tenant_id: the tenant id
|
||||
:param user_id: the caller id bound to runtime-scoped model/tool lookups
|
||||
:param invoke_from: invoke from
|
||||
:param tool_invoke_from: the tool invoke from
|
||||
:param credential_id: the credential id
|
||||
|
|
@ -213,6 +215,7 @@ class ToolManager:
|
|||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
|
|
@ -321,6 +324,7 @@ class ToolManager:
|
|||
return builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
|
|
@ -338,6 +342,7 @@ class ToolManager:
|
|||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials=dict(encrypter.decrypt(credentials)),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
|
|
@ -361,6 +366,7 @@ class ToolManager:
|
|||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
|
|
@ -369,9 +375,21 @@ class ToolManager:
|
|||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||
runtime = getattr(plugin_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return plugin_tool
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||
runtime = getattr(mcp_tool, "runtime", None)
|
||||
if runtime is not None:
|
||||
runtime.user_id = user_id
|
||||
runtime.invoke_from = invoke_from
|
||||
runtime.tool_invoke_from = tool_invoke_from
|
||||
return mcp_tool
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
|
||||
|
|
@ -381,6 +399,7 @@ class ToolManager:
|
|||
tenant_id: str,
|
||||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
|
|
@ -392,6 +411,7 @@ class ToolManager:
|
|||
provider_id=agent_tool.provider_id,
|
||||
tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credential_id=agent_tool.credential_id,
|
||||
|
|
@ -423,6 +443,7 @@ class ToolManager:
|
|||
app_id: str,
|
||||
node_id: str,
|
||||
workflow_tool: WorkflowToolRuntimeSpec,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
|
|
@ -435,6 +456,7 @@ class ToolManager:
|
|||
provider_id=workflow_tool.provider_id,
|
||||
tool_name=workflow_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
credential_id=workflow_tool.credential_id,
|
||||
|
|
@ -467,6 +489,7 @@ class ToolManager:
|
|||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
credential_id: str | None = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
|
|
@ -477,6 +500,7 @@ class ToolManager:
|
|||
provider_id=provider,
|
||||
tool_name=tool_name,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||
credential_id=credential_id,
|
||||
|
|
|
|||
|
|
@ -34,11 +34,12 @@ class ModelInvocationUtils:
|
|||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
|
|
@ -60,13 +61,13 @@ class ModelInvocationUtils:
|
|||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int:
|
||||
"""
|
||||
calculate tokens from prompt messages and model parameters
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id)
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
|
||||
|
||||
if not model_instance:
|
||||
|
|
@ -79,7 +80,12 @@ class ModelInvocationUtils:
|
|||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
tool_type: ToolProviderType,
|
||||
tool_name: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
caller_user_id: str | None = None,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
|
@ -93,7 +99,7 @@ class ModelInvocationUtils:
|
|||
"""
|
||||
|
||||
# get model manager
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=caller_user_id or user_id)
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -337,6 +337,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
|||
self._run_context.app_id,
|
||||
node_id,
|
||||
self._build_tool_runtime_spec(node_data),
|
||||
self._run_context.user_id,
|
||||
self._run_context.invoke_from,
|
||||
variable_pool,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from sqlalchemy.orm import Session
|
|||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
|
|
@ -141,6 +141,7 @@ class AgentRuntimeSupport:
|
|||
tenant_id,
|
||||
app_id,
|
||||
entity,
|
||||
user_id,
|
||||
invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
|
|
@ -242,8 +243,8 @@ class AgentRuntimeSupport:
|
|||
user_id: str,
|
||||
value: dict[str, Any],
|
||||
) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
assembly = create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id)
|
||||
provider_model_bundle = assembly.provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
model_type=ModelType.LLM,
|
||||
|
|
@ -255,7 +256,7 @@ class AgentRuntimeSupport:
|
|||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance(
|
||||
model_instance = assembly.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
|
|
@ -34,8 +34,6 @@ def current_account_with_tenant():
|
|||
return user, user.current_tenant_id
|
||||
|
||||
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
|
|
|||
|
|
@ -205,6 +205,7 @@ class AppService:
|
|||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
|
|||
|
|
@ -10,13 +10,14 @@ from core.entities.provider_configuration import ProviderConfiguration
|
|||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_manager import LBModelManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import (
|
||||
ModelCredentialSchema,
|
||||
ProviderCredentialSchema,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CredentialSourceType
|
||||
|
|
@ -223,8 +224,8 @@ class ModelLoadBalancingService:
|
|||
:param config_id: load balancing config id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id)
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_configurations = provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
|
|
@ -496,8 +497,8 @@ class ModelLoadBalancingService:
|
|||
:param config_id: load balancing config id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id)
|
||||
assembly = create_plugin_model_assembly(tenant_id=tenant_id)
|
||||
provider_configurations = assembly.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
|
|
@ -533,6 +534,7 @@ class ModelLoadBalancingService:
|
|||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
model_provider_factory=assembly.model_provider_factory,
|
||||
)
|
||||
|
||||
def _custom_credentials_validate(
|
||||
|
|
@ -543,6 +545,7 @@ class ModelLoadBalancingService:
|
|||
model: str,
|
||||
credentials: dict,
|
||||
load_balancing_model_config: LoadBalancingModelConfig | None = None,
|
||||
model_provider_factory: ModelProviderFactory | None = None,
|
||||
validate: bool = True,
|
||||
):
|
||||
"""
|
||||
|
|
@ -553,6 +556,7 @@ class ModelLoadBalancingService:
|
|||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param load_balancing_model_config: load balancing model config
|
||||
:param model_provider_factory: model provider factory sharing the active runtime
|
||||
:param validate: validate credentials
|
||||
:return:
|
||||
"""
|
||||
|
|
@ -582,7 +586,8 @@ class ModelLoadBalancingService:
|
|||
credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
|
||||
|
||||
if validate:
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
if model_provider_factory is None:
|
||||
model_provider_factory = provider_configuration.get_model_provider_factory()
|
||||
if isinstance(credential_schemas, ModelCredentialSchema):
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=provider_configuration.provider.provider,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
|||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.trigger.constants import is_trigger_node_type
|
||||
|
|
@ -491,12 +491,15 @@ class WorkflowService:
|
|||
:raises ValueError: If the model configuration is invalid or credentials fail policy checks
|
||||
"""
|
||||
try:
|
||||
from core.model_manager import ModelManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
# Model instance resolution and provider status lookup must reuse the
|
||||
# same request-scoped runtime so validation does not silently split
|
||||
# provider discovery and credential reads across different caches.
|
||||
assembly = create_plugin_model_assembly(tenant_id=tenant_id)
|
||||
|
||||
# Get model instance to validate provider+model combination
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_manager.get_model_instance(
|
||||
assembly.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
|
||||
|
|
@ -505,8 +508,7 @@ class WorkflowService:
|
|||
# If it fails, an exception will be raised
|
||||
|
||||
# Additionally, check the model status to ensure it's ACTIVE
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_configurations = provider_manager.get_configurations(tenant_id)
|
||||
provider_configurations = assembly.provider_manager.get_configurations(tenant_id)
|
||||
models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM)
|
||||
|
||||
target_model = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def test_validate_and_set_defaults_reuses_single_model_assembly():
|
||||
provider_name = str(ModelProviderID("openai"))
|
||||
provider_entity = SimpleNamespace(provider=provider_name)
|
||||
model = SimpleNamespace(model="gpt-4o-mini", model_properties={ModelPropertyKey.MODE: "chat"})
|
||||
provider_configurations = SimpleNamespace(get_models=lambda **kwargs: [model])
|
||||
assembly = SimpleNamespace(
|
||||
model_provider_factory=SimpleNamespace(get_providers=lambda: [provider_entity]),
|
||||
provider_manager=SimpleNamespace(get_configurations=lambda tenant_id: provider_configurations),
|
||||
)
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o-mini",
|
||||
"completion_params": {"stop": []},
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly",
|
||||
return_value=assembly,
|
||||
) as mock_assembly:
|
||||
result, keys = ModelConfigManager.validate_and_set_defaults("tenant-1", config)
|
||||
|
||||
assert result["model"]["provider"] == provider_name
|
||||
assert result["model"]["mode"] == "chat"
|
||||
assert keys == ["model"]
|
||||
mock_assembly.assert_called_once_with(tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_convert_keeps_model_config_shape():
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o-mini",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0.3, "stop": ["END"]},
|
||||
}
|
||||
}
|
||||
|
||||
result = ModelConfigManager.convert(config)
|
||||
|
||||
assert result == ModelConfigEntity(
|
||||
provider="openai",
|
||||
model="gpt-4o-mini",
|
||||
mode="chat",
|
||||
parameters={"temperature": 0.3},
|
||||
stop=["END"],
|
||||
)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
|
||||
|
||||
def test_plugin_model_assembly_reuses_single_runtime_across_views():
|
||||
runtime = Mock(name="runtime")
|
||||
provider_factory = Mock(name="provider_factory")
|
||||
provider_manager = Mock(name="provider_manager")
|
||||
model_manager = Mock(name="model_manager")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.plugin.impl.model_runtime_factory.create_plugin_model_runtime",
|
||||
return_value=runtime,
|
||||
) as mock_runtime_factory,
|
||||
patch(
|
||||
"core.plugin.impl.model_runtime_factory.ModelProviderFactory",
|
||||
return_value=provider_factory,
|
||||
) as mock_provider_factory_cls,
|
||||
patch("core.provider_manager.ProviderManager", return_value=provider_manager) as mock_provider_manager_cls,
|
||||
patch("core.model_manager.ModelManager", return_value=model_manager) as mock_model_manager_cls,
|
||||
):
|
||||
assembly = create_plugin_model_assembly(tenant_id="tenant-1", user_id="user-1")
|
||||
|
||||
assert assembly.model_provider_factory is provider_factory
|
||||
assert assembly.provider_manager is provider_manager
|
||||
assert assembly.model_manager is model_manager
|
||||
assert assembly.model_provider_factory is provider_factory
|
||||
assert assembly.provider_manager is provider_manager
|
||||
assert assembly.model_manager is model_manager
|
||||
|
||||
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.entities.request import RequestInvokeSummary
|
||||
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
|
||||
def test_system_model_helpers_forward_user_id():
|
||||
with (
|
||||
patch(
|
||||
"core.plugin.backwards_invocation.model.ModelInvocationUtils.get_max_llm_context_tokens",
|
||||
return_value=4096,
|
||||
) as mock_max_tokens,
|
||||
patch(
|
||||
"core.plugin.backwards_invocation.model.ModelInvocationUtils.calculate_tokens",
|
||||
return_value=7,
|
||||
) as mock_prompt_tokens,
|
||||
):
|
||||
assert PluginModelBackwardsInvocation.get_system_model_max_tokens("tenant-1", user_id="user-1") == 4096
|
||||
assert (
|
||||
PluginModelBackwardsInvocation.get_prompt_tokens(
|
||||
"tenant-1",
|
||||
[UserPromptMessage(content="hello")],
|
||||
user_id="user-1",
|
||||
)
|
||||
== 7
|
||||
)
|
||||
|
||||
mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
mock_prompt_tokens.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_summary_uses_same_user_scope_for_token_helpers():
|
||||
tenant = SimpleNamespace(id="tenant-1")
|
||||
payload = RequestInvokeSummary(text="short", instruction="keep it concise")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
PluginModelBackwardsInvocation,
|
||||
"get_system_model_max_tokens",
|
||||
return_value=100,
|
||||
) as mock_max_tokens,
|
||||
patch.object(
|
||||
PluginModelBackwardsInvocation,
|
||||
"get_prompt_tokens",
|
||||
return_value=10,
|
||||
) as mock_prompt_tokens,
|
||||
):
|
||||
assert PluginModelBackwardsInvocation.invoke_summary("user-1", tenant, payload) == "short"
|
||||
|
||||
mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
mock_prompt_tokens.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
prompt_messages=[UserPromptMessage(content="short")],
|
||||
user_id="user-1",
|
||||
)
|
||||
|
|
@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool):
|
|||
yield self.create_text_message("ok")
|
||||
|
||||
|
||||
def _build_tool() -> _BuiltinDummyTool:
|
||||
def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
|
||||
parameters=[],
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER)
|
||||
return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime)
|
||||
|
||||
|
||||
|
|
@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type():
|
|||
|
||||
|
||||
def test_invoke_model_calls_model_invocation_utils_invoke():
|
||||
tool = _build_tool()
|
||||
tool = _build_tool(user_id="runtime-user")
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke:
|
||||
assert (
|
||||
tool.invoke_model(
|
||||
|
|
@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke():
|
|||
)
|
||||
== "result"
|
||||
)
|
||||
mock_invoke.assert_called_once()
|
||||
mock_invoke.assert_called_once_with(
|
||||
user_id="u1",
|
||||
tenant_id="tenant-1",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
tool_name="tool-a",
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
caller_user_id="runtime-user",
|
||||
)
|
||||
|
||||
|
||||
def test_get_max_tokens_returns_value():
|
||||
tool = _build_tool()
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096):
|
||||
tool = _build_tool(user_id="runtime-user")
|
||||
with patch(
|
||||
"core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096
|
||||
) as mock_get:
|
||||
assert tool.get_max_tokens() == 4096
|
||||
mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user")
|
||||
|
||||
|
||||
def test_get_prompt_tokens_returns_value():
|
||||
tool = _build_tool()
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7):
|
||||
tool = _build_tool(user_id="runtime-user")
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate:
|
||||
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
|
||||
mock_calculate.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
user_id="runtime-user",
|
||||
)
|
||||
|
||||
|
||||
def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing():
|
||||
tool = _build_tool()
|
||||
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate:
|
||||
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
|
||||
|
||||
mock_calculate.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_none_raises():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import calendar
|
||||
import math
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
|
@ -98,7 +100,13 @@ def test_timezone_conversion_tool():
|
|||
def test_weekday_tool():
|
||||
weekday_tool = _build_builtin_tool(WeekdayTool)
|
||||
valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text
|
||||
assert "January 1, 2024" in valid
|
||||
expected_date = date(2024, 1, 1)
|
||||
expected_message = (
|
||||
f"{calendar.month_name[expected_date.month]} "
|
||||
f"{expected_date.day}, {expected_date.year} "
|
||||
f"is {calendar.day_name[expected_date.weekday()]}."
|
||||
)
|
||||
assert valid == expected_message
|
||||
invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[
|
||||
0
|
||||
].message.text
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType
|
|||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
|
@ -421,7 +422,7 @@ def test_get_agent_runtime_apply_runtime_parameters():
|
|||
tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter])
|
||||
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime):
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime:
|
||||
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}):
|
||||
manager = Mock()
|
||||
manager.decrypt_tool_parameters.return_value = {"query": "decrypted"}
|
||||
|
|
@ -437,12 +438,23 @@ def test_get_agent_runtime_apply_runtime_parameters():
|
|||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
agent_tool=agent_tool,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
variable_pool=None,
|
||||
)
|
||||
|
||||
assert result is tool_runtime
|
||||
assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted"
|
||||
mock_get_tool_runtime.assert_called_once_with(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credential_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_runtime_apply_runtime_parameters():
|
||||
|
|
@ -463,7 +475,7 @@ def test_get_workflow_runtime_apply_runtime_parameters():
|
|||
)
|
||||
tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter])
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2):
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime:
|
||||
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}):
|
||||
manager = Mock()
|
||||
manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"}
|
||||
|
|
@ -473,12 +485,23 @@ def test_get_workflow_runtime_apply_runtime_parameters():
|
|||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
workflow_tool=workflow_tool,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
variable_pool=None,
|
||||
)
|
||||
|
||||
assert workflow_result is tool_runtime2
|
||||
assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec"
|
||||
mock_get_tool_runtime.assert_called_once_with(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
credential_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agent_runtime_raises_when_runtime_missing():
|
||||
|
|
@ -520,17 +543,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters():
|
|||
tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param])
|
||||
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity):
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime:
|
||||
result = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type=ToolProviderType.API,
|
||||
tenant_id="tenant-1",
|
||||
provider="api-1",
|
||||
tool_name="search",
|
||||
tool_parameters={"q": "hello", "llm": "ignore"},
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result is tool_entity
|
||||
assert tool_entity.runtime.runtime_parameters == {"q": "hello"}
|
||||
mock_get_tool_runtime.assert_called_once_with(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||
credential_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_hardcoded_provider_icon_success():
|
||||
|
|
|
|||
|
|
@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat
|
|||
manager = Mock()
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
|
||||
if error_match:
|
||||
with pytest.raises(InvokeModelError, match=error_match):
|
||||
ModelInvocationUtils.get_max_llm_context_tokens("tenant")
|
||||
ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1")
|
||||
else:
|
||||
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected
|
||||
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected
|
||||
|
||||
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1")
|
||||
|
||||
|
||||
def test_calculate_tokens_handles_missing_model():
|
||||
manager = Mock()
|
||||
manager.get_default_model_instance.return_value = None
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
|
||||
with pytest.raises(InvokeModelError, match="Model not found"):
|
||||
ModelInvocationUtils.calculate_tokens("tenant", [])
|
||||
mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None)
|
||||
|
||||
|
||||
def test_invoke_success_and_error_mappings():
|
||||
|
|
@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings():
|
|||
|
||||
db_mock = SimpleNamespace(session=Mock())
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
|
||||
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
|
||||
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
|
||||
response = ModelInvocationUtils.invoke(
|
||||
|
|
@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings():
|
|||
tool_type="builtin",
|
||||
tool_name="tool-a",
|
||||
prompt_messages=[],
|
||||
caller_user_id="caller-1",
|
||||
)
|
||||
|
||||
assert response.message.content == "ok"
|
||||
assert db_mock.session.add.call_count == 1
|
||||
assert db_mock.session.commit.call_count == 2
|
||||
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected):
|
|||
|
||||
db_mock = SimpleNamespace(session=Mock())
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
|
||||
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
|
||||
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
|
||||
with pytest.raises(InvokeModelError, match=expected):
|
||||
|
|
@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected):
|
|||
tool_name="tool-a",
|
||||
prompt_messages=[],
|
||||
)
|
||||
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
def test_fetch_model_reuses_single_model_assembly():
|
||||
provider_configuration = SimpleNamespace(
|
||||
get_current_credentials=Mock(return_value={"api_key": "x"}),
|
||||
provider=SimpleNamespace(provider="openai"),
|
||||
)
|
||||
model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value="schema"))
|
||||
provider_model_bundle = SimpleNamespace(
|
||||
configuration=provider_configuration,
|
||||
model_type_instance=model_type_instance,
|
||||
)
|
||||
model_instance = Mock()
|
||||
assembly = SimpleNamespace(
|
||||
provider_manager=Mock(),
|
||||
model_manager=Mock(),
|
||||
)
|
||||
assembly.provider_manager.get_provider_model_bundle.return_value = provider_model_bundle
|
||||
assembly.model_manager.get_model_instance.return_value = model_instance
|
||||
|
||||
with patch(
|
||||
"core.workflow.nodes.agent.runtime_support.create_plugin_model_assembly",
|
||||
return_value=assembly,
|
||||
) as mock_assembly:
|
||||
resolved_instance, resolved_schema = AgentRuntimeSupport().fetch_model(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
value={"provider": "openai", "model": "gpt-4o-mini", "model_type": "llm"},
|
||||
)
|
||||
|
||||
assert resolved_instance is model_instance
|
||||
assert resolved_schema == "schema"
|
||||
mock_assembly.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
assembly.provider_manager.get_provider_model_bundle.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
assembly.model_manager.get_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
)
|
||||
|
|
@ -1,10 +1,14 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask_login import LoginManager, UserMixin
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from libs.login import _get_user, current_user, login_required
|
||||
import libs.login as login_module
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
|
||||
|
||||
class MockUser(UserMixin):
|
||||
|
|
@ -15,229 +19,214 @@ class MockUser(UserMixin):
|
|||
self._is_authenticated = is_authenticated
|
||||
|
||||
@property
|
||||
def is_authenticated(self):
|
||||
def is_authenticated(self) -> bool:
|
||||
return self._is_authenticated
|
||||
|
||||
|
||||
def mock_csrf_check(*args, **kwargs):
|
||||
return
|
||||
@pytest.fixture
|
||||
def login_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(_user_id: str):
|
||||
return None
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def csrf_check(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch.object(login_module, "check_csrf_token")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ensure_sync_spy(login_app: Flask, mocker: MockerFixture) -> MagicMock:
|
||||
def _ensure_sync(func):
|
||||
return lambda *args, **kwargs: func(*args, **kwargs)
|
||||
|
||||
return mocker.patch.object(login_app, "ensure_sync", side_effect=_ensure_sync)
|
||||
|
||||
|
||||
class TestLoginRequired:
|
||||
"""Test cases for login_required decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def setup_app(self, app: Flask):
|
||||
"""Set up Flask app with login manager."""
|
||||
# Initialize login manager
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
# Mock unauthorized handler
|
||||
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
|
||||
|
||||
# Add a dummy user loader to prevent exceptions
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id):
|
||||
return None
|
||||
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
|
||||
def test_authenticated_user_can_access_protected_view(
|
||||
self, login_app: Flask, csrf_check: MagicMock, ensure_sync_spy: MagicMock, mocker: MockerFixture
|
||||
):
|
||||
"""Test that authenticated users can access protected views."""
|
||||
|
||||
@login_required
|
||||
@login_module.login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock authenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
get_user = mocker.patch.object(login_module, "_get_user", return_value=mock_user)
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that unauthenticated users are redirected."""
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
csrf_check.assert_called_once()
|
||||
assert csrf_check.call_args.args[0].method == "GET"
|
||||
assert csrf_check.call_args.args[1] == "test_user"
|
||||
|
||||
@login_required
|
||||
assert result == "Protected content"
|
||||
get_user.assert_called_once_with()
|
||||
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
|
||||
login_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("resolved_user", "description"),
|
||||
[
|
||||
pytest.param(None, "missing user", id="missing-user"),
|
||||
pytest.param(MockUser("test_user", is_authenticated=False), "unauthenticated user", id="unauthenticated"),
|
||||
],
|
||||
)
|
||||
def test_unauthorized_access_returns_login_manager_response(
|
||||
self,
|
||||
login_app: Flask,
|
||||
csrf_check: MagicMock,
|
||||
ensure_sync_spy: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
resolved_user: MockUser | None,
|
||||
description: str,
|
||||
):
|
||||
"""Test that missing or unauthenticated users are redirected."""
|
||||
|
||||
@login_module.login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Unauthorized"
|
||||
setup_app.login_manager.unauthorized.assert_called_once()
|
||||
get_user = mocker.patch.object(login_module, "_get_user", return_value=resolved_user)
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
|
||||
"""Test that LOGIN_DISABLED config bypasses authentication."""
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
|
||||
@login_required
|
||||
assert result == "Unauthorized", description
|
||||
get_user.assert_called_once_with()
|
||||
login_app.login_manager.unauthorized.assert_called_once_with()
|
||||
csrf_check.assert_not_called()
|
||||
ensure_sync_spy.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "login_disabled"),
|
||||
[
|
||||
pytest.param("OPTIONS", False, id="options"),
|
||||
pytest.param("GET", True, id="login-disabled"),
|
||||
],
|
||||
)
|
||||
def test_bypass_paths_skip_authentication_and_csrf(
|
||||
self,
|
||||
login_app: Flask,
|
||||
csrf_check: MagicMock,
|
||||
ensure_sync_spy: MagicMock,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mocker: MockerFixture,
|
||||
method: str,
|
||||
login_disabled: bool,
|
||||
):
|
||||
"""Test that bypass conditions skip auth lookup, CSRF, and unauthorized handling."""
|
||||
|
||||
@login_module.login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user and LOGIN_DISABLED
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
with patch("libs.login.dify_config", autospec=True) as mock_config:
|
||||
mock_config.LOGIN_DISABLED = True
|
||||
get_user = mocker.patch.object(login_module, "_get_user")
|
||||
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled)
|
||||
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
with login_app.test_request_context(method=method):
|
||||
result = protected_view()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_options_request_bypasses_authentication(self, setup_app: Flask):
|
||||
"""Test that OPTIONS requests are exempt from authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context(method="OPTIONS"):
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_flask_2_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 2.x compatibility with ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Mock Flask 2.x ensure_sync
|
||||
setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Synced content"
|
||||
setup_app.ensure_sync.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_flask_1_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 1.x compatibility without ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Remove ensure_sync to simulate Flask 1.x
|
||||
if hasattr(setup_app, "ensure_sync"):
|
||||
del setup_app.ensure_sync
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
assert result == "Protected content"
|
||||
get_user.assert_not_called()
|
||||
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
|
||||
csrf_check.assert_not_called()
|
||||
login_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test cases for _get_user function."""
|
||||
|
||||
def test_get_user_returns_user_from_g(self, app: Flask):
|
||||
def test_get_user_returns_user_from_g(self, login_app: Flask):
|
||||
"""Test that _get_user returns user from g._login_user."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
with app.test_request_context():
|
||||
with login_app.test_request_context():
|
||||
g._login_user = mock_user
|
||||
user = _get_user()
|
||||
user = login_module._get_user()
|
||||
assert user == mock_user
|
||||
assert user.id == "test_user"
|
||||
|
||||
def test_get_user_loads_user_if_not_in_g(self, app: Flask):
|
||||
def test_get_user_loads_user_if_not_in_g(self, login_app: Flask):
|
||||
"""Test that _get_user loads user if not already in g."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
# Mock login manager
|
||||
login_manager = MagicMock()
|
||||
login_manager._load_user = MagicMock()
|
||||
app.login_manager = login_manager
|
||||
def _load_user() -> None:
|
||||
g._login_user = mock_user
|
||||
|
||||
with app.test_request_context():
|
||||
# Simulate _load_user setting g._login_user
|
||||
def side_effect():
|
||||
g._login_user = mock_user
|
||||
login_app.login_manager._load_user = MagicMock(side_effect=_load_user)
|
||||
|
||||
login_manager._load_user.side_effect = side_effect
|
||||
with login_app.test_request_context():
|
||||
user = login_module._get_user()
|
||||
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
login_manager._load_user.assert_called_once()
|
||||
assert user == mock_user
|
||||
login_app.login_manager._load_user.assert_called_once_with()
|
||||
|
||||
def test_get_user_returns_none_without_request_context(self, app: Flask):
|
||||
def test_get_user_returns_none_without_request_context(self):
|
||||
"""Test that _get_user returns None outside request context."""
|
||||
# Outside of request context
|
||||
user = _get_user()
|
||||
user = login_module._get_user()
|
||||
assert user is None
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
"""Test cases for current_user proxy."""
|
||||
|
||||
def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
|
||||
def test_current_user_proxy_returns_authenticated_user(self, login_app: Flask, mocker: MockerFixture):
|
||||
"""Test that current_user proxy returns authenticated user."""
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
mocker.patch.object(login_module, "_get_user", return_value=mock_user)
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
assert current_user.id == "test_user"
|
||||
assert current_user.is_authenticated is True
|
||||
with login_app.test_request_context():
|
||||
assert current_user.id == "test_user"
|
||||
assert current_user.is_authenticated is True
|
||||
|
||||
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
|
||||
def test_current_user_proxy_raises_attribute_error_when_no_user(self, login_app: Flask, mocker: MockerFixture):
|
||||
"""Test that current_user proxy handles None user."""
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=None, autospec=True):
|
||||
# When _get_user returns None, accessing attributes should fail
|
||||
# or current_user should evaluate to falsy
|
||||
try:
|
||||
# Try to access an attribute that would exist on a real user
|
||||
_ = current_user.id
|
||||
pytest.fail("Should have raised AttributeError")
|
||||
except AttributeError:
|
||||
# This is expected when current_user is None
|
||||
pass
|
||||
mocker.patch.object(login_module, "_get_user", return_value=None)
|
||||
|
||||
def test_current_user_proxy_thread_safety(self, app: Flask):
|
||||
"""Test that current_user proxy is thread-safe."""
|
||||
import threading
|
||||
with login_app.test_request_context():
|
||||
with pytest.raises(AttributeError):
|
||||
_ = current_user.id
|
||||
|
||||
results = {}
|
||||
|
||||
def check_user_in_thread(user_id: str, index: int):
|
||||
with app.test_request_context():
|
||||
mock_user = MockUser(user_id)
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
results[index] = current_user.id
|
||||
class TestCurrentAccountWithTenant:
|
||||
"""Test cases for current_account_with_tenant helper."""
|
||||
|
||||
# Create multiple threads with different users
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
def test_returns_account_and_tenant_id(self, mocker: MockerFixture):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123")
|
||||
current_user_proxy = MagicMock()
|
||||
current_user_proxy._get_current_object.return_value = account
|
||||
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
user, tenant_id = login_module.current_account_with_tenant()
|
||||
|
||||
# Verify each thread got its own user
|
||||
for i in range(5):
|
||||
assert results[i] == f"user_{i}"
|
||||
assert user is account
|
||||
assert tenant_id == "tenant-123"
|
||||
current_user_proxy._get_current_object.assert_called_once_with()
|
||||
|
||||
def test_raises_when_current_user_is_not_account(self, mocker: MockerFixture):
|
||||
mocker.patch.object(login_module, "current_user", new=MockUser("test_user"))
|
||||
|
||||
with pytest.raises(ValueError, match="current_user must be an Account instance"):
|
||||
login_module.current_account_with_tenant()
|
||||
|
||||
def test_raises_when_account_has_no_tenant(self, mocker: MockerFixture):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
mocker.patch.object(login_module, "current_user", new=account)
|
||||
|
||||
with pytest.raises(AssertionError, match="tenant information should be loaded"):
|
||||
login_module.current_account_with_tenant()
|
||||
|
|
|
|||
|
|
@ -70,8 +70,11 @@ def service(mocker: MockerFixture) -> ModelLoadBalancingService:
|
|||
# Arrange
|
||||
provider_manager = MagicMock()
|
||||
mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager)
|
||||
model_assembly = SimpleNamespace(provider_manager=provider_manager, model_provider_factory=MagicMock())
|
||||
mocker.patch("services.model_load_balancing_service.create_plugin_model_assembly", return_value=model_assembly)
|
||||
svc = ModelLoadBalancingService()
|
||||
svc.provider_manager = provider_manager
|
||||
svc.model_assembly = model_assembly
|
||||
svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign]
|
||||
return svc
|
||||
|
||||
|
|
@ -667,6 +670,9 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_
|
|||
assert mock_validate.call_count == 2
|
||||
assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config
|
||||
assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None
|
||||
shared_model_provider_factory = service.model_assembly.model_provider_factory
|
||||
assert mock_validate.call_args_list[0].kwargs["model_provider_factory"] is shared_model_provider_factory
|
||||
assert mock_validate.call_args_list[1].kwargs["model_provider_factory"] is shared_model_provider_factory
|
||||
|
||||
|
||||
def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt(
|
||||
|
|
@ -709,10 +715,6 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val
|
|||
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
)
|
||||
mock_encrypt = mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
|
|
@ -726,6 +728,7 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val
|
|||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "plain"},
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
model_provider_factory=mock_factory,
|
||||
validate=True,
|
||||
)
|
||||
|
||||
|
|
@ -744,10 +747,6 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m
|
|||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
|
|
@ -760,6 +759,7 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m
|
|||
model_type=ModelType.LLM,
|
||||
model="gpt-4o-mini",
|
||||
credentials={"api_key": "plain"},
|
||||
model_provider_factory=mock_factory,
|
||||
validate=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from dify_graph.enums import (
|
|||
)
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
|
||||
from dify_graph.variables.input_entities import VariableEntityType
|
||||
|
|
@ -1545,7 +1546,10 @@ class TestWorkflowServiceCredentialValidation:
|
|||
def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None:
|
||||
"""If ModelManager raises any exception it must be wrapped into ValueError."""
|
||||
# Arrange
|
||||
with patch("core.model_manager.ModelManager.get_model_instance", side_effect=RuntimeError("no key")):
|
||||
assembly = MagicMock()
|
||||
assembly.model_manager.get_model_instance.side_effect = RuntimeError("no key")
|
||||
|
||||
with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly):
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Failed to validate LLM model configuration"):
|
||||
service._validate_llm_model_config("tenant-1", "openai", "gpt-4")
|
||||
|
|
@ -1558,30 +1562,30 @@ class TestWorkflowServiceCredentialValidation:
|
|||
|
||||
mock_configs = MagicMock()
|
||||
mock_configs.get_models.return_value = [mock_model]
|
||||
assembly = MagicMock()
|
||||
assembly.provider_manager.get_configurations.return_value = mock_configs
|
||||
|
||||
with (
|
||||
patch("core.model_manager.ModelManager.get_model_instance"),
|
||||
patch("core.provider_manager.ProviderManager") as mock_pm_cls,
|
||||
):
|
||||
mock_pm_cls.return_value.get_configurations.return_value = mock_configs
|
||||
|
||||
with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly):
|
||||
# Act
|
||||
service._validate_llm_model_config("tenant-1", "openai", "gpt-4")
|
||||
|
||||
# Assert
|
||||
mock_model.raise_for_status.assert_called_once()
|
||||
assembly.model_manager.get_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None:
|
||||
"""Test ValueError when model is not found in provider configurations."""
|
||||
mock_configs = MagicMock()
|
||||
mock_configs.get_models.return_value = [] # No models
|
||||
assembly = MagicMock()
|
||||
assembly.provider_manager.get_configurations.return_value = mock_configs
|
||||
|
||||
with (
|
||||
patch("core.model_manager.ModelManager.get_model_instance"),
|
||||
patch("core.provider_manager.ProviderManager") as mock_pm_cls,
|
||||
):
|
||||
mock_pm_cls.return_value.get_configurations.return_value = mock_configs
|
||||
|
||||
with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly):
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"):
|
||||
service._validate_llm_model_config("tenant-1", "openai", "gpt-4")
|
||||
|
|
|
|||
Loading…
Reference in New Issue