dify/api/dify_graph/nodes/llm/node.py

2728 lines
112 KiB
Python

from __future__ import annotations
import base64
import io
import json
import logging
import mimetypes
import os
import re
import time
from collections.abc import Generator, Mapping, Sequence
from functools import reduce
from pathlib import PurePosixPath
from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import (
adapt_schema_for_sandbox_file_paths,
convert_sandbox_file_paths_in_output,
detect_file_path_fields,
)
from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_structured_output,
)
from core.model_manager import ModelInstance, ModelManager
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.tools.signature import sign_tool_file, sign_upload_file
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.tool_entities import ToolCallResult
from dify_graph.enums import (
BuiltinNodeTypes,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import (
ModelFeature,
ModelPropertyKey,
ModelType,
)
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
ModelInvokeCompletedEvent,
NodeEventBase,
NodeRunResult,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
from dify_graph.node_events.node import ChunkType, ThoughtEndChunkEvent, ThoughtStartChunkEvent
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
ArrayPromptMessageSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from . import llm_utils
from .entities import (
AgentContext,
AggregatedResult,
LLMGenerationData,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
LLMTraceSegment,
ModelConfig,
ModelTraceSegment,
PromptMessageContext,
StreamBuffers,
ThinkTagStreamParser,
ToolLogPayload,
ToolOutputState,
ToolTraceSegment,
TraceState,
)
from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError,
NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.agent.entities import AgentLog, AgentResult
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.base import BaseMemory
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.sandbox import Sandbox
from core.skill.entities.skill_bundle import SkillBundle
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
from core.tools.__base.tool import Tool
from dify_graph.file.models import File
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
class LLMNode(Node[LLMNodeData]):
node_type = BuiltinNodeTypes.LLM
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
http_client: HttpClientProtocol,
credentials_provider: object | None = None,
model_factory: object | None = None,
model_instance: object | None = None,
template_renderer: object | None = None,
memory: object | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._file_outputs: list[File] = []
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
from core.sandbox.bash.session import MAX_OUTPUT_FILES
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
clean_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = "" # Initialize as empty string for consistency
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
variable_pool = self.graph_runtime_state.variable_pool
try:
# Parse prompt template to separate static messages and context references
prompt_template = self.node_data.prompt_template
static_messages, context_refs, template_order = self._parse_prompt_template()
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
# merge inputs
inputs.update(jinja_inputs)
# fetch files
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector,
)
if self.node_data.vision.enabled
else []
)
if files:
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
context = None
context_files: list[File] = []
for event in generator:
context = event.context
context_files = event.context_files or []
yield event
if context:
node_inputs["#context#"] = context
if context_files:
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
query: str | None = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
query = query_variable.text
# Get prompt messages
prompt_messages: Sequence[PromptMessage]
stop: Sequence[str] | None
if isinstance(prompt_template, list) and context_refs:
prompt_messages, stop = self._build_prompt_messages_with_context(
context_refs=context_refs,
template_order=template_order,
static_messages=static_messages,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config,
context_files=context_files,
)
else:
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
self.node_data.prompt_template,
),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
sandbox=self.graph_runtime_state.sandbox,
)
# Variables for outputs
generation_data: LLMGenerationData | None = None
structured_output: LLMStructuredOutput | None = None
structured_output_schema: Mapping[str, Any] | None
structured_output_file_paths: list[str] = []
if self.node_data.structured_output_enabled:
if not self.node_data.structured_output:
raise LLMNodeError("structured_output_enabled is True but structured_output is not set")
raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
if self.node_data.computer_use:
raise LLMNodeError("Structured output is not supported in computer use mode.")
else:
if detect_file_path_fields(raw_schema):
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
raw_schema
)
else:
structured_output_schema = raw_schema
else:
structured_output_schema = None
if self.node_data.computer_use:
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("computer use is enabled but no sandbox found")
tool_dependencies: ToolDependencies | None = self._extract_tool_dependencies()
generator = self._invoke_llm_with_sandbox(
sandbox=sandbox,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
variable_pool=variable_pool,
tool_dependencies=tool_dependencies,
)
elif self.tool_call_enabled:
generator = self._invoke_llm_with_tools(
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
files=files,
variable_pool=variable_pool,
node_inputs=node_inputs,
process_data=process_data,
)
else:
# Use traditional LLM invocation
generator = LLMNode.invoke_llm(
node_data_model=self._node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_schema=structured_output_schema,
allow_file_path=bool(structured_output_file_paths),
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self._node_data.reasoning_format,
)
(
clean_text,
reasoning_content,
generation_reasoning_content,
generation_clean_content,
usage,
finish_reason,
structured_output,
generation_data,
) = yield from self._stream_llm_events(generator, model_instance=model_instance)
if structured_output and structured_output_file_paths:
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
structured_output_value = structured_output.structured_output
if structured_output_value is None:
raise LLMNodeError("Structured output is empty")
resolved_count = 0
def resolve_file(path: str) -> File:
nonlocal resolved_count
if resolved_count >= MAX_OUTPUT_FILES:
raise LLMNodeError("Structured output files exceed the sandbox output limit")
resolved_count += 1
return self._resolve_sandbox_file_path(sandbox=sandbox, path=path)
converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
output=structured_output_value,
file_path_fields=structured_output_file_paths,
file_resolver=resolve_file,
)
structured_output = LLMStructuredOutput(structured_output=converted_output)
if structured_output_files:
self._file_outputs.extend(structured_output_files)
# Extract variables from generation_data if available
if generation_data:
clean_text = generation_data.text
reasoning_content = ""
usage = generation_data.usage
finish_reason = generation_data.finish_reason
# Unified process_data building
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
if self.tool_call_enabled and self._node_data.tools:
process_data["tools"] = [
{
"type": tool.type.value if hasattr(tool.type, "value") else tool.type,
"provider_name": tool.provider_name,
"tool_name": tool.tool_name,
}
for tool in self._node_data.tools
if tool.enabled
]
is_sandbox = self.graph_runtime_state.sandbox is not None
outputs = self._build_outputs(
is_sandbox=is_sandbox,
clean_text=clean_text,
reasoning_content=reasoning_content,
generation_reasoning_content=generation_reasoning_content,
generation_clean_content=generation_clean_content,
usage=usage,
finish_reason=finish_reason,
prompt_messages=prompt_messages,
generation_data=generation_data,
structured_output=structured_output,
)
# Send final chunk event to indicate streaming is complete
# For tool calls and sandbox, final events are already sent in _process_tool_outputs
if not self.tool_call_enabled and not self.node_data.computer_use:
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
}
if generation_data and generation_data.trace:
metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [
segment.model_dump() for segment in generation_data.trace
]
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata=metadata,
llm_usage=usage,
)
)
except ValueError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
llm_usage=usage,
)
)
except Exception as e:
logger.exception("error while executing llm node")
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
llm_usage=usage,
)
)
def _build_outputs(
self,
*,
is_sandbox: bool,
clean_text: str,
reasoning_content: str,
generation_reasoning_content: str,
generation_clean_content: str,
usage: LLMUsage,
finish_reason: str | None,
prompt_messages: Sequence[PromptMessage],
generation_data: LLMGenerationData | None,
structured_output: LLMStructuredOutput | None,
) -> dict[str, Any]:
"""Build the outputs dictionary for the LLM node.
Two runtime modes produce different output shapes:
- **Classical** (is_sandbox=False): top-level ``text`` and ``reasoning_content``
are preserved for backward compatibility with existing users.
- **Sandbox** (is_sandbox=True): ``text`` and ``reasoning_content`` are omitted
from the top level because they duplicate fields inside ``generation``.
The ``generation`` field always carries the full structured representation
(content, reasoning, tool_calls, sequence) regardless of runtime mode.
"""
# Common outputs shared by both runtimes
outputs: dict[str, Any] = {
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
}
# Classical runtime keeps top-level text/reasoning_content for backward compatibility
if not is_sandbox:
outputs["text"] = clean_text
outputs["reasoning_content"] = reasoning_content
# Build generation field
if generation_data:
generation = {
"content": generation_data.text,
"reasoning_content": generation_data.reasoning_contents,
"tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls],
"sequence": generation_data.sequence,
}
files_to_output = list(generation_data.files)
if self._file_outputs:
existing_ids = {f.id for f in files_to_output}
files_to_output.extend(f for f in self._file_outputs if f.id not in existing_ids)
else:
generation_reasoning = generation_reasoning_content or reasoning_content
generation_content = generation_clean_content or clean_text
sequence: list[dict[str, Any]] = []
if generation_reasoning:
sequence = [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len(generation_content)},
]
generation = {
"content": generation_content,
"reasoning_content": [generation_reasoning] if generation_reasoning else [],
"tool_calls": [],
"sequence": sequence,
}
files_to_output = self._file_outputs
outputs["generation"] = generation
if files_to_output:
outputs["files"] = ArrayFileSegment(value=files_to_output)
if structured_output:
outputs["structured_output"] = structured_output.structured_output
return outputs
@staticmethod
def invoke_llm(
*,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
user_id: str,
structured_output_schema: Mapping[str, Any] | None,
allow_file_path: bool = False,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
)
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
if structured_output_schema:
request_start_time = time.perf_counter()
invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=structured_output_schema,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
user=user_id,
allow_file_path=allow_file_path,
)
else:
request_start_time = time.perf_counter()
invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=user_id,
)
return LLMNode.handle_invoke_result(
invoke_result=invoke_result,
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
node_type=node_type,
reasoning_format=reasoning_format,
request_start_time=request_start_time,
)
@staticmethod
def handle_invoke_result(
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_start_time: float | None = None,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
duration = None
if request_start_time is not None:
duration = time.perf_counter() - request_start_time
invoke_result.usage.latency = round(duration, 3)
event = LLMNode.handle_blocking_result(
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
reasoning_format=reasoning_format,
request_latency=duration,
)
yield event
return
# For streaming mode
model = ""
prompt_messages: list[PromptMessage] = []
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
think_parser = ThinkTagStreamParser()
reasoning_chunks: list[str] = []
# Initialize streaming metrics tracking
start_time = request_start_time if request_start_time is not None else time.perf_counter()
first_token_time = None
has_content = False
collected_structured_output = None
try:
for result in invoke_result:
if isinstance(result, LLMResultChunkWithStructuredOutput):
if result.structured_output is not None:
collected_structured_output = dict(result.structured_output)
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=contents,
file_saver=file_saver,
file_outputs=file_outputs,
):
if text_part and not has_content:
first_token_time = time.perf_counter()
has_content = True
full_text_buffer.write(text_part)
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=text_part,
is_final=False,
)
for kind, segment in think_parser.process(text_part):
if not segment:
if kind not in {"thought_start", "thought_end"}:
continue
if kind == "thought_start":
yield ThoughtStartChunkEvent(
selector=[node_id, "generation", "thought"],
chunk="",
is_final=False,
)
elif kind == "thought":
reasoning_chunks.append(segment)
yield ThoughtChunkEvent(
selector=[node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
elif kind == "thought_end":
yield ThoughtEndChunkEvent(
selector=[node_id, "generation", "thought"],
chunk="",
is_final=False,
)
else:
yield StreamChunkEvent(
selector=[node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
if not model and result.model:
model = result.model
if len(prompt_messages) == 0:
prompt_messages = list(result.prompt_messages)
if usage.prompt_tokens == 0 and result.delta.usage:
usage = result.delta.usage
if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
for kind, segment in think_parser.flush():
if not segment and kind not in {"thought_start", "thought_end"}:
continue
if kind == "thought_start":
yield ThoughtStartChunkEvent(
selector=[node_id, "generation", "thought"],
chunk="",
is_final=False,
)
elif kind == "thought":
reasoning_chunks.append(segment)
yield ThoughtChunkEvent(
selector=[node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
elif kind == "thought_end":
yield ThoughtEndChunkEvent(
selector=[node_id, "generation", "thought"],
chunk="",
is_final=False,
)
else:
yield StreamChunkEvent(
selector=[node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
full_text = full_text_buffer.getvalue()
if reasoning_format == "tagged":
clean_text = full_text
reasoning_content = "".join(reasoning_chunks)
else:
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
if reasoning_chunks and not reasoning_content:
reasoning_content = "".join(reasoning_chunks)
# Calculate streaming metrics
end_time = time.perf_counter()
total_duration = end_time - start_time
usage.latency = round(total_duration, 3)
if has_content and first_token_time:
gen_ai_server_time_to_first_token = first_token_time - start_time
llm_streaming_time_to_generate = end_time - first_token_time
usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3)
usage.time_to_generate = round(llm_streaming_time_to_generate, 3)
yield ModelInvokeCompletedEvent(
text=clean_text if reasoning_format == "separated" else full_text,
usage=usage,
finish_reason=finish_reason,
reasoning_content=reasoning_content,
structured_output=collected_structured_output,
)
@staticmethod
def _image_file_to_markdown(file: File, /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@classmethod
def _split_reasoning(
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
) -> tuple[str, str]:
if reasoning_format == "tagged":
return text, ""
matches = cls._THINK_PATTERN.findall(text)
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
clean_text = cls._THINK_PATTERN.sub("", text)
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
return clean_text, reasoning_content or ""
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == "jinja2" and messages.jinja2_text:
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == "jinja2" and message.jinja2_text:
message.text = message.jinja2_text
return messages
def _parse_prompt_template(
self,
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
prompt_template = self.node_data.prompt_template
static_messages: list[LLMNodeChatModelMessage] = []
context_refs: list[PromptMessageContext] = []
template_order: list[tuple[int, str]] = []
if isinstance(prompt_template, list):
for idx, item in enumerate(prompt_template):
if isinstance(item, PromptMessageContext):
context_refs.append(item)
template_order.append((idx, "context"))
else:
static_messages.append(item)
template_order.append((idx, "static"))
if static_messages:
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
return static_messages, context_refs, template_order
def _build_prompt_messages_with_context(
self,
*,
context_refs: list[PromptMessageContext],
template_order: list[tuple[int, str]],
static_messages: list[LLMNodeChatModelMessage],
query: str | None,
files: Sequence[File],
context: str | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
context_files: list[File],
) -> tuple[list[PromptMessage], Sequence[str] | None]:
variable_pool = self.graph_runtime_state.variable_pool
combined_messages: list[PromptMessage] = []
context_idx = 0
static_idx = 0
for _, type_ in template_order:
if type_ == "context":
ctx_ref = context_refs[context_idx]
ctx_var = variable_pool.get(ctx_ref.value_selector)
if ctx_var is None:
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value)
combined_messages.extend(restored_messages)
context_idx += 1
else:
static_msg = static_messages[static_idx]
processed_msgs = LLMNode.handle_list_messages(
messages=[static_msg],
context=context,
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
sandbox=self.graph_runtime_state.sandbox,
)
combined_messages.extend(processed_msgs)
static_idx += 1
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=self.node_data.memory,
model_config=model_config,
)
combined_messages.extend(memory_messages)
if query:
query_message = LLMNodeChatModelMessage(
text=query,
role=PromptMessageRole.USER,
edition_type="basic",
)
query_msgs = LLMNode.handle_list_messages(
messages=[query_message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(query_msgs)
combined_messages = self._append_files_to_messages(
messages=combined_messages,
sys_files=files,
context_files=context_files,
model_config=model_config,
)
combined_messages = self._filter_messages(combined_messages, model_config)
stop = self._get_stop_sequences(model_config)
return combined_messages, stop
def _append_files_to_messages(
self,
*,
messages: list[PromptMessage],
sys_files: Sequence[File],
context_files: list[File],
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
vision_enabled = self.node_data.vision.enabled
vision_detail = self.node_data.vision.configs.detail
if vision_enabled and sys_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
if vision_enabled and context_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
for file in context_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
return messages
def _filter_messages(
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> list[PromptMessage]:
filtered_messages: list[PromptMessage] = []
for message in messages:
if isinstance(message.content, list):
filtered_content: list[PromptMessageContentUnionTypes] = []
for content_item in message.content:
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
filtered_content.append(content_item)
continue
feature_map = {
PromptMessageContentType.IMAGE: ModelFeature.VISION,
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
}
required_feature = feature_map.get(content_item.type)
if required_feature and required_feature not in model_config.model_schema.features:
continue
filtered_content.append(content_item)
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
message.content = filtered_content[0].data
else:
message.content = filtered_content
if not message.is_empty():
filtered_messages.append(message)
if not filtered_messages:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_messages
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
return model_config.stop
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables: dict[str, Any] = {}
if not node_data.prompt_config:
return variables
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
def parse_dict(input_dict: Mapping[str, Any]) -> str:
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return str(input_dict["content"])
try:
return json.dumps(input_dict, ensure_ascii=False)
except Exception:
return str(input_dict)
if isinstance(variable, ArraySegment):
result = ""
for item in variable.value:
if isinstance(item, dict):
result += parse_dict(item)
else:
result += str(item)
result += "\n"
value = result.strip()
elif isinstance(variable, ObjectSegment):
value = parse_dict(variable.value)
else:
value = variable.text
variables[variable_name] = value
return variables
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
inputs = {}
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, CompletionModelPromptTemplate):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
inputs[variable_selector.variable] = ""
inputs[variable_selector.variable] = variable.to_object()
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
continue
inputs[variable_selector.variable] = variable.to_object()
return inputs
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
return
if not node_data.context.variable_selector:
return
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
if context_value_variable:
if isinstance(context_value_variable, StringSegment):
yield RunRetrieverResourceEvent(
retriever_resources=[], context=context_value_variable.value, context_files=[]
)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource: list[RetrievalSourceMetadata] = []
context_files: list[File] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
else:
if "content" not in item:
raise InvalidContextStructureError(f"Invalid context structure: {item}")
if item.get("summary"):
context_str += item["summary"] + "\n"
context_str += item["content"] + "\n"
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == retriever_resource.segment_id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attachment_info)
yield RunRetrieverResourceEvent(
retriever_resources=[r.model_dump() for r in original_retriever_resource],
context=context_str.strip(),
context_files=context_files,
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
and context_dict["metadata"]["_source"] == "knowledge"
):
metadata = context_dict.get("metadata", {})
source = RetrievalSourceMetadata(
position=metadata.get("position"),
dataset_id=metadata.get("dataset_id"),
dataset_name=metadata.get("dataset_name"),
document_id=metadata.get("document_id"),
document_name=metadata.get("document_name"),
data_source_type=metadata.get("data_source_type"),
segment_id=metadata.get("segment_id"),
retriever_from=metadata.get("retriever_from"),
score=metadata.get("score"),
hit_count=metadata.get("segment_hit_count"),
word_count=metadata.get("segment_word_count"),
segment_position=metadata.get("segment_position"),
index_node_hash=metadata.get("segment_index_node_hash"),
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
summary=context_dict.get("summary"),
)
return source
return None
@staticmethod
def _fetch_model_config(
*,
node_data_model: ModelConfig,
tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=tenant_id, node_data_model=node_data_model
)
completion_params = model_config_with_cred.parameters
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_config_with_cred.parameters = completion_params
node_data_model.completion_params = completion_params
return model, model_config_with_cred
@staticmethod
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: BaseMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list[File] | None = None,
sandbox: Sandbox | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
sandbox=sandbox,
)
)
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
prompt_messages.extend(memory_messages)
if sys_query:
message = LLMNodeChatModelMessage(
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
prompt_content = prompt_messages[0].content
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
content_item.data = memory_text + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
if sys_query:
if prompt_content_type == str:
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
if vision_enabled and sys_files:
file_prompts = []
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_config.model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message)
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
)
model_schema = model.model_type_instance.get_model_schema(
model=model_config.model,
credentials=model.credentials,
)
if not model_schema:
raise ModelNotExistError(f"Model {model_config.model} not exist.")
return filtered_prompt_messages, model_config.stop
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config
typed_node_data = node_data
prompt_template = typed_node_data.prompt_template
variable_selectors = []
prompt_context_selectors: list[Sequence[str]] = []
if isinstance(prompt_template, list):
for item in prompt_template:
if isinstance(item, PromptMessageContext):
if len(item.value_selector) >= 2:
prompt_context_selectors.append(item.value_selector)
elif isinstance(item, LLMNodeChatModelMessage):
variable_template_parser = VariableTemplateParser(template=item.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
else:
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping: dict[str, Any] = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
for context_selector in prompt_context_selectors:
variable_key = f"#{'.'.join(context_selector)}#"
variable_mapping[variable_key] = list(context_selector)
memory = typed_node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if typed_node_data.context.enabled:
variable_mapping["#context#"] = typed_node_data.context.variable_selector
if typed_node_data.vision.enabled:
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
for item in prompt_template:
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
enable_jinja = True
break
else:
enable_jinja = True
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "llm",
"config": {
"prompt_templates": {
"chat_model": {
"prompts": [
{"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"}
]
},
"completion_model": {
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"prompt": {
"text": "Here are the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic",
},
"stop": ["Human:"],
},
}
},
}
@staticmethod
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
sandbox: Sandbox | None = None,
) -> Sequence[PromptMessage]:
from core.sandbox.entities.config import AppAssets
from core.skill.assembler import SkillDocumentAssembler
from core.skill.constants import SkillAttrs
from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.skill_metadata import SkillMetadata
prompt_messages: list[PromptMessage] = []
bundle: SkillBundle | None = None
if sandbox:
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
if bundle is not None:
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
document=SkillDocument(
skill_id="anonymous",
content=result_text,
metadata=SkillMetadata.model_validate(message.metadata or {}),
),
base_path=AppAssets.PATH,
)
result_text = skill_entry.content
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
plain_text = segment_group.text
if plain_text and bundle is not None:
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
document=SkillDocument(
skill_id="anonymous",
content=plain_text,
metadata=SkillMetadata.model_validate(message.metadata or {}),
),
base_path=AppAssets.PATH,
)
plain_text = skill_entry.content
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
@staticmethod
def handle_blocking_result(
*,
invoke_result: LLMResult | LLMResultWithStructuredOutput,
saver: LLMFileSaver,
file_outputs: list[File],
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_latency: float | None = None,
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=invoke_result.message.content,
file_saver=saver,
file_outputs=file_outputs,
):
buffer.write(text_part)
full_text = buffer.getvalue()
if reasoning_format == "tagged":
clean_text = full_text
reasoning_content = ""
else:
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
event = ModelInvokeCompletedEvent(
text=clean_text if reasoning_format == "separated" else full_text,
usage=invoke_result.usage,
finish_reason=None,
reasoning_content=reasoning_content,
structured_output=getattr(invoke_result, "structured_output", None),
)
if request_latency is not None:
event.usage.latency = round(request_latency, 3)
return event
@staticmethod
def save_multimodal_image_output(
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> File:
if content.url != "":
saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else:
saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
return saved_file
@staticmethod
def _normalize_sandbox_file_path(path: str) -> str:
raw = path.strip()
if not raw:
raise LLMNodeError("Sandbox file path must not be empty")
sandbox_path = PurePosixPath(raw)
if any(part == ".." for part in sandbox_path.parts):
raise LLMNodeError("Sandbox file path must not contain '..'")
normalized = str(sandbox_path)
if normalized in {".", ""}:
raise LLMNodeError("Sandbox file path is invalid")
return normalized
def _resolve_sandbox_file_path(self, *, sandbox: Sandbox, path: str) -> File:
from core.sandbox.bash.session import MAX_OUTPUT_FILE_SIZE
from core.tools.tool_file_manager import ToolFileManager
normalized_path = self._normalize_sandbox_file_path(path)
filename = os.path.basename(normalized_path)
if not filename:
raise LLMNodeError("Sandbox file path must point to a file")
try:
file_content = sandbox.vm.download_file(normalized_path)
except Exception as exc:
raise LLMNodeError(f"Sandbox file not found: {normalized_path}") from exc
file_binary = file_content.getvalue()
if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
raise LLMNodeError(f"Sandbox file exceeds size limit: {normalized_path}")
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mime_type,
filename=filename,
)
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
url = sign_tool_file(tool_file.id, extension)
file_type = self._get_file_type_from_mime(mime_type)
return File(
id=tool_file.id,
tenant_id=self.tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=filename,
extension=extension,
mime_type=mime_type,
size=len(file_binary),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
)
@staticmethod
def _get_file_type_from_mime(mime_type: str) -> FileType:
if mime_type.startswith("image/"):
return FileType.IMAGE
if mime_type.startswith("video/"):
return FileType.VIDEO
if mime_type.startswith("audio/"):
return FileType.AUDIO
if "text" in mime_type or "pdf" in mime_type:
return FileType.DOCUMENT
return FileType.CUSTOM
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list[File],
) -> Generator[str, None, None]:
if contents is None:
yield from []
return
if isinstance(contents, str):
yield contents
else:
for item in contents:
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
file = LLMNode.save_multimodal_image_output(
content=item,
file_saver=file_saver,
)
file_outputs.append(file)
yield LLMNode._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
@property
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
@property
def tool_call_enabled(self) -> bool:
return (
self.node_data.tools is not None
and len(self.node_data.tools) > 0
and all(tool.enabled for tool in self.node_data.tools)
)
def _stream_llm_events(
self,
generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None],
*,
model_instance: ModelInstance,
) -> Generator[
NodeEventBase,
None,
tuple[
str,
str,
str,
str,
LLMUsage,
str | None,
LLMStructuredOutput | None,
LLMGenerationData | None,
],
]:
clean_text = ""
reasoning_content = ""
generation_reasoning_content = ""
generation_clean_content = ""
usage = LLMUsage.empty_usage()
finish_reason: str | None = None
structured_output: LLMStructuredOutput | None = None
generation_data: LLMGenerationData | None = None
completed = False
while True:
try:
event = next(generator)
except StopIteration as exc:
if isinstance(exc.value, LLMGenerationData):
generation_data = exc.value
break
if completed:
continue
match event:
case StreamChunkEvent() | ThoughtChunkEvent():
yield event
case ModelInvokeCompletedEvent(
text=text,
usage=usage_event,
finish_reason=finish_reason_event,
reasoning_content=reasoning_event,
structured_output=structured_raw,
):
clean_text = text
usage = usage_event
finish_reason = finish_reason_event
reasoning_content = reasoning_event or ""
generation_reasoning_content = reasoning_content
generation_clean_content = clean_text
if self.node_data.reasoning_format == "tagged":
generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning(
clean_text, reasoning_format="separated"
)
else:
clean_text, generation_reasoning_content = LLMNode._split_reasoning(
clean_text, self.node_data.reasoning_format
)
generation_clean_content = clean_text
structured_output = (
LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None
)
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
completed = True
case LLMStructuredOutput():
structured_output = event
case _:
continue
return (
clean_text,
reasoning_content,
generation_reasoning_content,
generation_clean_content,
usage,
finish_reason,
structured_output,
generation_data,
)
def _extract_disabled_tools(self) -> dict[str, ToolDependency]:
from core.skill.entities.tool_dependencies import ToolDependency
tools = [
ToolDependency(type=tool.type, provider=tool.provider, tool_name=tool.tool_name)
for tool in self.node_data.tool_settings
if not tool.enabled
]
return {tool.tool_id(): tool for tool in tools}
def _extract_tool_dependencies(self) -> ToolDependencies | None:
from core.sandbox.entities.config import AppAssets
from core.skill.assembler import SkillDocumentAssembler
from core.skill.constants import SkillAttrs
from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.skill_metadata import SkillMetadata
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("Sandbox not found")
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
tool_deps_list: list[ToolDependencies] = []
for prompt in self.node_data.prompt_template:
if isinstance(prompt, LLMNodeChatModelMessage):
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
document=SkillDocument(
skill_id="anonymous",
content=prompt.text,
metadata=SkillMetadata.model_validate(prompt.metadata or {}),
),
base_path=AppAssets.PATH,
)
tool_deps_list.append(skill_entry.tools)
if len(tool_deps_list) == 0:
return None
disabled_tools = self._extract_disabled_tools()
tool_dependencies = reduce(lambda x, y: x.merge(y), tool_deps_list)
for tool in tool_dependencies.dependencies:
if tool.tool_id() in disabled_tools:
tool.enabled = False
return tool_dependencies
def _invoke_llm_with_tools(
self,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
files: Sequence[File],
variable_pool: VariablePool,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
) -> Generator[NodeEventBase, None, LLMGenerationData]:
from core.agent.entities import ExecutionContext
from core.agent.patterns import StrategyFactory
model_features = self._get_model_features(model_instance)
tool_instances = self._prepare_tool_instances(variable_pool)
prompt_files = self._extract_prompt_files(variable_pool)
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=model_instance,
tools=tool_instances,
files=prompt_files,
max_iterations=self._node_data.max_iterations or 10,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
)
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=True,
)
result = yield from self._process_tool_outputs(outputs)
return result
def _invoke_llm_with_sandbox(
self,
sandbox: Sandbox,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
variable_pool: VariablePool,
tool_dependencies: ToolDependencies | None,
) -> Generator[NodeEventBase, None, LLMGenerationData]:
from core.agent.entities import AgentEntity, ExecutionContext
from core.agent.patterns import StrategyFactory
from core.sandbox.bash.session import SandboxBashSession
result: LLMGenerationData | None = None
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
prompt_files = self._extract_prompt_files(variable_pool)
model_features = self._get_model_features(model_instance)
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=model_instance,
tools=[session.bash_tool],
files=prompt_files,
max_iterations=self._node_data.max_iterations or 100,
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
)
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=True,
)
result = yield from self._process_tool_outputs(outputs)
collected_files = session.collect_output_files()
if collected_files:
existing_ids = {f.id for f in self._file_outputs}
self._file_outputs.extend(f for f in collected_files if f.id not in existing_ids)
if result is None:
raise LLMNodeError("SandboxSession exited unexpectedly")
return result
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
try:
model_type_instance = model_instance.model_type_instance
model_schema = model_type_instance.get_model_schema(
model_instance.model_name,
model_instance.credentials,
)
return model_schema.features if model_schema and model_schema.features else []
except Exception:
logger.warning("Failed to get model schema, assuming no special features")
return []
def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]:
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
tool_instances = []
if self._node_data.tools:
for tool in self._node_data.tools:
try:
processed_settings = {}
for key, value in tool.settings.items():
if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict):
if "type" in value["value"] and "value" in value["value"]:
processed_settings[key] = value["value"]
else:
processed_settings[key] = value
else:
processed_settings[key] = value
merged_parameters = {**tool.parameters, **processed_settings}
agent_tool = AgentToolEntity(
provider_id=tool.provider_name,
provider_type=tool.type,
tool_name=tool.tool_name,
tool_parameters=merged_parameters,
plugin_unique_identifier=tool.plugin_unique_identifier,
credential_id=tool.credential_id,
)
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_id,
agent_tool=agent_tool,
invoke_from=self.invoke_from,
variable_pool=variable_pool,
)
if tool.extra.get("description") and tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
tool.extra.get("description") or tool_runtime.entity.description.llm
)
tool_instances.append(tool_runtime)
except Exception as e:
logger.warning("Failed to load tool %s: %s", tool, str(e))
continue
return tool_instances
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
from dify_graph.variables.variables import ArrayFileVariable, FileVariable
files: list[File] = []
if isinstance(self._node_data.prompt_template, list):
for message in self._node_data.prompt_template:
if message.text:
parser = VariableTemplateParser(message.text)
variable_selectors = parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable = variable_pool.get(variable_selector.value_selector)
if isinstance(variable, FileVariable) and variable.value:
files.append(variable.value)
elif isinstance(variable, ArrayFileVariable) and variable.value:
files.extend(variable.value)
return files
@staticmethod
def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]:
def _file_to_ref(file: File) -> str | None:
return file.id or file.related_id
files = []
for file in tool_call.files or []:
ref = _file_to_ref(file)
if ref:
files.append(ref)
return {
"id": tool_call.id,
"name": tool_call.name,
"arguments": tool_call.arguments,
"output": tool_call.output,
"files": files,
"status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status,
"elapsed_time": tool_call.elapsed_time,
}
def _generate_model_provider_icon_url(self, provider: str, dark: bool = False) -> str | None:
from yarl import URL
from configs import dify_config
icon_type = "icon_small_dark" if dark else "icon_small"
try:
return str(
URL(dify_config.CONSOLE_API_URL or "/")
/ "console"
/ "api"
/ "workspaces"
/ self.tenant_id
/ "model-providers"
/ provider
/ icon_type
/ "en_US"
)
except Exception:
return None
def _emit_model_start(self, trace_state: TraceState) -> Generator[NodeEventBase, None, None]:
if trace_state.model_start_emitted:
return
trace_state.model_start_emitted = True
if trace_state.model_segment_start_time is None:
trace_state.model_segment_start_time = time.perf_counter()
provider = self._node_data.model.provider
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_start"],
chunk="",
chunk_type=ChunkType.MODEL_START,
is_final=False,
model_provider=provider,
model_name=self._node_data.model.name,
model_icon=self._generate_model_provider_icon_url(provider),
model_icon_dark=self._generate_model_provider_icon_url(provider, dark=True),
)
def _flush_model_segment(
self,
buffers: StreamBuffers,
trace_state: TraceState,
error: str | None = None,
) -> Generator[NodeEventBase, None, None]:
if not buffers.pending_thought and not buffers.pending_content and not buffers.pending_tool_calls:
return
now = time.perf_counter()
duration = now - trace_state.model_segment_start_time if trace_state.model_segment_start_time else 0.0
usage = trace_state.pending_usage
provider = self._node_data.model.provider
model_name = self._node_data.model.name
model_icon = self._generate_model_provider_icon_url(provider)
model_icon_dark = self._generate_model_provider_icon_url(provider, dark=True)
trace_state.trace_segments.append(
LLMTraceSegment(
type="model",
duration=duration,
usage=usage,
output=ModelTraceSegment(
text="".join(buffers.pending_content) if buffers.pending_content else None,
reasoning="".join(buffers.pending_thought) if buffers.pending_thought else None,
tool_calls=list(buffers.pending_tool_calls),
),
provider=provider,
name=model_name,
icon=model_icon,
icon_dark=model_icon_dark,
error=error,
status="error" if error else "success",
)
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_end"],
chunk="",
chunk_type=ChunkType.MODEL_END,
is_final=False,
model_usage=usage,
model_duration=duration,
)
buffers.pending_thought.clear()
buffers.pending_content.clear()
buffers.pending_tool_calls.clear()
trace_state.model_segment_start_time = None
trace_state.model_start_emitted = False
trace_state.pending_usage = None
def _handle_agent_log_output(
self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext
) -> Generator[NodeEventBase, None, None]:
from core.agent.entities import AgentLog
payload = ToolLogPayload.from_log(output)
agent_log_event = AgentLogEvent(
message_id=output.id,
label=output.label,
node_execution_id=self.id,
parent_id=output.parent_id,
error=output.error,
status=output.status.value,
data=output.data,
metadata={k.value: v for k, v in output.metadata.items()},
node_id=self._node_id,
)
for log in agent_context.agent_logs:
if log.message_id == agent_log_event.message_id:
log.data = agent_log_event.data
log.status = agent_log_event.status
log.error = agent_log_event.error
log.label = agent_log_event.label
log.metadata = agent_log_event.metadata
break
else:
agent_context.agent_logs.append(agent_log_event)
if output.log_type == AgentLog.LogType.THOUGHT and output.status == AgentLog.LogStatus.SUCCESS:
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) if output.metadata else None
if llm_usage:
trace_state.pending_usage = llm_usage
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
yield from self._emit_model_start(trace_state)
tool_name = payload.tool_name
tool_call_id = payload.tool_call_id
tool_arguments = json.dumps(payload.tool_args or {})
tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None
tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
icon=tool_icon,
icon_dark=tool_icon_dark,
),
is_final=False,
)
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
tool_name = payload.tool_name
tool_output = payload.tool_output
tool_call_id = payload.tool_call_id
tool_files = payload.files if isinstance(payload.files, list) else []
tool_error = payload.tool_error
tool_arguments = json.dumps(payload.tool_args or {})
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
yield from self._flush_model_segment(buffers, trace_state)
if output.status == AgentLog.LogStatus.ERROR:
tool_error = output.error or payload.tool_error
if not tool_error and payload.meta:
tool_error = payload.meta.get("error")
else:
if payload.meta:
meta_error = payload.meta.get("error")
if meta_error:
tool_error = meta_error
elapsed_time = output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None
tool_provider = output.metadata.get(AgentLog.LogMetadata.PROVIDER) if output.metadata else None
tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None
tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None
result_str = str(tool_output) if tool_output is not None else None
tool_status: Literal["success", "error"] = "error" if tool_error else "success"
tool_call_segment = LLMTraceSegment(
type="tool",
duration=elapsed_time or 0.0,
usage=None,
output=ToolTraceSegment(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
output=result_str,
),
provider=tool_provider,
name=tool_name,
icon=tool_icon,
icon_dark=tool_icon_dark,
error=str(tool_error) if tool_error else None,
status=tool_status,
)
trace_state.trace_segments.append(tool_call_segment)
if tool_call_id:
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
trace_state.model_segment_start_time = time.perf_counter()
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk=result_str or "",
tool_result=ToolResult(
id=tool_call_id,
name=tool_name,
output=result_str,
files=tool_files,
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
elapsed_time=elapsed_time,
icon=tool_icon,
icon_dark=tool_icon_dark,
provider=tool_provider,
),
is_final=False,
)
if buffers.current_turn_reasoning:
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
buffers.current_turn_reasoning.clear()
def _handle_llm_chunk_output(
self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
) -> Generator[NodeEventBase, None, None]:
message = output.delta.message
if message and message.content:
chunk_text = message.content
if isinstance(chunk_text, list):
chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text)
else:
chunk_text = str(chunk_text)
for kind, segment in buffers.think_parser.process(chunk_text):
if not segment and kind not in {"thought_start", "thought_end"}:
continue
yield from self._emit_model_start(trace_state)
if kind == "thought_start":
yield ThoughtStartChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=False,
)
elif kind == "thought":
buffers.current_turn_reasoning.append(segment)
buffers.pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
elif kind == "thought_end":
yield ThoughtEndChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=False,
)
else:
aggregate.text += segment
buffers.pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
if output.delta.usage:
self._accumulate_usage(aggregate.usage, output.delta.usage)
if output.delta.finish_reason:
aggregate.finish_reason = output.delta.finish_reason
def _flush_remaining_stream(
self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
) -> Generator[NodeEventBase, None, None]:
for kind, segment in buffers.think_parser.flush():
if not segment and kind not in {"thought_start", "thought_end"}:
continue
yield from self._emit_model_start(trace_state)
if kind == "thought_start":
yield ThoughtStartChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=False,
)
elif kind == "thought":
buffers.current_turn_reasoning.append(segment)
buffers.pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
elif kind == "thought_end":
yield ThoughtEndChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=False,
)
else:
aggregate.text += segment
buffers.pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
if buffers.current_turn_reasoning:
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
if trace_state.pending_usage is None:
trace_state.pending_usage = aggregate.usage
yield from self._flush_model_segment(buffers, trace_state)
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk="",
tool_call=ToolCall(
id="",
name="",
arguments="",
),
is_final=True,
)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk="",
tool_result=ToolResult(
id="",
name="",
output="",
files=[],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_start"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_end"],
chunk="",
is_final=True,
)
def _build_generation_data(
self,
trace_state: TraceState,
agent_context: AgentContext,
aggregate: AggregatedResult,
buffers: StreamBuffers,
) -> LLMGenerationData:
from core.agent.entities import AgentLog
sequence: list[dict[str, Any]] = []
reasoning_index = 0
content_position = 0
tool_call_seen_index: dict[str, int] = {}
for trace_segment in trace_state.trace_segments:
if trace_segment.type == "thought":
sequence.append({"type": "reasoning", "index": reasoning_index})
reasoning_index += 1
elif trace_segment.type == "content":
segment_text = trace_segment.text or ""
start = content_position
end = start + len(segment_text)
sequence.append({"type": "content", "start": start, "end": end})
content_position = end
elif trace_segment.type == "tool_call":
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
if tool_id not in tool_call_seen_index:
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
tool_calls_for_generation: list[ToolCallResult] = []
for log in agent_context.agent_logs:
payload = ToolLogPayload.from_mapping(log.data or {})
tool_call_id = payload.tool_call_id
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
continue
tool_args = payload.tool_args
log_error = payload.tool_error
log_output = payload.tool_output
result_text = log_output or log_error or ""
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
tool_calls_for_generation.append(
ToolCallResult(
id=tool_call_id,
name=payload.tool_name,
arguments=json.dumps(tool_args) if tool_args else "",
output=result_text,
status=status,
elapsed_time=log.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if log.metadata else None,
)
)
tool_calls_for_generation.sort(
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
)
return LLMGenerationData(
text=aggregate.text,
reasoning_contents=buffers.reasoning_per_turn,
tool_calls=tool_calls_for_generation,
sequence=sequence,
usage=aggregate.usage,
finish_reason=aggregate.finish_reason,
files=aggregate.files,
trace=trace_state.trace_segments,
)
def _process_tool_outputs(
self,
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
) -> Generator[NodeEventBase, None, LLMGenerationData]:
from core.agent.entities import AgentLog, AgentResult
state = ToolOutputState()
try:
for output in outputs:
if isinstance(output, AgentLog):
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
else:
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
except StopIteration as exception:
if isinstance(getattr(exception, "value", None), AgentResult):
state.agent.agent_result = exception.value
if state.agent.agent_result:
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
state.aggregate.files = state.agent.agent_result.files
if state.agent.agent_result.usage:
state.aggregate.usage = state.agent.agent_result.usage
if state.agent.agent_result.finish_reason:
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
yield from self._close_streams()
return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream)
def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None:
total_usage.prompt_tokens += delta_usage.prompt_tokens
total_usage.completion_tokens += delta_usage.completion_tokens
total_usage.total_tokens += delta_usage.total_tokens
total_usage.prompt_price += delta_usage.prompt_price
total_usage.completion_price += delta_usage.completion_price
total_usage.total_price += delta_usage.total_price
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
case _:
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
from core.helper.code_executor import CodeExecutor, CodeLanguage
if not template:
return ""
jinja2_inputs = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinja2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: BaseMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: BaseMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
)
prompt_messages.append(prompt_message)
return prompt_messages