237 lines
8.0 KiB
Python
237 lines
8.0 KiB
Python
from __future__ import annotations
|
|
|
|
from collections import defaultdict
|
|
from collections.abc import Mapping, Sequence
|
|
from enum import StrEnum
|
|
from typing import Any, Protocol, cast
|
|
from uuid import uuid4
|
|
|
|
from dify_graph.enums import BuiltinNodeTypes
|
|
from dify_graph.variables import build_segment, segment_to_variable
|
|
from dify_graph.variables.segments import Segment
|
|
from dify_graph.variables.variables import RAGPipelineVariableInput, Variable
|
|
|
|
from .variable_prefixes import (
|
|
CONVERSATION_VARIABLE_NODE_ID,
|
|
ENVIRONMENT_VARIABLE_NODE_ID,
|
|
RAG_PIPELINE_VARIABLE_NODE_ID,
|
|
SYSTEM_VARIABLE_NODE_ID,
|
|
)
|
|
|
|
|
|
class SystemVariableKey(StrEnum):
|
|
QUERY = "query"
|
|
FILES = "files"
|
|
CONVERSATION_ID = "conversation_id"
|
|
USER_ID = "user_id"
|
|
DIALOGUE_COUNT = "dialogue_count"
|
|
APP_ID = "app_id"
|
|
WORKFLOW_ID = "workflow_id"
|
|
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
|
TIMESTAMP = "timestamp"
|
|
DOCUMENT_ID = "document_id"
|
|
ORIGINAL_DOCUMENT_ID = "original_document_id"
|
|
BATCH = "batch"
|
|
DATASET_ID = "dataset_id"
|
|
DATASOURCE_TYPE = "datasource_type"
|
|
DATASOURCE_INFO = "datasource_info"
|
|
INVOKE_FROM = "invoke_from"
|
|
|
|
|
|
class _VariablePoolReader(Protocol):
|
|
def get(self, selector: Sequence[str], /) -> Segment | None: ...
|
|
|
|
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ...
|
|
|
|
|
|
class _VariablePoolWriter(_VariablePoolReader, Protocol):
|
|
def add(self, selector: Sequence[str], value: object, /) -> None: ...
|
|
|
|
|
|
class _VariableLoader(Protocol):
|
|
def load_variables(self, selectors: list[list[str]]) -> Sequence[object]: ...
|
|
|
|
|
|
def system_variable_name(key: str | SystemVariableKey) -> str:
|
|
return key.value if isinstance(key, SystemVariableKey) else key
|
|
|
|
|
|
def system_variable_selector(key: str | SystemVariableKey) -> tuple[str, str]:
|
|
return SYSTEM_VARIABLE_NODE_ID, system_variable_name(key)
|
|
|
|
|
|
def _normalize_system_variable_values(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> dict[str, Any]:
|
|
raw_values = dict(values or {})
|
|
raw_values.update(kwargs)
|
|
|
|
workflow_execution_id = raw_values.pop("workflow_execution_id", None)
|
|
if workflow_execution_id is not None and SystemVariableKey.WORKFLOW_EXECUTION_ID.value not in raw_values:
|
|
raw_values[SystemVariableKey.WORKFLOW_EXECUTION_ID.value] = workflow_execution_id
|
|
|
|
normalized: dict[str, Any] = {}
|
|
for key, value in raw_values.items():
|
|
if value is None:
|
|
continue
|
|
normalized[system_variable_name(key)] = value
|
|
|
|
normalized.setdefault(SystemVariableKey.FILES.value, [])
|
|
return normalized
|
|
|
|
|
|
def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> list[Variable]:
|
|
normalized = _normalize_system_variable_values(values, **kwargs)
|
|
|
|
return [
|
|
cast(
|
|
Variable,
|
|
segment_to_variable(
|
|
segment=build_segment(value),
|
|
selector=system_variable_selector(key),
|
|
name=key,
|
|
),
|
|
)
|
|
for key, value in normalized.items()
|
|
]
|
|
|
|
|
|
def default_system_variables() -> list[Variable]:
|
|
return build_system_variables(workflow_run_id=str(uuid4()))
|
|
|
|
|
|
def system_variables_to_mapping(system_variables: Sequence[Variable]) -> dict[str, Any]:
|
|
return {variable.name: variable.value for variable in system_variables}
|
|
|
|
|
|
def _with_selector(variable: Variable, node_id: str) -> Variable:
|
|
selector = [node_id, variable.name]
|
|
if list(variable.selector) == selector:
|
|
return variable
|
|
return variable.model_copy(update={"selector": selector})
|
|
|
|
|
|
def build_bootstrap_variables(
|
|
*,
|
|
system_variables: Sequence[Variable] = (),
|
|
environment_variables: Sequence[Variable] = (),
|
|
conversation_variables: Sequence[Variable] = (),
|
|
rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = (),
|
|
) -> list[Variable]:
|
|
variables = [
|
|
*(_with_selector(variable, SYSTEM_VARIABLE_NODE_ID) for variable in system_variables),
|
|
*(_with_selector(variable, ENVIRONMENT_VARIABLE_NODE_ID) for variable in environment_variables),
|
|
*(_with_selector(variable, CONVERSATION_VARIABLE_NODE_ID) for variable in conversation_variables),
|
|
]
|
|
|
|
rag_pipeline_variables_map: defaultdict[str, dict[str, Any]] = defaultdict(dict)
|
|
for rag_var in rag_pipeline_variables:
|
|
node_id = rag_var.variable.belong_to_node_id
|
|
key = rag_var.variable.variable
|
|
rag_pipeline_variables_map[node_id][key] = rag_var.value
|
|
|
|
for node_id, value in rag_pipeline_variables_map.items():
|
|
variables.append(
|
|
cast(
|
|
Variable,
|
|
segment_to_variable(
|
|
segment=build_segment(value),
|
|
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
|
name=node_id,
|
|
),
|
|
)
|
|
)
|
|
|
|
return variables
|
|
|
|
|
|
def get_system_segment(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Segment | None:
|
|
return variable_pool.get(system_variable_selector(key))
|
|
|
|
|
|
def get_system_value(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Any:
|
|
segment = get_system_segment(variable_pool, key)
|
|
return None if segment is None else segment.value
|
|
|
|
|
|
def get_system_text(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> str | None:
|
|
segment = get_system_segment(variable_pool, key)
|
|
if segment is None:
|
|
return None
|
|
text = getattr(segment, "text", None)
|
|
return text if isinstance(text, str) else None
|
|
|
|
|
|
def get_all_system_variables(variable_pool: _VariablePoolReader) -> Mapping[str, object]:
|
|
return variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
|
|
|
|
|
_MEMORY_BOOTSTRAP_NODE_TYPES = frozenset(
|
|
(
|
|
BuiltinNodeTypes.LLM,
|
|
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
|
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
|
)
|
|
)
|
|
|
|
|
|
def get_node_creation_preload_selectors(
|
|
*,
|
|
node_type: str,
|
|
node_data: object,
|
|
) -> tuple[tuple[str, str], ...]:
|
|
"""Return selectors that must exist before node construction begins."""
|
|
|
|
if node_type not in _MEMORY_BOOTSTRAP_NODE_TYPES or getattr(node_data, "memory", None) is None:
|
|
return ()
|
|
|
|
return (system_variable_selector(SystemVariableKey.CONVERSATION_ID),)
|
|
|
|
|
|
def preload_node_creation_variables(
|
|
*,
|
|
variable_loader: _VariableLoader,
|
|
variable_pool: _VariablePoolWriter,
|
|
selectors: Sequence[Sequence[str]],
|
|
) -> None:
|
|
"""Load constructor-time variables before node or graph creation."""
|
|
|
|
seen_selectors: set[tuple[str, ...]] = set()
|
|
selectors_to_load: list[list[str]] = []
|
|
for selector in selectors:
|
|
normalized_selector = tuple(selector)
|
|
if len(normalized_selector) < 2:
|
|
raise ValueError(f"Invalid preload selector: {selector}")
|
|
if normalized_selector in seen_selectors:
|
|
continue
|
|
seen_selectors.add(normalized_selector)
|
|
if variable_pool.get(normalized_selector) is None:
|
|
selectors_to_load.append(list(normalized_selector))
|
|
|
|
loaded_variables = variable_loader.load_variables(selectors_to_load)
|
|
for variable in loaded_variables:
|
|
raw_selector = getattr(variable, "selector", ())
|
|
loaded_selector = list(raw_selector)
|
|
if len(loaded_selector) < 2:
|
|
raise ValueError(f"Invalid loaded variable selector: {raw_selector}")
|
|
variable_pool.add(loaded_selector[:2], variable)
|
|
|
|
|
|
def inject_default_system_variable_mappings(
|
|
*,
|
|
node_id: str,
|
|
node_type: str,
|
|
node_data: object,
|
|
variable_mapping: Mapping[str, Sequence[str]],
|
|
) -> Mapping[str, Sequence[str]]:
|
|
"""Add workflow-owned implicit sys mappings that `dify_graph` should not know about."""
|
|
|
|
if node_type != BuiltinNodeTypes.LLM or getattr(node_data, "memory", None) is None:
|
|
return variable_mapping
|
|
|
|
query_mapping_key = f"{node_id}.#sys.query#"
|
|
if query_mapping_key in variable_mapping:
|
|
return variable_mapping
|
|
|
|
augmented_mapping = dict(variable_mapping)
|
|
augmented_mapping[query_mapping_key] = system_variable_selector(SystemVariableKey.QUERY)
|
|
return augmented_mapping
|