feat(skill): tool switcher for llm node

- Added an `enabled` field to `DifyCliToolConfig` and `ToolDependency` to manage tool activation status.
- Updated `DifyCliConfig` to handle tool dependencies more effectively, ensuring only enabled tools are processed.
- Refactored `SkillCompiler` to utilize `tool_id` for better identification of tools and improved handling of disabled tools.
- Introduced a new method `_extract_disabled_tools` in `LLMNode` to streamline the extraction of disabled tools from node data.
- Enhanced metadata parsing to account for tool enablement, improving overall tool management.
This commit is contained in:
Harry 2026-01-29 01:21:09 +08:00
parent 23ee9e618b
commit 0495dc5085
9 changed files with 91 additions and 41 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -62,6 +62,7 @@ class DifyCliEnvConfig(BaseModel):
class DifyCliToolConfig(BaseModel):
provider_type: str
enabled: bool = True
identity: dict[str, Any]
description: dict[str, Any]
parameters: list[dict[str, Any]]
@ -81,8 +82,8 @@ class DifyCliToolConfig(BaseModel):
@classmethod
def create_from_tool(cls, tool: Tool) -> DifyCliToolConfig:
return cls(
provider_type=cls.transform_provider_type(tool.tool_provider_type()),
identity=to_json(tool.entity.identity),
provider_type=cls.transform_provider_type(tool.tool_provider_type()),
description=to_json(tool.entity.description),
parameters=[cls.transform_parameter(parameter) for parameter in tool.entity.parameters],
)
@ -137,15 +138,18 @@ class DifyCliConfig(BaseModel):
cli_api_url = dify_config.CLI_API_URL
tools: list[Tool] = []
tools: list[DifyCliToolConfig] = []
for dependency in tool_deps.dependencies:
tool = ToolManager.get_tool_runtime(
tenant_id=tenant_id,
provider_type=dependency.type,
provider_id=dependency.provider,
tool_name=dependency.tool_name,
invoke_from=InvokeFrom.AGENT,
tool = DifyCliToolConfig.create_from_tool(
ToolManager.get_tool_runtime(
tenant_id=tenant_id,
provider_type=dependency.type,
provider_id=dependency.provider,
tool_name=dependency.tool_name,
invoke_from=InvokeFrom.AGENT,
)
)
tool.enabled = dependency.enabled
tools.append(tool)
return cls(
@ -156,7 +160,7 @@ class DifyCliConfig(BaseModel):
cli_api_secret=session.secret,
),
tool_references=[DifyCliToolReference.create_from_tool_reference(ref) for ref in tool_deps.references],
tools=[DifyCliToolConfig.create_from_tool(tool) for tool in tools],
tools=tools,
)

View File

@ -22,6 +22,10 @@ class ToolConfiguration(BaseModel):
return {field.id: field.value for field in self.fields if field.value is not None}
def create_tool_id(provider: str, tool_name: str) -> str:
return f"{provider}.{tool_name}"
class ToolReference(BaseModel):
model_config = ConfigDict(extra="forbid")
@ -33,6 +37,12 @@ class ToolReference(BaseModel):
credential_id: str | None = None
configuration: ToolConfiguration | None = None
def reference_id(self) -> str:
return f"{self.provider}.{self.tool_name}.{self.uuid}"
def tool_id(self) -> str:
return f"{self.provider}.{self.tool_name}"
class FileReference(BaseModel):
model_config = ConfigDict(extra="forbid")

View File

@ -10,6 +10,10 @@ class ToolDependency(BaseModel):
type: ToolProviderType
provider: str
tool_name: str
enabled: bool = True
def tool_id(self) -> str:
return f"{self.provider}.{self.tool_name}"
class ToolDependencies(BaseModel):
@ -57,3 +61,18 @@ class ToolDependencies(BaseModel):
dependencies=list(dep_map.values()),
references=list(ref_map.values()),
)
def remove_tools(self, tools: list[ToolDependency]) -> "ToolDependencies":
tool_keys = {f"{tool.provider}.{tool.tool_name}" for tool in tools}
return ToolDependencies(
dependencies=[
dependency
for dependency in self.dependencies
if f"{dependency.provider}.{dependency.tool_name}" not in tool_keys
],
references=[
reference
for reference in self.references
if f"{reference.provider}.{reference.tool_name}" not in tool_keys
],
)

View File

@ -9,7 +9,13 @@ from core.skill.entities.asset_references import AssetReferences
from core.skill.entities.skill_bundle import SkillBundle
from core.skill.entities.skill_bundle_entry import SkillBundleEntry, SourceInfo
from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.skill_metadata import FileReference, SkillMetadata, ToolConfiguration, ToolReference
from core.skill.entities.skill_metadata import (
FileReference,
SkillMetadata,
ToolConfiguration,
ToolReference,
create_tool_id,
)
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
from core.tools.entities.tool_entities import ToolProviderType
@ -164,7 +170,7 @@ class SkillCompiler:
direct_refs: dict[str, ToolReference] = {}
for tool_ref in metadata.tools.values():
key = f"{tool_ref.provider}.{tool_ref.tool_name}"
key = tool_ref.tool_id()
if key not in direct_tools:
direct_tools[key] = ToolDependency(
type=tool_ref.type,
@ -207,9 +213,7 @@ class SkillCompiler:
continue
# Collect current tools and files
current_tools: dict[str, ToolDependency] = {
f"{d.provider}.{d.tool_name}": d for d in entry.tools.dependencies
}
current_tools: dict[str, ToolDependency] = {d.tool_id(): d for d in entry.tools.dependencies}
current_refs: dict[str, ToolReference] = {r.uuid: r for r in entry.tools.references}
current_files: dict[str, FileReference] = {f.asset_id: f for f in entry.files.references}
@ -224,7 +228,7 @@ class SkillCompiler:
continue
for tool_dep in dep_entry.tools.dependencies:
key = f"{tool_dep.provider}.{tool_dep.tool_name}"
key = tool_dep.tool_id()
if key not in current_tools:
current_tools[key] = tool_dep
@ -324,7 +328,7 @@ class SkillCompiler:
all_refs: dict[str, ToolReference] = {}
for tool_ref in metadata.tools.values():
key = f"{tool_ref.provider}.{tool_ref.tool_name}"
key = tool_ref.tool_id()
if key not in all_tools:
all_tools[key] = ToolDependency(
type=tool_ref.type,
@ -335,7 +339,7 @@ class SkillCompiler:
for dep in direct_deps:
for tool_dep in dep.tools.dependencies:
key = f"{tool_dep.provider}.{tool_dep.tool_name}"
key = tool_dep.tool_id()
if key not in all_tools:
all_tools[key] = tool_dep
@ -359,7 +363,7 @@ class SkillCompiler:
tool_id = match.group(1)
tool_ref: ToolReference | None = metadata.tools.get(tool_id)
if not tool_ref:
return f"[Tool not found: {tool_id}]"
return f"[Tool not found or disabled: {tool_id}]"
if not tool_ref.enabled:
return ""
return self._tool_resolver.resolve(tool_ref)
@ -372,7 +376,7 @@ class SkillCompiler:
tool_id = tool_match.group(1)
tool_ref: ToolReference | None = metadata.tools.get(tool_id)
if not tool_ref:
enabled_renders.append(f"[Tool not found: {tool_id}]")
enabled_renders.append(f"[Tool not found or disabled: {tool_id}]")
continue
if not tool_ref.enabled:
continue
@ -387,32 +391,32 @@ class SkillCompiler:
content = self._config.tool_pattern.sub(replace_tool, content)
return content
def _parse_metadata(self, content: str, raw_metadata: Mapping[str, Any]) -> SkillMetadata:
def _parse_metadata(
self, content: str, raw_metadata: Mapping[str, Any], disabled_tools: list[ToolDependency] = []
) -> SkillMetadata:
tools_raw = dict(raw_metadata.get("tools", {}))
tools: dict[str, ToolReference] = {}
disabled_tools_set = {tool.tool_id() for tool in disabled_tools}
tool_iter = re.finditer(r"§\[tool\]\.\[([^\]]+)\]\.\[([^\]]+)\]\.\[([^\]]+)\", content)
for match in tool_iter:
provider, name, uuid = match.group(1), match.group(2), match.group(3)
if uuid in tools_raw:
meta = tools_raw[uuid]
if isinstance(meta, ToolReference):
tools[uuid] = meta
elif isinstance(meta, dict):
meta_dict = cast(dict[str, Any], meta)
tool_type_str = cast(str | None, meta_dict.get("type"))
if tool_type_str:
tools[uuid] = ToolReference(
uuid=uuid,
type=ToolProviderType.value_of(tool_type_str),
provider=provider,
tool_name=name,
enabled=cast(bool, meta_dict.get("enabled", True)),
credential_id=cast(str | None, meta_dict.get("credential_id")),
configuration=ToolConfiguration.model_validate(meta_dict.get("configuration", {}))
if meta_dict.get("configuration")
else None,
)
meta_dict = cast(dict[str, Any], meta)
type = cast(str, meta_dict.get("type"))
if create_tool_id(provider, name) in disabled_tools_set:
continue
tools[uuid] = ToolReference(
uuid=uuid,
type=ToolProviderType.value_of(type),
provider=provider,
tool_name=name,
enabled=cast(bool, meta_dict.get("enabled", True)),
credential_id=cast(str | None, meta_dict.get("credential_id")),
configuration=ToolConfiguration.model_validate(meta_dict.get("configuration", {}))
if meta_dict.get("configuration")
else None,
)
parsed_files: list[FileReference] = []
file_iter = re.finditer(r"§\[file\]\.\[([^\]]+)\]\.\[([^\]]+)\", content)

View File

@ -59,7 +59,7 @@ from core.sandbox.entities.config import AppAssets
from core.skill.constants import SkillAttrs
from core.skill.entities.skill_bundle import SkillBundle
from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.tool_dependencies import ToolDependencies
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
from core.skill.skill_compiler import SkillCompiler
from core.tools.__base.tool import Tool
from core.tools.signature import sign_upload_file
@ -301,7 +301,7 @@ class LLMNode(Node[LLMNodeData]):
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("computer use is enabled but no sandbox found")
tool_dependencies = self._extract_tool_dependencies()
tool_dependencies: ToolDependencies | None = self._extract_tool_dependencies()
generator = self._invoke_llm_with_sandbox(
sandbox=sandbox,
model_instance=model_instance,
@ -1822,6 +1822,14 @@ class LLMNode(Node[LLMNodeData]):
generation_data,
)
def _extract_disabled_tools(self) -> dict[str, ToolDependency]:
tools = [
ToolDependency(type=tool.type, provider=tool.provider, tool_name=tool.tool_name)
for tool in self.node_data.tool_settings
if not tool.enabled
]
return {tool.tool_id(): tool for tool in tools}
def _extract_tool_dependencies(self) -> ToolDependencies | None:
"""Extract tool artifact from prompt template."""
@ -1845,7 +1853,12 @@ class LLMNode(Node[LLMNodeData]):
if len(tool_deps_list) == 0:
return None
return reduce(lambda x, y: x.merge(y), tool_deps_list)
disabled_tools = self._extract_disabled_tools()
tool_dependencies = reduce(lambda x, y: x.merge(y), tool_deps_list)
for tool in tool_dependencies.dependencies:
if tool.tool_id() in disabled_tools:
tool.enabled = False
return tool_dependencies
def _invoke_llm_with_tools(
self,