mirror of https://github.com/langgenius/dify.git
refactor: llm decouple code executor module (#33400)
Co-authored-by: Byron.wang <byron@dify.ai>
This commit is contained in:
parent
a6163f80d1
commit
6ef69ff880
|
|
@ -103,7 +103,6 @@ ignore_imports =
|
||||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
|
||||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||||
dify_graph.nodes.llm.node -> core.model_manager
|
dify_graph.nodes.llm.node -> core.model_manager
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||||
from dify_graph.nodes.http_request import build_http_request_config
|
from dify_graph.nodes.http_request import build_http_request_config
|
||||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||||
|
from dify_graph.nodes.llm.protocols import TemplateRenderer
|
||||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||||
from dify_graph.nodes.template_transform.template_renderer import (
|
from dify_graph.nodes.template_transform.template_renderer import (
|
||||||
|
|
@ -228,6 +229,16 @@ class DefaultWorkflowCodeExecutor:
|
||||||
return isinstance(error, CodeExecutionError)
|
return isinstance(error, CodeExecutionError)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultLLMTemplateRenderer(TemplateRenderer):
|
||||||
|
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
|
||||||
|
result = CodeExecutor.execute_workflow_code_template(
|
||||||
|
language=CodeLanguage.JINJA2,
|
||||||
|
code=template,
|
||||||
|
inputs=inputs,
|
||||||
|
)
|
||||||
|
return str(result.get("result", ""))
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class DifyNodeFactory(NodeFactory):
|
class DifyNodeFactory(NodeFactory):
|
||||||
"""
|
"""
|
||||||
|
|
@ -254,6 +265,7 @@ class DifyNodeFactory(NodeFactory):
|
||||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||||
)
|
)
|
||||||
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
|
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
|
||||||
|
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
|
||||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||||
self._http_request_http_client = ssrf_proxy
|
self._http_request_http_client = ssrf_proxy
|
||||||
self._http_request_tool_file_manager_factory = ToolFileManager
|
self._http_request_tool_file_manager_factory = ToolFileManager
|
||||||
|
|
@ -391,6 +403,8 @@ class DifyNodeFactory(NodeFactory):
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||||
|
node_init_kwargs["template_renderer"] = self._llm_template_renderer
|
||||||
if include_http_client:
|
if include_http_client:
|
||||||
node_init_kwargs["http_client"] = self._http_request_http_client
|
node_init_kwargs["http_client"] = self._http_request_http_client
|
||||||
return node_init_kwargs
|
return node_init_kwargs
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,53 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
|
from dify_graph.file import FileType, file_manager
|
||||||
from dify_graph.file.models import File
|
from dify_graph.file.models import File
|
||||||
from dify_graph.model_runtime.entities import PromptMessageRole
|
from dify_graph.model_runtime.entities import (
|
||||||
from dify_graph.model_runtime.entities.message_entities import (
|
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageRole,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
)
|
)
|
||||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
from dify_graph.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageContentUnionTypes,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from dify_graph.nodes.base.entities import VariableSelector
|
||||||
from dify_graph.runtime import VariablePool
|
from dify_graph.runtime import VariablePool
|
||||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
from dify_graph.variables import ArrayFileSegment, FileSegment
|
||||||
|
from dify_graph.variables.segments import ArrayAnySegment, NoneSegment
|
||||||
|
|
||||||
from .exc import InvalidVariableTypeError
|
from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig
|
||||||
|
from .exc import (
|
||||||
|
InvalidVariableTypeError,
|
||||||
|
MemoryRolePrefixRequiredError,
|
||||||
|
NoPromptFoundError,
|
||||||
|
TemplateTypeNotSupportError,
|
||||||
|
)
|
||||||
|
from .protocols import TemplateRenderer
|
||||||
|
|
||||||
|
|
||||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||||
model_instance.model_name,
|
model_instance.model_name,
|
||||||
model_instance.credentials,
|
dict(model_instance.credentials),
|
||||||
)
|
)
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||||
return model_schema
|
return model_schema
|
||||||
|
|
||||||
|
|
||||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]:
|
||||||
variable = variable_pool.get(selector)
|
variable = variable_pool.get(selector)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -89,3 +108,366 @@ def fetch_memory_text(
|
||||||
human_prefix=human_prefix,
|
human_prefix=human_prefix,
|
||||||
ai_prefix=ai_prefix,
|
ai_prefix=ai_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_prompt_messages(
|
||||||
|
*,
|
||||||
|
sys_query: str | None = None,
|
||||||
|
sys_files: Sequence[File],
|
||||||
|
context: str | None = None,
|
||||||
|
memory: PromptMessageMemory | None = None,
|
||||||
|
model_instance: ModelInstance,
|
||||||
|
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||||
|
stop: Sequence[str] | None = None,
|
||||||
|
memory_config: MemoryConfig | None = None,
|
||||||
|
vision_enabled: bool = False,
|
||||||
|
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
context_files: list[File] | None = None,
|
||||||
|
template_renderer: TemplateRenderer | None = None,
|
||||||
|
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||||
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
model_schema = fetch_model_schema(model_instance=model_instance)
|
||||||
|
|
||||||
|
if isinstance(prompt_template, list):
|
||||||
|
prompt_messages.extend(
|
||||||
|
handle_list_messages(
|
||||||
|
messages=prompt_template,
|
||||||
|
context=context,
|
||||||
|
jinja2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
vision_detail_config=vision_detail,
|
||||||
|
template_renderer=template_renderer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_messages.extend(
|
||||||
|
handle_memory_chat_mode(
|
||||||
|
memory=memory,
|
||||||
|
memory_config=memory_config,
|
||||||
|
model_instance=model_instance,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if sys_query:
|
||||||
|
prompt_messages.extend(
|
||||||
|
handle_list_messages(
|
||||||
|
messages=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=sys_query,
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
context="",
|
||||||
|
jinja2_variables=[],
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
vision_detail_config=vision_detail,
|
||||||
|
template_renderer=template_renderer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||||
|
prompt_messages.extend(
|
||||||
|
handle_completion_template(
|
||||||
|
template=prompt_template,
|
||||||
|
context=context,
|
||||||
|
jinja2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
template_renderer=template_renderer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_text = handle_memory_completion_mode(
|
||||||
|
memory=memory,
|
||||||
|
memory_config=memory_config,
|
||||||
|
model_instance=model_instance,
|
||||||
|
)
|
||||||
|
prompt_content = prompt_messages[0].content
|
||||||
|
if isinstance(prompt_content, str):
|
||||||
|
prompt_content = str(prompt_content)
|
||||||
|
if "#histories#" in prompt_content:
|
||||||
|
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||||
|
else:
|
||||||
|
prompt_content = memory_text + "\n" + prompt_content
|
||||||
|
prompt_messages[0].content = prompt_content
|
||||||
|
elif isinstance(prompt_content, list):
|
||||||
|
for content_item in prompt_content:
|
||||||
|
if isinstance(content_item, TextPromptMessageContent):
|
||||||
|
if "#histories#" in content_item.data:
|
||||||
|
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||||
|
else:
|
||||||
|
content_item.data = memory_text + "\n" + content_item.data
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid prompt content type")
|
||||||
|
|
||||||
|
if sys_query:
|
||||||
|
if isinstance(prompt_content, str):
|
||||||
|
prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||||
|
elif isinstance(prompt_content, list):
|
||||||
|
for content_item in prompt_content:
|
||||||
|
if isinstance(content_item, TextPromptMessageContent):
|
||||||
|
content_item.data = sys_query + "\n" + content_item.data
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid prompt content type")
|
||||||
|
else:
|
||||||
|
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||||
|
|
||||||
|
_append_file_prompts(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
files=sys_files,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
vision_detail=vision_detail,
|
||||||
|
)
|
||||||
|
_append_file_prompts(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
files=context_files or [],
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
vision_detail=vision_detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered_prompt_messages: list[PromptMessage] = []
|
||||||
|
for prompt_message in prompt_messages:
|
||||||
|
if isinstance(prompt_message.content, list):
|
||||||
|
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||||
|
for content_item in prompt_message.content:
|
||||||
|
if not model_schema.features:
|
||||||
|
if content_item.type == PromptMessageContentType.TEXT:
|
||||||
|
prompt_message_content.append(content_item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
(
|
||||||
|
content_item.type == PromptMessageContentType.IMAGE
|
||||||
|
and ModelFeature.VISION not in model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.DOCUMENT
|
||||||
|
and ModelFeature.DOCUMENT not in model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.VIDEO
|
||||||
|
and ModelFeature.VIDEO not in model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.AUDIO
|
||||||
|
and ModelFeature.AUDIO not in model_schema.features
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
prompt_message_content.append(content_item)
|
||||||
|
if prompt_message_content:
|
||||||
|
prompt_message.content = prompt_message_content
|
||||||
|
filtered_prompt_messages.append(prompt_message)
|
||||||
|
elif not prompt_message.is_empty():
|
||||||
|
filtered_prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
|
if len(filtered_prompt_messages) == 0:
|
||||||
|
raise NoPromptFoundError(
|
||||||
|
"No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding."
|
||||||
|
)
|
||||||
|
|
||||||
|
return filtered_prompt_messages, stop
|
||||||
|
|
||||||
|
|
||||||
|
def handle_list_messages(
|
||||||
|
*,
|
||||||
|
messages: Sequence[LLMNodeChatModelMessage],
|
||||||
|
context: str | None,
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||||
|
template_renderer: TemplateRenderer | None = None,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
for message in messages:
|
||||||
|
if message.edition_type == "jinja2":
|
||||||
|
result_text = render_jinja2_message(
|
||||||
|
template=message.jinja2_text or "",
|
||||||
|
jinja2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
template_renderer=template_renderer,
|
||||||
|
)
|
||||||
|
prompt_messages.append(
|
||||||
|
combine_message_content_with_role(
|
||||||
|
contents=[TextPromptMessageContent(data=result_text)],
|
||||||
|
role=message.role,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
template = message.text.replace("{#context#}", context) if context else message.text
|
||||||
|
segment_group = variable_pool.convert_template(template)
|
||||||
|
file_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
|
for segment in segment_group.value:
|
||||||
|
if isinstance(segment, ArrayFileSegment):
|
||||||
|
for file in segment.value:
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||||
|
file_contents.append(
|
||||||
|
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
|
||||||
|
)
|
||||||
|
elif isinstance(segment, FileSegment):
|
||||||
|
file = segment.value
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||||
|
file_contents.append(
|
||||||
|
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
|
||||||
|
)
|
||||||
|
|
||||||
|
if segment_group.text:
|
||||||
|
prompt_messages.append(
|
||||||
|
combine_message_content_with_role(
|
||||||
|
contents=[TextPromptMessageContent(data=segment_group.text)],
|
||||||
|
role=message.role,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if file_contents:
|
||||||
|
prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role))
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
|
def render_jinja2_message(
|
||||||
|
*,
|
||||||
|
template: str,
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
template_renderer: TemplateRenderer | None = None,
|
||||||
|
) -> str:
|
||||||
|
if not template:
|
||||||
|
return ""
|
||||||
|
if template_renderer is None:
|
||||||
|
raise ValueError("template_renderer is required for jinja2 prompt rendering")
|
||||||
|
|
||||||
|
jinja2_inputs: dict[str, Any] = {}
|
||||||
|
for jinja2_variable in jinja2_variables:
|
||||||
|
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||||
|
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||||
|
return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_completion_template(
|
||||||
|
*,
|
||||||
|
template: LLMNodeCompletionModelPromptTemplate,
|
||||||
|
context: str | None,
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
template_renderer: TemplateRenderer | None = None,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
if template.edition_type == "jinja2":
|
||||||
|
result_text = render_jinja2_message(
|
||||||
|
template=template.jinja2_text or "",
|
||||||
|
jinja2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
template_renderer=template_renderer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
template_text = template.text.replace("{#context#}", context) if context else template.text
|
||||||
|
result_text = variable_pool.convert_template(template_text).text
|
||||||
|
return [
|
||||||
|
combine_message_content_with_role(
|
||||||
|
contents=[TextPromptMessageContent(data=result_text)],
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def combine_message_content_with_role(
|
||||||
|
*,
|
||||||
|
contents: str | list[PromptMessageContentUnionTypes] | None = None,
|
||||||
|
role: PromptMessageRole,
|
||||||
|
) -> PromptMessage:
|
||||||
|
match role:
|
||||||
|
case PromptMessageRole.USER:
|
||||||
|
return UserPromptMessage(content=contents)
|
||||||
|
case PromptMessageRole.ASSISTANT:
|
||||||
|
return AssistantPromptMessage(content=contents)
|
||||||
|
case PromptMessageRole.SYSTEM:
|
||||||
|
return SystemPromptMessage(content=contents)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(f"Role {role} is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int:
|
||||||
|
rest_tokens = 2000
|
||||||
|
runtime_model_schema = fetch_model_schema(model_instance=model_instance)
|
||||||
|
runtime_model_parameters = model_instance.parameters
|
||||||
|
|
||||||
|
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
|
if model_context_tokens:
|
||||||
|
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
|
|
||||||
|
max_tokens = 0
|
||||||
|
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||||
|
if parameter_rule.name == "max_tokens" or (
|
||||||
|
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||||
|
):
|
||||||
|
max_tokens = (
|
||||||
|
runtime_model_parameters.get(parameter_rule.name)
|
||||||
|
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||||
|
rest_tokens = max(rest_tokens, 0)
|
||||||
|
|
||||||
|
return rest_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def handle_memory_chat_mode(
|
||||||
|
*,
|
||||||
|
memory: PromptMessageMemory | None,
|
||||||
|
memory_config: MemoryConfig | None,
|
||||||
|
model_instance: ModelInstance,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
if not memory or not memory_config:
|
||||||
|
return []
|
||||||
|
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
|
||||||
|
return memory.get_history_prompt_messages(
|
||||||
|
max_token_limit=rest_tokens,
|
||||||
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_memory_completion_mode(
|
||||||
|
*,
|
||||||
|
memory: PromptMessageMemory | None,
|
||||||
|
memory_config: MemoryConfig | None,
|
||||||
|
model_instance: ModelInstance,
|
||||||
|
) -> str:
|
||||||
|
if not memory or not memory_config:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
|
||||||
|
if not memory_config.role_prefix:
|
||||||
|
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||||
|
|
||||||
|
return fetch_memory_text(
|
||||||
|
memory=memory,
|
||||||
|
max_token_limit=rest_tokens,
|
||||||
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
human_prefix=memory_config.role_prefix.user,
|
||||||
|
ai_prefix=memory_config.role_prefix.assistant,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _append_file_prompts(
|
||||||
|
*,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
files: Sequence[File],
|
||||||
|
vision_enabled: bool,
|
||||||
|
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||||
|
) -> None:
|
||||||
|
if not vision_enabled or not files:
|
||||||
|
return
|
||||||
|
|
||||||
|
file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files]
|
||||||
|
if (
|
||||||
|
prompt_messages
|
||||||
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||||
|
and isinstance(prompt_messages[-1].content, list)
|
||||||
|
):
|
||||||
|
existing_contents = prompt_messages[-1].content
|
||||||
|
assert isinstance(existing_contents, list)
|
||||||
|
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
|
||||||
from core.llm_generator.output_parser.errors import OutputParserError
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
|
|
@ -28,11 +27,10 @@ from dify_graph.enums import (
|
||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
)
|
)
|
||||||
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
|
from dify_graph.file import File, FileTransferMethod, FileType
|
||||||
from dify_graph.model_runtime.entities import (
|
from dify_graph.model_runtime.entities import (
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentType,
|
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
)
|
)
|
||||||
from dify_graph.model_runtime.entities.llm_entities import (
|
from dify_graph.model_runtime.entities.llm_entities import (
|
||||||
|
|
@ -43,14 +41,7 @@ from dify_graph.model_runtime.entities.llm_entities import (
|
||||||
LLMStructuredOutput,
|
LLMStructuredOutput,
|
||||||
LLMUsage,
|
LLMUsage,
|
||||||
)
|
)
|
||||||
from dify_graph.model_runtime.entities.message_entities import (
|
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||||
AssistantPromptMessage,
|
|
||||||
PromptMessageContentUnionTypes,
|
|
||||||
PromptMessageRole,
|
|
||||||
SystemPromptMessage,
|
|
||||||
UserPromptMessage,
|
|
||||||
)
|
|
||||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
|
||||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from dify_graph.node_events import (
|
from dify_graph.node_events import (
|
||||||
|
|
@ -64,13 +55,12 @@ from dify_graph.node_events import (
|
||||||
from dify_graph.nodes.base.entities import VariableSelector
|
from dify_graph.nodes.base.entities import VariableSelector
|
||||||
from dify_graph.nodes.base.node import Node
|
from dify_graph.nodes.base.node import Node
|
||||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||||
from dify_graph.runtime import VariablePool
|
from dify_graph.runtime import VariablePool
|
||||||
from dify_graph.variables import (
|
from dify_graph.variables import (
|
||||||
ArrayFileSegment,
|
ArrayFileSegment,
|
||||||
ArraySegment,
|
ArraySegment,
|
||||||
FileSegment,
|
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
ObjectSegment,
|
ObjectSegment,
|
||||||
StringSegment,
|
StringSegment,
|
||||||
|
|
@ -89,9 +79,6 @@ from .exc import (
|
||||||
InvalidContextStructureError,
|
InvalidContextStructureError,
|
||||||
InvalidVariableTypeError,
|
InvalidVariableTypeError,
|
||||||
LLMNodeError,
|
LLMNodeError,
|
||||||
MemoryRolePrefixRequiredError,
|
|
||||||
NoPromptFoundError,
|
|
||||||
TemplateTypeNotSupportError,
|
|
||||||
VariableNotFoundError,
|
VariableNotFoundError,
|
||||||
)
|
)
|
||||||
from .file_saver import FileSaverImpl, LLMFileSaver
|
from .file_saver import FileSaverImpl, LLMFileSaver
|
||||||
|
|
@ -118,6 +105,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
_model_factory: ModelFactory
|
_model_factory: ModelFactory
|
||||||
_model_instance: ModelInstance
|
_model_instance: ModelInstance
|
||||||
_memory: PromptMessageMemory | None
|
_memory: PromptMessageMemory | None
|
||||||
|
_template_renderer: TemplateRenderer
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -130,6 +118,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
model_factory: ModelFactory,
|
model_factory: ModelFactory,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
http_client: HttpClientProtocol,
|
http_client: HttpClientProtocol,
|
||||||
|
template_renderer: TemplateRenderer,
|
||||||
memory: PromptMessageMemory | None = None,
|
memory: PromptMessageMemory | None = None,
|
||||||
llm_file_saver: LLMFileSaver | None = None,
|
llm_file_saver: LLMFileSaver | None = None,
|
||||||
):
|
):
|
||||||
|
|
@ -146,6 +135,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
self._model_factory = model_factory
|
self._model_factory = model_factory
|
||||||
self._model_instance = model_instance
|
self._model_instance = model_instance
|
||||||
self._memory = memory
|
self._memory = memory
|
||||||
|
self._template_renderer = template_renderer
|
||||||
|
|
||||||
if llm_file_saver is None:
|
if llm_file_saver is None:
|
||||||
dify_ctx = self.require_dify_context()
|
dify_ctx = self.require_dify_context()
|
||||||
|
|
@ -240,6 +230,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||||
context_files=context_files,
|
context_files=context_files,
|
||||||
|
template_renderer=self._template_renderer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
|
|
@ -773,182 +764,24 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
context_files: list[File] | None = None,
|
context_files: list[File] | None = None,
|
||||||
|
template_renderer: TemplateRenderer | None = None,
|
||||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||||
prompt_messages: list[PromptMessage] = []
|
return llm_utils.fetch_prompt_messages(
|
||||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
sys_query=sys_query,
|
||||||
|
sys_files=sys_files,
|
||||||
if isinstance(prompt_template, list):
|
context=context,
|
||||||
# For chat model
|
memory=memory,
|
||||||
prompt_messages.extend(
|
model_instance=model_instance,
|
||||||
LLMNode.handle_list_messages(
|
prompt_template=prompt_template,
|
||||||
messages=prompt_template,
|
stop=stop,
|
||||||
context=context,
|
memory_config=memory_config,
|
||||||
jinja2_variables=jinja2_variables,
|
vision_enabled=vision_enabled,
|
||||||
variable_pool=variable_pool,
|
vision_detail=vision_detail,
|
||||||
vision_detail_config=vision_detail,
|
variable_pool=variable_pool,
|
||||||
)
|
jinja2_variables=jinja2_variables,
|
||||||
)
|
context_files=context_files,
|
||||||
|
template_renderer=template_renderer,
|
||||||
# Get memory messages for chat mode
|
)
|
||||||
memory_messages = _handle_memory_chat_mode(
|
|
||||||
memory=memory,
|
|
||||||
memory_config=memory_config,
|
|
||||||
model_instance=model_instance,
|
|
||||||
)
|
|
||||||
# Extend prompt_messages with memory messages
|
|
||||||
prompt_messages.extend(memory_messages)
|
|
||||||
|
|
||||||
# Add current query to the prompt messages
|
|
||||||
if sys_query:
|
|
||||||
message = LLMNodeChatModelMessage(
|
|
||||||
text=sys_query,
|
|
||||||
role=PromptMessageRole.USER,
|
|
||||||
edition_type="basic",
|
|
||||||
)
|
|
||||||
prompt_messages.extend(
|
|
||||||
LLMNode.handle_list_messages(
|
|
||||||
messages=[message],
|
|
||||||
context="",
|
|
||||||
jinja2_variables=[],
|
|
||||||
variable_pool=variable_pool,
|
|
||||||
vision_detail_config=vision_detail,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
|
||||||
# For completion model
|
|
||||||
prompt_messages.extend(
|
|
||||||
_handle_completion_template(
|
|
||||||
template=prompt_template,
|
|
||||||
context=context,
|
|
||||||
jinja2_variables=jinja2_variables,
|
|
||||||
variable_pool=variable_pool,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get memory text for completion model
|
|
||||||
memory_text = _handle_memory_completion_mode(
|
|
||||||
memory=memory,
|
|
||||||
memory_config=memory_config,
|
|
||||||
model_instance=model_instance,
|
|
||||||
)
|
|
||||||
# Insert histories into the prompt
|
|
||||||
prompt_content = prompt_messages[0].content
|
|
||||||
# For issue #11247 - Check if prompt content is a string or a list
|
|
||||||
if isinstance(prompt_content, str):
|
|
||||||
prompt_content = str(prompt_content)
|
|
||||||
if "#histories#" in prompt_content:
|
|
||||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
|
||||||
else:
|
|
||||||
prompt_content = memory_text + "\n" + prompt_content
|
|
||||||
prompt_messages[0].content = prompt_content
|
|
||||||
elif isinstance(prompt_content, list):
|
|
||||||
for content_item in prompt_content:
|
|
||||||
if isinstance(content_item, TextPromptMessageContent):
|
|
||||||
if "#histories#" in content_item.data:
|
|
||||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
|
||||||
else:
|
|
||||||
content_item.data = memory_text + "\n" + content_item.data
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid prompt content type")
|
|
||||||
|
|
||||||
# Add current query to the prompt message
|
|
||||||
if sys_query:
|
|
||||||
if isinstance(prompt_content, str):
|
|
||||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
|
||||||
prompt_messages[0].content = prompt_content
|
|
||||||
elif isinstance(prompt_content, list):
|
|
||||||
for content_item in prompt_content:
|
|
||||||
if isinstance(content_item, TextPromptMessageContent):
|
|
||||||
content_item.data = sys_query + "\n" + content_item.data
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid prompt content type")
|
|
||||||
else:
|
|
||||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
|
||||||
|
|
||||||
# The sys_files will be deprecated later
|
|
||||||
if vision_enabled and sys_files:
|
|
||||||
file_prompts = []
|
|
||||||
for file in sys_files:
|
|
||||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
|
||||||
file_prompts.append(file_prompt)
|
|
||||||
# If last prompt is a user prompt, add files into its contents,
|
|
||||||
# otherwise append a new user prompt
|
|
||||||
if (
|
|
||||||
len(prompt_messages) > 0
|
|
||||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
|
||||||
and isinstance(prompt_messages[-1].content, list)
|
|
||||||
):
|
|
||||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
|
||||||
else:
|
|
||||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
|
||||||
|
|
||||||
# The context_files
|
|
||||||
if vision_enabled and context_files:
|
|
||||||
file_prompts = []
|
|
||||||
for file in context_files:
|
|
||||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
|
||||||
file_prompts.append(file_prompt)
|
|
||||||
# If last prompt is a user prompt, add files into its contents,
|
|
||||||
# otherwise append a new user prompt
|
|
||||||
if (
|
|
||||||
len(prompt_messages) > 0
|
|
||||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
|
||||||
and isinstance(prompt_messages[-1].content, list)
|
|
||||||
):
|
|
||||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
|
||||||
else:
|
|
||||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
|
||||||
|
|
||||||
# Remove empty messages and filter unsupported content
|
|
||||||
filtered_prompt_messages = []
|
|
||||||
for prompt_message in prompt_messages:
|
|
||||||
if isinstance(prompt_message.content, list):
|
|
||||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
|
||||||
for content_item in prompt_message.content:
|
|
||||||
# Skip content if features are not defined
|
|
||||||
if not model_schema.features:
|
|
||||||
if content_item.type != PromptMessageContentType.TEXT:
|
|
||||||
continue
|
|
||||||
prompt_message_content.append(content_item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip content if corresponding feature is not supported
|
|
||||||
if (
|
|
||||||
(
|
|
||||||
content_item.type == PromptMessageContentType.IMAGE
|
|
||||||
and ModelFeature.VISION not in model_schema.features
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
content_item.type == PromptMessageContentType.DOCUMENT
|
|
||||||
and ModelFeature.DOCUMENT not in model_schema.features
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
content_item.type == PromptMessageContentType.VIDEO
|
|
||||||
and ModelFeature.VIDEO not in model_schema.features
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
content_item.type == PromptMessageContentType.AUDIO
|
|
||||||
and ModelFeature.AUDIO not in model_schema.features
|
|
||||||
)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
prompt_message_content.append(content_item)
|
|
||||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
|
||||||
prompt_message.content = prompt_message_content[0].data
|
|
||||||
else:
|
|
||||||
prompt_message.content = prompt_message_content
|
|
||||||
if prompt_message.is_empty():
|
|
||||||
continue
|
|
||||||
filtered_prompt_messages.append(prompt_message)
|
|
||||||
|
|
||||||
if len(filtered_prompt_messages) == 0:
|
|
||||||
raise NoPromptFoundError(
|
|
||||||
"No prompt found in the LLM configuration. "
|
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
|
||||||
)
|
|
||||||
|
|
||||||
return filtered_prompt_messages, stop
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
|
@ -1048,59 +881,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||||
|
template_renderer: TemplateRenderer | None = None,
|
||||||
) -> Sequence[PromptMessage]:
|
) -> Sequence[PromptMessage]:
|
||||||
prompt_messages: list[PromptMessage] = []
|
return llm_utils.handle_list_messages(
|
||||||
for message in messages:
|
messages=messages,
|
||||||
if message.edition_type == "jinja2":
|
context=context,
|
||||||
result_text = _render_jinja2_message(
|
jinja2_variables=jinja2_variables,
|
||||||
template=message.jinja2_text or "",
|
variable_pool=variable_pool,
|
||||||
jinja2_variables=jinja2_variables,
|
vision_detail_config=vision_detail_config,
|
||||||
variable_pool=variable_pool,
|
template_renderer=template_renderer,
|
||||||
)
|
)
|
||||||
prompt_message = _combine_message_content_with_role(
|
|
||||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
|
||||||
)
|
|
||||||
prompt_messages.append(prompt_message)
|
|
||||||
else:
|
|
||||||
# Get segment group from basic message
|
|
||||||
if context:
|
|
||||||
template = message.text.replace("{#context#}", context)
|
|
||||||
else:
|
|
||||||
template = message.text
|
|
||||||
segment_group = variable_pool.convert_template(template)
|
|
||||||
|
|
||||||
# Process segments for images
|
|
||||||
file_contents = []
|
|
||||||
for segment in segment_group.value:
|
|
||||||
if isinstance(segment, ArrayFileSegment):
|
|
||||||
for file in segment.value:
|
|
||||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
|
||||||
file_content = file_manager.to_prompt_message_content(
|
|
||||||
file, image_detail_config=vision_detail_config
|
|
||||||
)
|
|
||||||
file_contents.append(file_content)
|
|
||||||
elif isinstance(segment, FileSegment):
|
|
||||||
file = segment.value
|
|
||||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
|
||||||
file_content = file_manager.to_prompt_message_content(
|
|
||||||
file, image_detail_config=vision_detail_config
|
|
||||||
)
|
|
||||||
file_contents.append(file_content)
|
|
||||||
|
|
||||||
# Create message with text from all segments
|
|
||||||
plain_text = segment_group.text
|
|
||||||
if plain_text:
|
|
||||||
prompt_message = _combine_message_content_with_role(
|
|
||||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
|
||||||
)
|
|
||||||
prompt_messages.append(prompt_message)
|
|
||||||
|
|
||||||
if file_contents:
|
|
||||||
# Create message with image contents
|
|
||||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
|
||||||
prompt_messages.append(prompt_message)
|
|
||||||
|
|
||||||
return prompt_messages
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def handle_blocking_result(
|
def handle_blocking_result(
|
||||||
|
|
@ -1239,152 +1029,3 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
@property
|
@property
|
||||||
def model_instance(self) -> ModelInstance:
|
def model_instance(self) -> ModelInstance:
|
||||||
return self._model_instance
|
return self._model_instance
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(
|
|
||||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
|
||||||
):
|
|
||||||
match role:
|
|
||||||
case PromptMessageRole.USER:
|
|
||||||
return UserPromptMessage(content=contents)
|
|
||||||
case PromptMessageRole.ASSISTANT:
|
|
||||||
return AssistantPromptMessage(content=contents)
|
|
||||||
case PromptMessageRole.SYSTEM:
|
|
||||||
return SystemPromptMessage(content=contents)
|
|
||||||
case _:
|
|
||||||
raise NotImplementedError(f"Role {role} is not supported")
|
|
||||||
|
|
||||||
|
|
||||||
def _render_jinja2_message(
|
|
||||||
*,
|
|
||||||
template: str,
|
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
|
||||||
variable_pool: VariablePool,
|
|
||||||
):
|
|
||||||
if not template:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
jinja2_inputs = {}
|
|
||||||
for jinja2_variable in jinja2_variables:
|
|
||||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
|
||||||
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
|
||||||
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
|
||||||
language=CodeLanguage.JINJA2,
|
|
||||||
code=template,
|
|
||||||
inputs=jinja2_inputs,
|
|
||||||
)
|
|
||||||
result_text = code_execute_resp["result"]
|
|
||||||
return result_text
|
|
||||||
|
|
||||||
|
|
||||||
def _calculate_rest_token(
|
|
||||||
*,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_instance: ModelInstance,
|
|
||||||
) -> int:
|
|
||||||
rest_tokens = 2000
|
|
||||||
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
|
||||||
runtime_model_parameters = model_instance.parameters
|
|
||||||
|
|
||||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
||||||
if model_context_tokens:
|
|
||||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
max_tokens = 0
|
|
||||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
|
||||||
if parameter_rule.name == "max_tokens" or (
|
|
||||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
|
||||||
):
|
|
||||||
max_tokens = (
|
|
||||||
runtime_model_parameters.get(parameter_rule.name)
|
|
||||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
|
||||||
rest_tokens = max(rest_tokens, 0)
|
|
||||||
|
|
||||||
return rest_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_memory_chat_mode(
|
|
||||||
*,
|
|
||||||
memory: PromptMessageMemory | None,
|
|
||||||
memory_config: MemoryConfig | None,
|
|
||||||
model_instance: ModelInstance,
|
|
||||||
) -> Sequence[PromptMessage]:
|
|
||||||
memory_messages: Sequence[PromptMessage] = []
|
|
||||||
# Get messages from memory for chat model
|
|
||||||
if memory and memory_config:
|
|
||||||
rest_tokens = _calculate_rest_token(
|
|
||||||
prompt_messages=[],
|
|
||||||
model_instance=model_instance,
|
|
||||||
)
|
|
||||||
memory_messages = memory.get_history_prompt_messages(
|
|
||||||
max_token_limit=rest_tokens,
|
|
||||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
|
||||||
)
|
|
||||||
return memory_messages
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_memory_completion_mode(
|
|
||||||
*,
|
|
||||||
memory: PromptMessageMemory | None,
|
|
||||||
memory_config: MemoryConfig | None,
|
|
||||||
model_instance: ModelInstance,
|
|
||||||
) -> str:
|
|
||||||
memory_text = ""
|
|
||||||
# Get history text from memory for completion model
|
|
||||||
if memory and memory_config:
|
|
||||||
rest_tokens = _calculate_rest_token(
|
|
||||||
prompt_messages=[],
|
|
||||||
model_instance=model_instance,
|
|
||||||
)
|
|
||||||
if not memory_config.role_prefix:
|
|
||||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
|
||||||
memory_text = llm_utils.fetch_memory_text(
|
|
||||||
memory=memory,
|
|
||||||
max_token_limit=rest_tokens,
|
|
||||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
|
||||||
human_prefix=memory_config.role_prefix.user,
|
|
||||||
ai_prefix=memory_config.role_prefix.assistant,
|
|
||||||
)
|
|
||||||
return memory_text
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_completion_template(
|
|
||||||
*,
|
|
||||||
template: LLMNodeCompletionModelPromptTemplate,
|
|
||||||
context: str | None,
|
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
|
||||||
variable_pool: VariablePool,
|
|
||||||
) -> Sequence[PromptMessage]:
|
|
||||||
"""Handle completion template processing outside of LLMNode class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: The completion model prompt template
|
|
||||||
context: Optional context string
|
|
||||||
jinja2_variables: Variables for jinja2 template rendering
|
|
||||||
variable_pool: Variable pool for template conversion
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sequence of prompt messages
|
|
||||||
"""
|
|
||||||
prompt_messages = []
|
|
||||||
if template.edition_type == "jinja2":
|
|
||||||
result_text = _render_jinja2_message(
|
|
||||||
template=template.jinja2_text or "",
|
|
||||||
jinja2_variables=jinja2_variables,
|
|
||||||
variable_pool=variable_pool,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if context:
|
|
||||||
template_text = template.text.replace("{#context#}", context)
|
|
||||||
else:
|
|
||||||
template_text = template.text
|
|
||||||
result_text = variable_pool.convert_template(template_text).text
|
|
||||||
prompt_message = _combine_message_content_with_role(
|
|
||||||
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
|
|
||||||
)
|
|
||||||
prompt_messages.append(prompt_message)
|
|
||||||
return prompt_messages
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
|
|
@ -19,3 +20,11 @@ class ModelFactory(Protocol):
|
||||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateRenderer(Protocol):
|
||||||
|
"""Port for rendering prompt templates used by LLM-compatible nodes."""
|
||||||
|
|
||||||
|
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
|
||||||
|
"""Render the given Jinja2 template into plain text."""
|
||||||
|
...
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ from dify_graph.nodes.llm import (
|
||||||
llm_utils,
|
llm_utils,
|
||||||
)
|
)
|
||||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
|
||||||
|
|
@ -59,6 +59,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
_model_factory: "ModelFactory"
|
_model_factory: "ModelFactory"
|
||||||
_model_instance: ModelInstance
|
_model_instance: ModelInstance
|
||||||
_memory: PromptMessageMemory | None
|
_memory: PromptMessageMemory | None
|
||||||
|
_template_renderer: TemplateRenderer
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -71,6 +72,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
model_factory: "ModelFactory",
|
model_factory: "ModelFactory",
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
http_client: HttpClientProtocol,
|
http_client: HttpClientProtocol,
|
||||||
|
template_renderer: TemplateRenderer,
|
||||||
memory: PromptMessageMemory | None = None,
|
memory: PromptMessageMemory | None = None,
|
||||||
llm_file_saver: LLMFileSaver | None = None,
|
llm_file_saver: LLMFileSaver | None = None,
|
||||||
):
|
):
|
||||||
|
|
@ -87,6 +89,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
self._model_factory = model_factory
|
self._model_factory = model_factory
|
||||||
self._model_instance = model_instance
|
self._model_instance = model_instance
|
||||||
self._memory = memory
|
self._memory = memory
|
||||||
|
self._template_renderer = template_renderer
|
||||||
|
|
||||||
if llm_file_saver is None:
|
if llm_file_saver is None:
|
||||||
dify_ctx = self.require_dify_context()
|
dify_ctx = self.require_dify_context()
|
||||||
|
|
@ -142,7 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
||||||
# two consecutive user prompts will be generated, causing model's error.
|
# two consecutive user prompts will be generated, causing model's error.
|
||||||
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
||||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
prompt_messages, stop = llm_utils.fetch_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
sys_query="",
|
sys_query="",
|
||||||
memory=memory,
|
memory=memory,
|
||||||
|
|
@ -153,6 +156,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
vision_detail=node_data.vision.configs.detail,
|
vision_detail=node_data.vision.configs.detail,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
jinja2_variables=[],
|
jinja2_variables=[],
|
||||||
|
template_renderer=self._template_renderer,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_text = ""
|
result_text = ""
|
||||||
|
|
@ -287,7 +291,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||||
|
|
||||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||||
prompt_messages, _ = LLMNode.fetch_prompt_messages(
|
prompt_messages, _ = llm_utils.fetch_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
sys_query="",
|
sys_query="",
|
||||||
sys_files=[],
|
sys_files=[],
|
||||||
|
|
@ -300,6 +304,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
vision_detail=node_data.vision.configs.detail,
|
vision_detail=node_data.vision.configs.detail,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
jinja2_variables=[],
|
jinja2_variables=[],
|
||||||
|
template_renderer=self._template_renderer,
|
||||||
)
|
)
|
||||||
rest_tokens = 2000
|
rest_tokens = 2000
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from core.model_manager import ModelInstance
|
||||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||||
from dify_graph.node_events import StreamCompletedEvent
|
from dify_graph.node_events import StreamCompletedEvent
|
||||||
from dify_graph.nodes.llm.node import LLMNode
|
from dify_graph.nodes.llm.node import LLMNode
|
||||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||||
from dify_graph.system_variable import SystemVariable
|
from dify_graph.system_variable import SystemVariable
|
||||||
|
|
@ -75,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||||
model_factory=MagicMock(spec=ModelFactory),
|
model_factory=MagicMock(spec=ModelFactory),
|
||||||
model_instance=MagicMock(spec=ModelInstance),
|
model_instance=MagicMock(spec=ModelInstance),
|
||||||
|
template_renderer=MagicMock(spec=TemplateRenderer),
|
||||||
http_client=MagicMock(spec=HttpClientProtocol),
|
http_client=MagicMock(spec=HttpClientProtocol),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -158,7 +159,7 @@ def test_execute_llm():
|
||||||
return mock_model_instance
|
return mock_model_instance
|
||||||
|
|
||||||
# Mock fetch_prompt_messages to avoid database calls
|
# Mock fetch_prompt_messages to avoid database calls
|
||||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
def mock_fetch_prompt_messages_1(*_args, **_kwargs):
|
||||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
|
||||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode
|
from dify_graph.nodes.document_extractor import DocumentExtractorNode
|
||||||
from dify_graph.nodes.http_request import HttpRequestNode
|
from dify_graph.nodes.http_request import HttpRequestNode
|
||||||
from dify_graph.nodes.llm import LLMNode
|
from dify_graph.nodes.llm import LLMNode
|
||||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||||
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
|
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
|
||||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||||
|
|
@ -68,6 +68,8 @@ class MockNodeMixin:
|
||||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||||
# LLM-like nodes now require an http_client; provide a mock by default for tests.
|
# LLM-like nodes now require an http_client; provide a mock by default for tests.
|
||||||
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
|
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
|
||||||
|
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||||
|
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
|
||||||
|
|
||||||
# Ensure TemplateTransformNode receives a renderer now required by constructor
|
# Ensure TemplateTransformNode receives a renderer now required by constructor
|
||||||
if isinstance(self, TemplateTransformNode):
|
if isinstance(self, TemplateTransformNode):
|
||||||
|
|
|
||||||
|
|
@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
|
||||||
VisionConfigOptions,
|
VisionConfigOptions,
|
||||||
)
|
)
|
||||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||||
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
|
from dify_graph.nodes.llm.node import LLMNode
|
||||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||||
from dify_graph.system_variable import SystemVariable
|
from dify_graph.system_variable import SystemVariable
|
||||||
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||||
|
|
@ -107,6 +107,7 @@ def llm_node(
|
||||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||||
|
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||||
node_config = {
|
node_config = {
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": llm_node_data.model_dump(),
|
"data": llm_node_data.model_dump(),
|
||||||
|
|
@ -121,6 +122,7 @@ def llm_node(
|
||||||
model_factory=mock_model_factory,
|
model_factory=mock_model_factory,
|
||||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||||
llm_file_saver=mock_file_saver,
|
llm_file_saver=mock_file_saver,
|
||||||
|
template_renderer=mock_template_renderer,
|
||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
)
|
)
|
||||||
return node
|
return node
|
||||||
|
|
@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node):
|
||||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
|
||||||
|
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
|
||||||
|
messages = [
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="",
|
||||||
|
jinja2_text="Hello, {{ name }}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="jinja2",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = llm_node.handle_list_messages(
|
||||||
|
messages=messages,
|
||||||
|
context=None,
|
||||||
|
jinja2_variables=[],
|
||||||
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
|
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||||
|
template_renderer=llm_node._template_renderer,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
|
||||||
|
llm_node._template_renderer.render_jinja2.assert_called_once_with(
|
||||||
|
template="Hello, {{ name }}",
|
||||||
|
inputs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||||
memory.get_history_prompt_messages.return_value = [
|
memory.get_history_prompt_messages.return_value = [
|
||||||
|
|
@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||||
window=MemoryConfig.WindowConfig(enabled=True, size=3),
|
window=MemoryConfig.WindowConfig(enabled=True, size=3),
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
|
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||||
memory_text = _handle_memory_completion_mode(
|
memory_text = llm_utils.handle_memory_completion_mode(
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
|
|
@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||||
|
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||||
node_config = {
|
node_config = {
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": llm_node_data.model_dump(),
|
"data": llm_node_data.model_dump(),
|
||||||
|
|
@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||||
model_factory=mock_model_factory,
|
model_factory=mock_model_factory,
|
||||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||||
llm_file_saver=mock_file_saver,
|
llm_file_saver=mock_file_saver,
|
||||||
|
template_renderer=mock_template_renderer,
|
||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
)
|
)
|
||||||
return node, mock_file_saver
|
return node, mock_file_saver
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,14 @@
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent
|
from dify_graph.model_runtime.entities import ImagePromptMessageContent
|
||||||
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
|
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||||
|
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||||
|
from dify_graph.nodes.question_classifier import (
|
||||||
|
QuestionClassifierNode,
|
||||||
|
QuestionClassifierNodeData,
|
||||||
|
)
|
||||||
|
from tests.workflow_test_utils import build_test_graph_init_params
|
||||||
|
|
||||||
|
|
||||||
def test_init_question_classifier_node_data():
|
def test_init_question_classifier_node_data():
|
||||||
|
|
@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config():
|
||||||
assert node_data.vision.enabled == False
|
assert node_data.vision.enabled == False
|
||||||
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
||||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
|
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
|
||||||
|
|
||||||
|
|
||||||
|
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
|
||||||
|
node_data = QuestionClassifierNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "test classifier node",
|
||||||
|
"query_variable_selector": ["id", "name"],
|
||||||
|
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||||
|
"classes": [{"id": "1", "name": "class 1"}],
|
||||||
|
"instruction": "This is a test instruction",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
template_renderer = MagicMock(spec=TemplateRenderer)
|
||||||
|
node = QuestionClassifierNode(
|
||||||
|
id="node-id",
|
||||||
|
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
|
||||||
|
graph_init_params=build_test_graph_init_params(
|
||||||
|
workflow_id="workflow-id",
|
||||||
|
graph_config={},
|
||||||
|
tenant_id="tenant-id",
|
||||||
|
app_id="app-id",
|
||||||
|
user_id="user-id",
|
||||||
|
),
|
||||||
|
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
|
||||||
|
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||||
|
model_factory=MagicMock(spec=ModelFactory),
|
||||||
|
model_instance=MagicMock(),
|
||||||
|
http_client=MagicMock(spec=HttpClientProtocol),
|
||||||
|
llm_file_saver=MagicMock(),
|
||||||
|
template_renderer=template_renderer,
|
||||||
|
)
|
||||||
|
fetch_prompt_messages = MagicMock(return_value=([], None))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
|
||||||
|
fetch_prompt_messages,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
|
||||||
|
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
|
||||||
|
)
|
||||||
|
|
||||||
|
node._calculate_rest_token(
|
||||||
|
node_data=node_data,
|
||||||
|
query="hello",
|
||||||
|
model_instance=MagicMock(stop=(), parameters={}),
|
||||||
|
context="",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor:
|
||||||
assert executor.is_execution_error(RuntimeError("boom")) is False
|
assert executor.is_execution_error(RuntimeError("boom")) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultLLMTemplateRenderer:
|
||||||
|
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
|
||||||
|
renderer = node_factory.DefaultLLMTemplateRenderer()
|
||||||
|
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
node_factory.CodeExecutor,
|
||||||
|
"execute_workflow_code_template",
|
||||||
|
execute_workflow_code_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = renderer.render_jinja2(
|
||||||
|
template="Hello {{ name }}",
|
||||||
|
inputs={"name": "world"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "hello world"
|
||||||
|
execute_workflow_code_template.assert_called_once_with(
|
||||||
|
language=CodeLanguage.JINJA2,
|
||||||
|
code="Hello {{ name }}",
|
||||||
|
inputs={"name": "world"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDifyNodeFactoryInit:
|
class TestDifyNodeFactoryInit:
|
||||||
def test_init_builds_default_dependencies(self):
|
def test_init_builds_default_dependencies(self):
|
||||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||||
|
|
@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit:
|
||||||
http_request_config = sentinel.http_request_config
|
http_request_config = sentinel.http_request_config
|
||||||
credentials_provider = sentinel.credentials_provider
|
credentials_provider = sentinel.credentials_provider
|
||||||
model_factory = sentinel.model_factory
|
model_factory = sentinel.model_factory
|
||||||
|
llm_template_renderer = sentinel.llm_template_renderer
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(
|
patch.object(
|
||||||
|
|
@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit:
|
||||||
"build_http_request_config",
|
"build_http_request_config",
|
||||||
return_value=http_request_config,
|
return_value=http_request_config,
|
||||||
),
|
),
|
||||||
|
patch.object(
|
||||||
|
node_factory,
|
||||||
|
"DefaultLLMTemplateRenderer",
|
||||||
|
return_value=llm_template_renderer,
|
||||||
|
) as llm_renderer_factory,
|
||||||
patch.object(
|
patch.object(
|
||||||
node_factory,
|
node_factory,
|
||||||
"build_dify_model_access",
|
"build_dify_model_access",
|
||||||
|
|
@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit:
|
||||||
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
||||||
build_dify_model_access.assert_called_once_with("tenant-id")
|
build_dify_model_access.assert_called_once_with("tenant-id")
|
||||||
renderer_factory.assert_called_once()
|
renderer_factory.assert_called_once()
|
||||||
|
llm_renderer_factory.assert_called_once()
|
||||||
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
||||||
assert factory.graph_init_params is graph_init_params
|
assert factory.graph_init_params is graph_init_params
|
||||||
assert factory.graph_runtime_state is graph_runtime_state
|
assert factory.graph_runtime_state is graph_runtime_state
|
||||||
assert factory._dify_context is dify_context
|
assert factory._dify_context is dify_context
|
||||||
assert factory._template_renderer is template_renderer
|
assert factory._template_renderer is template_renderer
|
||||||
|
|
||||||
|
assert factory._llm_template_renderer is llm_template_renderer
|
||||||
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
||||||
assert factory._http_request_config is http_request_config
|
assert factory._http_request_config is http_request_config
|
||||||
assert factory._llm_credentials_provider is credentials_provider
|
assert factory._llm_credentials_provider is credentials_provider
|
||||||
|
|
@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||||
factory._code_executor = sentinel.code_executor
|
factory._code_executor = sentinel.code_executor
|
||||||
factory._code_limits = sentinel.code_limits
|
factory._code_limits = sentinel.code_limits
|
||||||
factory._template_renderer = sentinel.template_renderer
|
factory._template_renderer = sentinel.template_renderer
|
||||||
|
factory._llm_template_renderer = sentinel.llm_template_renderer
|
||||||
factory._template_transform_max_output_length = 2048
|
factory._template_transform_max_output_length = 2048
|
||||||
factory._http_request_http_client = sentinel.http_client
|
factory._http_request_http_client = sentinel.http_client
|
||||||
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
||||||
|
|
@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||||
[
|
[
|
||||||
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
|
(
|
||||||
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
|
BuiltinNodeTypes.LLM,
|
||||||
|
"LLMNode",
|
||||||
|
{
|
||||||
|
"http_client": sentinel.http_client,
|
||||||
|
"template_renderer": sentinel.llm_template_renderer,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||||
|
"QuestionClassifierNode",
|
||||||
|
{
|
||||||
|
"http_client": sentinel.http_client,
|
||||||
|
"template_renderer": sentinel.llm_template_renderer,
|
||||||
|
},
|
||||||
|
),
|
||||||
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue