mirror of https://github.com/langgenius/dify.git
feat(skill): tool switcher for llm node
- Added an `enabled` field to `DifyCliToolConfig` and `ToolDependency` to manage tool activation status. - Updated `DifyCliConfig` to handle tool dependencies more effectively, ensuring only enabled tools are processed. - Refactored `SkillCompiler` to utilize `tool_id` for better identification of tools and improved handling of disabled tools. - Introduced a new method `_extract_disabled_tools` in `LLMNode` to streamline the extraction of disabled tools from node data. - Enhanced metadata parsing to account for tool enablement, improving overall tool management.
This commit is contained in:
parent
23ee9e618b
commit
0495dc5085
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -62,6 +62,7 @@ class DifyCliEnvConfig(BaseModel):
|
|||
|
||||
class DifyCliToolConfig(BaseModel):
|
||||
provider_type: str
|
||||
enabled: bool = True
|
||||
identity: dict[str, Any]
|
||||
description: dict[str, Any]
|
||||
parameters: list[dict[str, Any]]
|
||||
|
|
@ -81,8 +82,8 @@ class DifyCliToolConfig(BaseModel):
|
|||
@classmethod
|
||||
def create_from_tool(cls, tool: Tool) -> DifyCliToolConfig:
|
||||
return cls(
|
||||
provider_type=cls.transform_provider_type(tool.tool_provider_type()),
|
||||
identity=to_json(tool.entity.identity),
|
||||
provider_type=cls.transform_provider_type(tool.tool_provider_type()),
|
||||
description=to_json(tool.entity.description),
|
||||
parameters=[cls.transform_parameter(parameter) for parameter in tool.entity.parameters],
|
||||
)
|
||||
|
|
@ -137,15 +138,18 @@ class DifyCliConfig(BaseModel):
|
|||
|
||||
cli_api_url = dify_config.CLI_API_URL
|
||||
|
||||
tools: list[Tool] = []
|
||||
tools: list[DifyCliToolConfig] = []
|
||||
for dependency in tool_deps.dependencies:
|
||||
tool = ToolManager.get_tool_runtime(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=dependency.type,
|
||||
provider_id=dependency.provider,
|
||||
tool_name=dependency.tool_name,
|
||||
invoke_from=InvokeFrom.AGENT,
|
||||
tool = DifyCliToolConfig.create_from_tool(
|
||||
ToolManager.get_tool_runtime(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=dependency.type,
|
||||
provider_id=dependency.provider,
|
||||
tool_name=dependency.tool_name,
|
||||
invoke_from=InvokeFrom.AGENT,
|
||||
)
|
||||
)
|
||||
tool.enabled = dependency.enabled
|
||||
tools.append(tool)
|
||||
|
||||
return cls(
|
||||
|
|
@ -156,7 +160,7 @@ class DifyCliConfig(BaseModel):
|
|||
cli_api_secret=session.secret,
|
||||
),
|
||||
tool_references=[DifyCliToolReference.create_from_tool_reference(ref) for ref in tool_deps.references],
|
||||
tools=[DifyCliToolConfig.create_from_tool(tool) for tool in tools],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ class ToolConfiguration(BaseModel):
|
|||
return {field.id: field.value for field in self.fields if field.value is not None}
|
||||
|
||||
|
||||
def create_tool_id(provider: str, tool_name: str) -> str:
|
||||
return f"{provider}.{tool_name}"
|
||||
|
||||
|
||||
class ToolReference(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
|
@ -33,6 +37,12 @@ class ToolReference(BaseModel):
|
|||
credential_id: str | None = None
|
||||
configuration: ToolConfiguration | None = None
|
||||
|
||||
def reference_id(self) -> str:
|
||||
return f"{self.provider}.{self.tool_name}.{self.uuid}"
|
||||
|
||||
def tool_id(self) -> str:
|
||||
return f"{self.provider}.{self.tool_name}"
|
||||
|
||||
|
||||
class FileReference(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ class ToolDependency(BaseModel):
|
|||
type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
enabled: bool = True
|
||||
|
||||
def tool_id(self) -> str:
|
||||
return f"{self.provider}.{self.tool_name}"
|
||||
|
||||
|
||||
class ToolDependencies(BaseModel):
|
||||
|
|
@ -57,3 +61,18 @@ class ToolDependencies(BaseModel):
|
|||
dependencies=list(dep_map.values()),
|
||||
references=list(ref_map.values()),
|
||||
)
|
||||
|
||||
def remove_tools(self, tools: list[ToolDependency]) -> "ToolDependencies":
|
||||
tool_keys = {f"{tool.provider}.{tool.tool_name}" for tool in tools}
|
||||
return ToolDependencies(
|
||||
dependencies=[
|
||||
dependency
|
||||
for dependency in self.dependencies
|
||||
if f"{dependency.provider}.{dependency.tool_name}" not in tool_keys
|
||||
],
|
||||
references=[
|
||||
reference
|
||||
for reference in self.references
|
||||
if f"{reference.provider}.{reference.tool_name}" not in tool_keys
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from core.skill.entities.asset_references import AssetReferences
|
|||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from core.skill.entities.skill_bundle_entry import SkillBundleEntry, SourceInfo
|
||||
from core.skill.entities.skill_document import SkillDocument
|
||||
from core.skill.entities.skill_metadata import FileReference, SkillMetadata, ToolConfiguration, ToolReference
|
||||
from core.skill.entities.skill_metadata import (
|
||||
FileReference,
|
||||
SkillMetadata,
|
||||
ToolConfiguration,
|
||||
ToolReference,
|
||||
create_tool_id,
|
||||
)
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
|
@ -164,7 +170,7 @@ class SkillCompiler:
|
|||
direct_refs: dict[str, ToolReference] = {}
|
||||
|
||||
for tool_ref in metadata.tools.values():
|
||||
key = f"{tool_ref.provider}.{tool_ref.tool_name}"
|
||||
key = tool_ref.tool_id()
|
||||
if key not in direct_tools:
|
||||
direct_tools[key] = ToolDependency(
|
||||
type=tool_ref.type,
|
||||
|
|
@ -207,9 +213,7 @@ class SkillCompiler:
|
|||
continue
|
||||
|
||||
# Collect current tools and files
|
||||
current_tools: dict[str, ToolDependency] = {
|
||||
f"{d.provider}.{d.tool_name}": d for d in entry.tools.dependencies
|
||||
}
|
||||
current_tools: dict[str, ToolDependency] = {d.tool_id(): d for d in entry.tools.dependencies}
|
||||
current_refs: dict[str, ToolReference] = {r.uuid: r for r in entry.tools.references}
|
||||
current_files: dict[str, FileReference] = {f.asset_id: f for f in entry.files.references}
|
||||
|
||||
|
|
@ -224,7 +228,7 @@ class SkillCompiler:
|
|||
continue
|
||||
|
||||
for tool_dep in dep_entry.tools.dependencies:
|
||||
key = f"{tool_dep.provider}.{tool_dep.tool_name}"
|
||||
key = tool_dep.tool_id()
|
||||
if key not in current_tools:
|
||||
current_tools[key] = tool_dep
|
||||
|
||||
|
|
@ -324,7 +328,7 @@ class SkillCompiler:
|
|||
all_refs: dict[str, ToolReference] = {}
|
||||
|
||||
for tool_ref in metadata.tools.values():
|
||||
key = f"{tool_ref.provider}.{tool_ref.tool_name}"
|
||||
key = tool_ref.tool_id()
|
||||
if key not in all_tools:
|
||||
all_tools[key] = ToolDependency(
|
||||
type=tool_ref.type,
|
||||
|
|
@ -335,7 +339,7 @@ class SkillCompiler:
|
|||
|
||||
for dep in direct_deps:
|
||||
for tool_dep in dep.tools.dependencies:
|
||||
key = f"{tool_dep.provider}.{tool_dep.tool_name}"
|
||||
key = tool_dep.tool_id()
|
||||
if key not in all_tools:
|
||||
all_tools[key] = tool_dep
|
||||
|
||||
|
|
@ -359,7 +363,7 @@ class SkillCompiler:
|
|||
tool_id = match.group(1)
|
||||
tool_ref: ToolReference | None = metadata.tools.get(tool_id)
|
||||
if not tool_ref:
|
||||
return f"[Tool not found: {tool_id}]"
|
||||
return f"[Tool not found or disabled: {tool_id}]"
|
||||
if not tool_ref.enabled:
|
||||
return ""
|
||||
return self._tool_resolver.resolve(tool_ref)
|
||||
|
|
@ -372,7 +376,7 @@ class SkillCompiler:
|
|||
tool_id = tool_match.group(1)
|
||||
tool_ref: ToolReference | None = metadata.tools.get(tool_id)
|
||||
if not tool_ref:
|
||||
enabled_renders.append(f"[Tool not found: {tool_id}]")
|
||||
enabled_renders.append(f"[Tool not found or disabled: {tool_id}]")
|
||||
continue
|
||||
if not tool_ref.enabled:
|
||||
continue
|
||||
|
|
@ -387,32 +391,32 @@ class SkillCompiler:
|
|||
content = self._config.tool_pattern.sub(replace_tool, content)
|
||||
return content
|
||||
|
||||
def _parse_metadata(self, content: str, raw_metadata: Mapping[str, Any]) -> SkillMetadata:
|
||||
def _parse_metadata(
|
||||
self, content: str, raw_metadata: Mapping[str, Any], disabled_tools: list[ToolDependency] = []
|
||||
) -> SkillMetadata:
|
||||
tools_raw = dict(raw_metadata.get("tools", {}))
|
||||
tools: dict[str, ToolReference] = {}
|
||||
|
||||
disabled_tools_set = {tool.tool_id() for tool in disabled_tools}
|
||||
tool_iter = re.finditer(r"§\[tool\]\.\[([^\]]+)\]\.\[([^\]]+)\]\.\[([^\]]+)\]§", content)
|
||||
for match in tool_iter:
|
||||
provider, name, uuid = match.group(1), match.group(2), match.group(3)
|
||||
if uuid in tools_raw:
|
||||
meta = tools_raw[uuid]
|
||||
if isinstance(meta, ToolReference):
|
||||
tools[uuid] = meta
|
||||
elif isinstance(meta, dict):
|
||||
meta_dict = cast(dict[str, Any], meta)
|
||||
tool_type_str = cast(str | None, meta_dict.get("type"))
|
||||
if tool_type_str:
|
||||
tools[uuid] = ToolReference(
|
||||
uuid=uuid,
|
||||
type=ToolProviderType.value_of(tool_type_str),
|
||||
provider=provider,
|
||||
tool_name=name,
|
||||
enabled=cast(bool, meta_dict.get("enabled", True)),
|
||||
credential_id=cast(str | None, meta_dict.get("credential_id")),
|
||||
configuration=ToolConfiguration.model_validate(meta_dict.get("configuration", {}))
|
||||
if meta_dict.get("configuration")
|
||||
else None,
|
||||
)
|
||||
meta_dict = cast(dict[str, Any], meta)
|
||||
type = cast(str, meta_dict.get("type"))
|
||||
if create_tool_id(provider, name) in disabled_tools_set:
|
||||
continue
|
||||
tools[uuid] = ToolReference(
|
||||
uuid=uuid,
|
||||
type=ToolProviderType.value_of(type),
|
||||
provider=provider,
|
||||
tool_name=name,
|
||||
enabled=cast(bool, meta_dict.get("enabled", True)),
|
||||
credential_id=cast(str | None, meta_dict.get("credential_id")),
|
||||
configuration=ToolConfiguration.model_validate(meta_dict.get("configuration", {}))
|
||||
if meta_dict.get("configuration")
|
||||
else None,
|
||||
)
|
||||
|
||||
parsed_files: list[FileReference] = []
|
||||
file_iter = re.finditer(r"§\[file\]\.\[([^\]]+)\]\.\[([^\]]+)\]§", content)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ from core.sandbox.entities.config import AppAssets
|
|||
from core.skill.constants import SkillAttrs
|
||||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from core.skill.entities.skill_document import SkillDocument
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||
from core.skill.skill_compiler import SkillCompiler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.signature import sign_upload_file
|
||||
|
|
@ -301,7 +301,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
sandbox = self.graph_runtime_state.sandbox
|
||||
if not sandbox:
|
||||
raise LLMNodeError("computer use is enabled but no sandbox found")
|
||||
tool_dependencies = self._extract_tool_dependencies()
|
||||
tool_dependencies: ToolDependencies | None = self._extract_tool_dependencies()
|
||||
generator = self._invoke_llm_with_sandbox(
|
||||
sandbox=sandbox,
|
||||
model_instance=model_instance,
|
||||
|
|
@ -1822,6 +1822,14 @@ class LLMNode(Node[LLMNodeData]):
|
|||
generation_data,
|
||||
)
|
||||
|
||||
def _extract_disabled_tools(self) -> dict[str, ToolDependency]:
|
||||
tools = [
|
||||
ToolDependency(type=tool.type, provider=tool.provider, tool_name=tool.tool_name)
|
||||
for tool in self.node_data.tool_settings
|
||||
if not tool.enabled
|
||||
]
|
||||
return {tool.tool_id(): tool for tool in tools}
|
||||
|
||||
def _extract_tool_dependencies(self) -> ToolDependencies | None:
|
||||
"""Extract tool artifact from prompt template."""
|
||||
|
||||
|
|
@ -1845,7 +1853,12 @@ class LLMNode(Node[LLMNodeData]):
|
|||
if len(tool_deps_list) == 0:
|
||||
return None
|
||||
|
||||
return reduce(lambda x, y: x.merge(y), tool_deps_list)
|
||||
disabled_tools = self._extract_disabled_tools()
|
||||
tool_dependencies = reduce(lambda x, y: x.merge(y), tool_deps_list)
|
||||
for tool in tool_dependencies.dependencies:
|
||||
if tool.tool_id() in disabled_tools:
|
||||
tool.enabled = False
|
||||
return tool_dependencies
|
||||
|
||||
def _invoke_llm_with_tools(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue