mirror of https://github.com/langgenius/dify.git
224 lines
8.4 KiB
Python
224 lines
8.4 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from collections import defaultdict
|
|
from collections.abc import Mapping, Sequence
|
|
from copy import deepcopy
|
|
from typing import Annotated, Any, Union, cast
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from dify_graph.file import File, FileAttribute, file_manager
|
|
from dify_graph.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable
|
|
from dify_graph.variables.consts import SELECTORS_LENGTH
|
|
from dify_graph.variables.segments import FileSegment, ObjectSegment
|
|
from dify_graph.variables.variables import Variable
|
|
|
|
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
|
|
|
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
|
|
|
|
|
class VariablePool(BaseModel):
|
|
# Variable dictionary is a dictionary for looking up variables by their selector.
|
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
|
# elements of the selector except the first one.
|
|
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
|
|
description="Variables mapping",
|
|
default=defaultdict(dict),
|
|
)
|
|
|
|
def add(self, selector: Sequence[str], value: Any, /):
|
|
"""
|
|
Add a variable to the variable pool.
|
|
|
|
This method accepts a selector path and a value, converting the value
|
|
to a Variable object if necessary before storing it in the pool.
|
|
|
|
Args:
|
|
selector: A two-element sequence containing [node_id, variable_name].
|
|
The selector must have exactly 2 elements to be valid.
|
|
value: The value to store. Can be a Variable, Segment, or any value
|
|
that can be converted to a Segment (str, int, float, dict, list, File).
|
|
|
|
Raises:
|
|
ValueError: If selector length is not exactly 2 elements.
|
|
|
|
Note:
|
|
While non-Segment values are currently accepted and automatically
|
|
converted, it's recommended to pass Segment or Variable objects directly.
|
|
"""
|
|
if len(selector) != SELECTORS_LENGTH:
|
|
raise ValueError(
|
|
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
|
|
f"got {len(selector)} elements"
|
|
)
|
|
|
|
if isinstance(value, VariableBase):
|
|
variable = value
|
|
elif isinstance(value, Segment):
|
|
variable = segment_to_variable(segment=value, selector=selector)
|
|
else:
|
|
segment = build_segment(value)
|
|
variable = segment_to_variable(segment=segment, selector=selector)
|
|
|
|
node_id, name = self._selector_to_keys(selector)
|
|
# Based on the definition of `Variable`,
|
|
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
|
self.variable_dictionary[node_id][name] = cast(Variable, variable)
|
|
|
|
@classmethod
|
|
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
|
return selector[0], selector[1]
|
|
|
|
def _has(self, selector: Sequence[str]) -> bool:
|
|
node_id, name = self._selector_to_keys(selector)
|
|
if node_id not in self.variable_dictionary:
|
|
return False
|
|
if name not in self.variable_dictionary[node_id]:
|
|
return False
|
|
return True
|
|
|
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
|
"""
|
|
Retrieve a variable's value from the pool as a Segment.
|
|
|
|
This method supports both simple selectors [node_id, variable_name] and
|
|
extended selectors that include attribute access for FileSegment and
|
|
ObjectSegment types.
|
|
|
|
Args:
|
|
selector: A sequence with at least 2 elements:
|
|
- [node_id, variable_name]: Returns the full segment
|
|
- [node_id, variable_name, attr, ...]: Returns a nested value
|
|
from FileSegment (e.g., 'url', 'name') or ObjectSegment
|
|
|
|
Returns:
|
|
The Segment associated with the selector, or None if not found.
|
|
Returns None if selector has fewer than 2 elements.
|
|
|
|
Raises:
|
|
ValueError: If attempting to access an invalid FileAttribute.
|
|
"""
|
|
if len(selector) < SELECTORS_LENGTH:
|
|
return None
|
|
|
|
node_id, name = self._selector_to_keys(selector)
|
|
node_map = self.variable_dictionary.get(node_id)
|
|
if node_map is None:
|
|
return None
|
|
|
|
segment: Segment | None = node_map.get(name)
|
|
|
|
if segment is None:
|
|
return None
|
|
|
|
if len(selector) == 2:
|
|
return segment
|
|
|
|
if isinstance(segment, FileSegment):
|
|
attr = selector[2]
|
|
# Python support `attr in FileAttribute` after 3.12
|
|
if attr not in {item.value for item in FileAttribute}:
|
|
return None
|
|
attr = FileAttribute(attr)
|
|
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
|
|
return build_segment(attr_value)
|
|
|
|
# Navigate through nested attributes
|
|
result: Any = segment
|
|
for attr in selector[2:]:
|
|
result = self._extract_value(result)
|
|
result = self._get_nested_attribute(result, attr)
|
|
if result is None:
|
|
return None
|
|
|
|
# Return result as Segment
|
|
return result if isinstance(result, Segment) else build_segment(result)
|
|
|
|
def _extract_value(self, obj: Any):
|
|
"""Extract the actual value from an ObjectSegment."""
|
|
return obj.value if isinstance(obj, ObjectSegment) else obj
|
|
|
|
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
|
|
"""
|
|
Get a nested attribute from a dictionary-like object.
|
|
|
|
Args:
|
|
obj: The dictionary-like object to search.
|
|
attr: The key to look up.
|
|
|
|
Returns:
|
|
Segment | None:
|
|
The corresponding Segment built from the attribute value if the key exists,
|
|
otherwise None.
|
|
"""
|
|
if not isinstance(obj, dict) or attr not in obj:
|
|
return None
|
|
return build_segment(obj.get(attr))
|
|
|
|
def remove(self, selector: Sequence[str], /):
|
|
"""
|
|
Remove variables from the variable pool based on the given selector.
|
|
|
|
Args:
|
|
selector (Sequence[str]): A sequence of strings representing the selector.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if not selector:
|
|
return
|
|
if len(selector) == 1:
|
|
self.variable_dictionary[selector[0]] = {}
|
|
return
|
|
key, hash_key = self._selector_to_keys(selector)
|
|
self.variable_dictionary[key].pop(hash_key, None)
|
|
|
|
def convert_template(self, template: str, /):
|
|
parts = VARIABLE_PATTERN.split(template)
|
|
segments: list[Segment] = []
|
|
for part in filter(lambda x: x, parts):
|
|
if "." in part and (variable := self.get(part.split("."))):
|
|
segments.append(variable)
|
|
else:
|
|
segments.append(build_segment(part))
|
|
return SegmentGroup(value=segments)
|
|
|
|
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
|
|
segment = self.get(selector)
|
|
if isinstance(segment, FileSegment):
|
|
return segment
|
|
return None
|
|
|
|
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
|
|
"""Return a copy of all variables stored under the given node prefix."""
|
|
|
|
nodes = self.variable_dictionary.get(prefix)
|
|
if not nodes:
|
|
return {}
|
|
|
|
result: dict[str, object] = {}
|
|
for key, variable in nodes.items():
|
|
value = variable.value
|
|
result[key] = deepcopy(value)
|
|
|
|
return result
|
|
|
|
def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]:
|
|
"""Return a selector-style snapshot of the entire variable pool."""
|
|
|
|
result: dict[str, object] = {}
|
|
for node_id, variables in self.variable_dictionary.items():
|
|
for name, variable in variables.items():
|
|
output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}"
|
|
result[output_name] = deepcopy(variable.value)
|
|
|
|
return result
|
|
|
|
@classmethod
|
|
def empty(cls) -> VariablePool:
|
|
"""Create an empty variable pool."""
|
|
return cls()
|