refactor: llm decouple code executor module (#33400)

Co-authored-by: Byron.wang <byron@dify.ai>
This commit is contained in:
wangxiaolei 2026-03-16 10:06:14 +08:00 committed by GitHub
parent a6163f80d1
commit 6ef69ff880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 603 additions and 414 deletions

View File

@ -103,7 +103,6 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager

View File

@ -45,6 +45,7 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig
from dify_graph.nodes.http_request import build_http_request_config
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.protocols import TemplateRenderer
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
from dify_graph.nodes.template_transform.template_renderer import (
@ -228,6 +229,16 @@ class DefaultWorkflowCodeExecutor:
return isinstance(error, CodeExecutionError)
class DefaultLLMTemplateRenderer(TemplateRenderer):
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=inputs,
)
return str(result.get("result", ""))
@final
class DifyNodeFactory(NodeFactory):
"""
@ -254,6 +265,7 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
@ -391,6 +403,8 @@ class DifyNodeFactory(NodeFactory):
model_instance=model_instance,
),
}
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
node_init_kwargs["template_renderer"] = self._llm_template_renderer
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
return node_init_kwargs

View File

@ -1,34 +1,53 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import cast
from typing import Any, cast
from core.model_manager import ModelInstance
from dify_graph.file import FileType, file_manager
from dify_graph.file.models import File
from dify_graph.model_runtime.entities import PromptMessageRole
from dify_graph.model_runtime.entities.message_entities import (
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageRole,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from dify_graph.variables import ArrayFileSegment, FileSegment
from dify_graph.variables.segments import ArrayAnySegment, NoneSegment
from .exc import InvalidVariableTypeError
from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig
from .exc import (
InvalidVariableTypeError,
MemoryRolePrefixRequiredError,
NoPromptFoundError,
TemplateTypeNotSupportError,
)
from .protocols import TemplateRenderer
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
model_instance.model_name,
model_instance.credentials,
dict(model_instance.credentials),
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
return model_schema
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]:
variable = variable_pool.get(selector)
if variable is None:
return []
@ -89,3 +108,366 @@ def fetch_memory_text(
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: PromptMessageMemory | None = None,
model_instance: ModelInstance,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
context_files: list[File] | None = None,
template_renderer: TemplateRenderer | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
prompt_messages.extend(
handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
template_renderer=template_renderer,
)
)
prompt_messages.extend(
handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
)
if sys_query:
prompt_messages.extend(
handle_list_messages(
messages=[
LLMNodeChatModelMessage(
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
template_renderer=template_renderer,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
prompt_messages.extend(
handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
)
memory_text = handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
prompt_content = prompt_messages[0].content
if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
content_item.data = memory_text + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
if sys_query:
if isinstance(prompt_content, str):
prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
_append_file_prompts(
prompt_messages=prompt_messages,
files=sys_files,
vision_enabled=vision_enabled,
vision_detail=vision_detail,
)
_append_file_prompts(
prompt_messages=prompt_messages,
files=context_files or [],
vision_enabled=vision_enabled,
vision_detail=vision_detail,
)
filtered_prompt_messages: list[PromptMessage] = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
if not model_schema.features:
if content_item.type == PromptMessageContentType.TEXT:
prompt_message_content.append(content_item)
continue
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if prompt_message_content:
prompt_message.content = prompt_message_content
filtered_prompt_messages.append(prompt_message)
elif not prompt_message.is_empty():
filtered_prompt_messages.append(prompt_message)
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
template_renderer: TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = render_jinja2_message(
template=message.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
prompt_messages.append(
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)],
role=message.role,
)
)
continue
template = message.text.replace("{#context#}", context) if context else message.text
segment_group = variable_pool.convert_template(template)
file_contents: list[PromptMessageContentUnionTypes] = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
)
if segment_group.text:
prompt_messages.append(
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=segment_group.text)],
role=message.role,
)
)
if file_contents:
prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role))
return prompt_messages
def render_jinja2_message(
*,
template: str,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
template_renderer: TemplateRenderer | None = None,
) -> str:
if not template:
return ""
if template_renderer is None:
raise ValueError("template_renderer is required for jinja2 prompt rendering")
jinja2_inputs: dict[str, Any] = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs)
def handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
template_renderer: TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
if template.edition_type == "jinja2":
result_text = render_jinja2_message(
template=template.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
else:
template_text = template.text.replace("{#context#}", context) if context else template.text
result_text = variable_pool.convert_template(template_text).text
return [
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)],
role=PromptMessageRole.USER,
)
]
def combine_message_content_with_role(
*,
contents: str | list[PromptMessageContentUnionTypes] | None = None,
role: PromptMessageRole,
) -> PromptMessage:
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
case _:
raise NotImplementedError(f"Role {role} is not supported")
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int:
rest_tokens = 2000
runtime_model_schema = fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def handle_memory_chat_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> Sequence[PromptMessage]:
if not memory or not memory_config:
return []
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
return memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
def handle_memory_completion_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> str:
if not memory or not memory_config:
return ""
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
return fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
def _append_file_prompts(
*,
prompt_messages: list[PromptMessage],
files: Sequence[File],
vision_enabled: bool,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> None:
if not vision_enabled or not files:
return
file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files]
if (
prompt_messages
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
existing_contents = prompt_messages[-1].content
assert isinstance(existing_contents, list)
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))

View File

@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelInstance
@ -28,11 +27,10 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.llm_entities import (
@ -43,14 +41,7 @@ from dify_graph.model_runtime.entities.llm_entities import (
LLMStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
@ -64,13 +55,12 @@ from dify_graph.node_events import (
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
@ -89,9 +79,6 @@ from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMNodeError,
MemoryRolePrefixRequiredError,
NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
@ -118,6 +105,7 @@ class LLMNode(Node[LLMNodeData]):
_model_factory: ModelFactory
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
def __init__(
self,
@ -130,6 +118,7 @@ class LLMNode(Node[LLMNodeData]):
model_factory: ModelFactory,
model_instance: ModelInstance,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@ -146,6 +135,7 @@ class LLMNode(Node[LLMNodeData]):
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
@ -240,6 +230,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
context_files=context_files,
template_renderer=self._template_renderer,
)
# handle invoke result
@ -773,182 +764,24 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
context_files: list[File] | None = None,
template_renderer: TemplateRenderer | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if sys_query:
message = LLMNodeChatModelMessage(
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
content_item.data = memory_text + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if sys_query:
if isinstance(prompt_content, str):
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message)
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
return llm_utils.fetch_prompt_messages(
sys_query=sys_query,
sys_files=sys_files,
context=context,
memory=memory,
model_instance=model_instance,
prompt_template=prompt_template,
stop=stop,
memory_config=memory_config,
vision_enabled=vision_enabled,
vision_detail=vision_detail,
variable_pool=variable_pool,
jinja2_variables=jinja2_variables,
context_files=context_files,
template_renderer=template_renderer,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -1048,59 +881,16 @@ class LLMNode(Node[LLMNodeData]):
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
template_renderer: TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
return llm_utils.handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail_config,
template_renderer=template_renderer,
)
@staticmethod
def handle_blocking_result(
@ -1239,152 +1029,3 @@ class LLMNode(Node[LLMNodeData]):
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
case _:
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinja2_inputs = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinja2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _calculate_rest_token(
*,
prompt_messages: list[PromptMessage],
model_instance: ModelInstance,
) -> int:
rest_tokens = 2000
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
)
prompt_messages.append(prompt_message)
return prompt_messages

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_manager import ModelInstance
@ -19,3 +20,11 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class TemplateRenderer(Protocol):
"""Port for rendering prompt templates used by LLM-compatible nodes."""
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
"""Render the given Jinja2 template into plain text."""
...

View File

@ -28,7 +28,7 @@ from dify_graph.nodes.llm import (
llm_utils,
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -59,6 +59,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
def __init__(
self,
@ -71,6 +72,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_factory: "ModelFactory",
model_instance: ModelInstance,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@ -87,6 +89,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
@ -142,7 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_messages, stop = llm_utils.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
@ -153,6 +156,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
template_renderer=self._template_renderer,
)
result_text = ""
@ -287,7 +291,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages, _ = LLMNode.fetch_prompt_messages(
prompt_messages, _ = llm_utils.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
sys_files=[],
@ -300,6 +304,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
template_renderer=self._template_renderer,
)
rest_tokens = 2000

View File

@ -10,7 +10,7 @@ from core.model_manager import ModelInstance
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.node_events import StreamCompletedEvent
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
@ -75,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
template_renderer=MagicMock(spec=TemplateRenderer),
http_client=MagicMock(spec=HttpClientProtocol),
)
@ -158,7 +159,7 @@ def test_execute_llm():
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
def mock_fetch_prompt_messages_1(*_args, **_kwargs):
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [

View File

@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode
from dify_graph.nodes.http_request import HttpRequestNode
from dify_graph.nodes.llm import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.question_classifier import QuestionClassifierNode
@ -68,6 +68,8 @@ class MockNodeMixin:
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
# LLM-like nodes now require an http_client; provide a mock by default for tests.
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
if isinstance(self, (LLMNode, QuestionClassifierNode)):
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
# Ensure TemplateTransformNode receives a renderer now required by constructor
if isinstance(self, TemplateTransformNode):

View File

@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
VisionConfigOptions,
)
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@ -107,6 +107,7 @@ def llm_node(
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@ -121,6 +122,7 @@ def llm_node(
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
template_renderer=mock_template_renderer,
http_client=http_client,
)
return node
@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node):
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
messages = [
LLMNodeChatModelMessage(
text="",
jinja2_text="Hello, {{ name }}",
role=PromptMessageRole.USER,
edition_type="jinja2",
)
]
result = llm_node.handle_list_messages(
messages=messages,
context=None,
jinja2_variables=[],
variable_pool=llm_node.graph_runtime_state.variable_pool,
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
template_renderer=llm_node._template_renderer,
)
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
llm_node._template_renderer.render_jinja2.assert_called_once_with(
template="Hello, {{ name }}",
inputs={},
)
def test_handle_memory_completion_mode_uses_prompt_message_interface():
memory = mock.MagicMock(spec=MockTokenBufferMemory)
memory.get_history_prompt_messages.return_value = [
@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
window=MemoryConfig.WindowConfig(enabled=True, size=3),
)
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = _handle_memory_completion_mode(
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = llm_utils.handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
template_renderer=mock_template_renderer,
http_client=http_client,
)
return node, mock_file_saver

View File

@ -1,5 +1,14 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from dify_graph.model_runtime.entities import ImagePromptMessageContent
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.question_classifier import (
QuestionClassifierNode,
QuestionClassifierNodeData,
)
from tests.workflow_test_utils import build_test_graph_init_params
def test_init_question_classifier_node_data():
@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config():
assert node_data.vision.enabled == False
assert node_data.vision.configs.variable_selector == ["sys", "files"]
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
node_data = QuestionClassifierNodeData.model_validate(
{
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"instruction": "This is a test instruction",
}
)
template_renderer = MagicMock(spec=TemplateRenderer)
node = QuestionClassifierNode(
id="node-id",
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
graph_init_params=build_test_graph_init_params(
workflow_id="workflow-id",
graph_config={},
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
),
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(),
http_client=MagicMock(spec=HttpClientProtocol),
llm_file_saver=MagicMock(),
template_renderer=template_renderer,
)
fetch_prompt_messages = MagicMock(return_value=([], None))
monkeypatch.setattr(
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
fetch_prompt_messages,
)
monkeypatch.setattr(
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
)
node._calculate_rest_token(
node_data=node_data,
query="hello",
model_instance=MagicMock(stop=(), parameters={}),
context="",
)
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer

View File

@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor:
assert executor.is_execution_error(RuntimeError("boom")) is False
class TestDefaultLLMTemplateRenderer:
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
renderer = node_factory.DefaultLLMTemplateRenderer()
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
monkeypatch.setattr(
node_factory.CodeExecutor,
"execute_workflow_code_template",
execute_workflow_code_template,
)
result = renderer.render_jinja2(
template="Hello {{ name }}",
inputs={"name": "world"},
)
assert result == "hello world"
execute_workflow_code_template.assert_called_once_with(
language=CodeLanguage.JINJA2,
code="Hello {{ name }}",
inputs={"name": "world"},
)
class TestDifyNodeFactoryInit:
def test_init_builds_default_dependencies(self):
graph_init_params = SimpleNamespace(run_context={"context": "value"})
@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit:
http_request_config = sentinel.http_request_config
credentials_provider = sentinel.credentials_provider
model_factory = sentinel.model_factory
llm_template_renderer = sentinel.llm_template_renderer
with (
patch.object(
@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit:
"build_http_request_config",
return_value=http_request_config,
),
patch.object(
node_factory,
"DefaultLLMTemplateRenderer",
return_value=llm_template_renderer,
) as llm_renderer_factory,
patch.object(
node_factory,
"build_dify_model_access",
@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit:
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
build_dify_model_access.assert_called_once_with("tenant-id")
renderer_factory.assert_called_once()
llm_renderer_factory.assert_called_once()
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
assert factory.graph_init_params is graph_init_params
assert factory.graph_runtime_state is graph_runtime_state
assert factory._dify_context is dify_context
assert factory._template_renderer is template_renderer
assert factory._llm_template_renderer is llm_template_renderer
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
assert factory._http_request_config is http_request_config
assert factory._llm_credentials_provider is credentials_provider
@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode:
factory._code_executor = sentinel.code_executor
factory._code_limits = sentinel.code_limits
factory._template_renderer = sentinel.template_renderer
factory._llm_template_renderer = sentinel.llm_template_renderer
factory._template_transform_max_output_length = 2048
factory._http_request_http_client = sentinel.http_client
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode:
@pytest.mark.parametrize(
("node_type", "constructor_name", "expected_extra_kwargs"),
[
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
(
BuiltinNodeTypes.LLM,
"LLMNode",
{
"http_client": sentinel.http_client,
"template_renderer": sentinel.llm_template_renderer,
},
),
(
BuiltinNodeTypes.QUESTION_CLASSIFIER,
"QuestionClassifierNode",
{
"http_client": sentinel.http_client,
"template_renderer": sentinel.llm_template_renderer,
},
),
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
],
)