mirror of https://github.com/langgenius/dify.git
refactor: better `/context-generate` with frontend support
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
parent
71f811930f
commit
a409e3d32e
|
|
@ -16,6 +16,11 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
|||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.context_models import (
|
||||
AvailableVarPayload,
|
||||
CodeContextPayload,
|
||||
ParameterInfoPayload,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -58,22 +63,21 @@ class InstructionTemplatePayload(BaseModel):
|
|||
class ContextGeneratePayload(BaseModel):
|
||||
"""Payload for generating extractor code node."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name to generate code for")
|
||||
language: str = Field(default="python3", description="Code language (python3/javascript)")
|
||||
prompt_messages: list[dict[str, Any]] = Field(
|
||||
..., description="Multi-turn conversation history, last message is the current instruction"
|
||||
)
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
|
||||
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
|
||||
code_context: CodeContextPayload | None = Field(
|
||||
default=None, description="Existing code node context for incremental generation"
|
||||
)
|
||||
|
||||
|
||||
class SuggestedQuestionsPayload(BaseModel):
|
||||
"""Payload for generating suggested questions."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name")
|
||||
language: str = Field(
|
||||
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
|
||||
)
|
||||
|
|
@ -82,6 +86,8 @@ class SuggestedQuestionsPayload(BaseModel):
|
|||
alias="model_config",
|
||||
description="Model configuration (optional, uses system default if not provided)",
|
||||
)
|
||||
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
|
||||
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
|
|
@ -328,17 +334,15 @@ class ContextGenerateApi(Resource):
|
|||
args = ContextGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_with_context(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
prompt_messages=prompt_messages,
|
||||
prompt_messages=deserialize_prompt_messages(args.prompt_messages),
|
||||
model_config=args.model_config_data,
|
||||
available_vars=args.available_vars,
|
||||
parameter_info=args.parameter_info,
|
||||
code_context=args.code_context,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
|
@ -362,14 +366,12 @@ class SuggestedQuestionsApi(Resource):
|
|||
def post(self):
|
||||
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_suggested_questions(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
available_vars=args.available_vars,
|
||||
parameter_info=args.parameter_info,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class VariableSelectorPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(..., description="Variable name used in generated code")
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeOutputPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(..., description="Output variable type")
|
||||
|
||||
|
||||
class CodeContextPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
code: str = Field(..., description="Existing code in the Code node")
|
||||
outputs: dict[str, CodeOutputPayload] | None = Field(
|
||||
default=None, description="Existing output definitions for the Code node"
|
||||
)
|
||||
variables: list[VariableSelectorPayload] | None = Field(
|
||||
default=None, description="Existing variable selectors used by the Code node"
|
||||
)
|
||||
|
||||
|
||||
class AvailableVarPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables).
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output")
|
||||
type: str = Field(..., description="Variable type, e.g. string, number, array[object]")
|
||||
description: str | None = Field(default=None, description="Optional variable description")
|
||||
node_id: str | None = Field(default=None, description="Source node ID")
|
||||
node_title: str | None = Field(default=None, description="Source node title")
|
||||
node_type: str | None = Field(default=None, description="Source node type")
|
||||
json_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="schema",
|
||||
description="Optional JSON schema for object variables",
|
||||
)
|
||||
|
||||
|
||||
class ParameterInfoPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(..., description="Target parameter name")
|
||||
type: str = Field(default="string", description="Target parameter type")
|
||||
description: str = Field(default="", description="Parameter description")
|
||||
required: bool | None = Field(default=None, description="Whether the parameter is required")
|
||||
options: list[str] | None = Field(default=None, description="Allowed option values")
|
||||
min: float | None = Field(default=None, description="Minimum numeric value")
|
||||
max: float | None = Field(default=None, description="Maximum numeric value")
|
||||
default: str | int | float | bool | None = Field(default=None, description="Default value")
|
||||
multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values")
|
||||
label: str | None = Field(default=None, description="Optional display label")
|
||||
|
|
@ -1,11 +1,16 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.llm_generator.context_models import (
|
||||
AvailableVarPayload,
|
||||
CodeContextPayload,
|
||||
ParameterInfoPayload,
|
||||
)
|
||||
from core.llm_generator.output_models import (
|
||||
CodeNodeStructuredOutput,
|
||||
InstructionModifyOutput,
|
||||
|
|
@ -132,9 +137,6 @@ class LLMGenerator:
|
|||
return []
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
|
||||
questions: Sequence[str] = []
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
|
|
@ -402,24 +404,24 @@ class LLMGenerator:
|
|||
def generate_with_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: dict,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
parameter_info: ParameterInfoPayload,
|
||||
code_context: CodeContextPayload,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate extractor code node based on conversation context.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant/workspace ID
|
||||
workflow_id: Workflow ID
|
||||
node_id: Current tool/llm node ID
|
||||
parameter_name: Parameter name to generate code for
|
||||
language: Code language (python3/javascript)
|
||||
prompt_messages: Multi-turn conversation history (last message is instruction)
|
||||
model_config: Model configuration (provider, name, completion_params)
|
||||
available_vars: Client-provided available variables with types/schema
|
||||
parameter_info: Client-provided parameter metadata (type/constraints)
|
||||
code_context: Client-provided existing code node context
|
||||
|
||||
Returns:
|
||||
dict with CodeNodeData format:
|
||||
|
|
@ -430,43 +432,12 @@ class LLMGenerator:
|
|||
- message: Explanation
|
||||
- error: Error message if any
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.workflow_service import WorkflowService
|
||||
# available_vars/parameter_info/code_context are provided by the frontend context-generate modal.
|
||||
# See web/app/components/workflow/nodes/tool/components/context-generate-modal/hooks/use-context-generate.ts
|
||||
|
||||
# Get workflow
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return cls._error_response(f"App {workflow_id} not found")
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return cls._error_response(f"Workflow for app {workflow_id} not found")
|
||||
|
||||
# Get upstream nodes via edge backtracking
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
|
||||
# Get current node info
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return cls._error_response(f"Node {node_id} not found")
|
||||
|
||||
# Get parameter info
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = cls._build_extractor_system_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
available_vars=available_vars, parameter_info=parameter_info, language=language, code_context=code_context
|
||||
)
|
||||
|
||||
# Construct complete prompt_messages with system prompt
|
||||
|
|
@ -504,9 +475,14 @@ class LLMGenerator:
|
|||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return cls._parse_code_node_output(
|
||||
response.structured_output, language, parameter_info.get("type", "string")
|
||||
)
|
||||
return {
|
||||
"variables": response.variables,
|
||||
"code_language": language,
|
||||
"code": response.code,
|
||||
"outputs": response.outputs,
|
||||
"message": response.explanation,
|
||||
"error": "",
|
||||
}
|
||||
|
||||
except InvokeError as e:
|
||||
return cls._error_response(str(e))
|
||||
|
|
@ -530,49 +506,24 @@ class LLMGenerator:
|
|||
def generate_suggested_questions(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
model_config: dict | None = None,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
parameter_info: ParameterInfoPayload,
|
||||
model_config: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate suggested questions for context generation.
|
||||
|
||||
Returns dict with questions array and error field.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow context (reuse existing logic)
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return {"questions": [], "error": f"App {workflow_id} not found"}
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
|
||||
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return {"questions": [], "error": f"Node {node_id} not found"}
|
||||
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# available_vars/parameter_info are provided by the frontend context-generate modal.
|
||||
# See web/app/components/workflow/nodes/tool/components/context-generate-modal/hooks/use-context-generate.ts
|
||||
# Build prompt
|
||||
system_prompt = cls._build_suggested_questions_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
available_vars=available_vars,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
|
@ -617,8 +568,7 @@ class LLMGenerator:
|
|||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
questions = response.structured_output.get("questions", []) if response.structured_output else []
|
||||
return {"questions": questions, "error": ""}
|
||||
return {"questions": response.questions, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
return {"questions": [], "error": str(e)}
|
||||
|
|
@ -629,213 +579,100 @@ class LLMGenerator:
|
|||
@classmethod
|
||||
def _build_suggested_questions_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
parameter_info: ParameterInfoPayload,
|
||||
language: str = "English",
|
||||
) -> str:
|
||||
"""Build minimal prompt for suggested questions generation."""
|
||||
# Simplify upstream nodes to reduce tokens
|
||||
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
|
||||
param_type = parameter_info.get("type", "string")
|
||||
param_desc = parameter_info.get("description", "")[:100]
|
||||
parameter_block = cls._format_parameter_info(parameter_info)
|
||||
available_vars_block = cls._format_available_vars(
|
||||
available_vars,
|
||||
max_items=30,
|
||||
max_schema_chars=400,
|
||||
max_description_chars=120,
|
||||
)
|
||||
|
||||
return f"""Suggest 3 code generation questions for extracting data.
|
||||
Sources: {", ".join(sources)}
|
||||
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
|
||||
Output 3 short, practical questions in {language}."""
|
||||
return f"""Suggest exactly 3 short questions that would help generate code for the target parameter.
|
||||
|
||||
## Target Parameter
|
||||
{parameter_block}
|
||||
|
||||
## Available Variables
|
||||
{available_vars_block}
|
||||
|
||||
## Constraints
|
||||
- Output exactly 3 questions.
|
||||
- Use {language}.
|
||||
- Keep each question short and practical.
|
||||
- Do not include code or variable syntax in the questions.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get all upstream nodes via edge backtracking.
|
||||
|
||||
Traverses the graph backwards from node_id to collect all reachable nodes.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
|
||||
edges = graph_dict.get("edges", [])
|
||||
|
||||
# Build reverse adjacency list
|
||||
reverse_adj: dict[str, list[str]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
reverse_adj[edge["target"]].append(edge["source"])
|
||||
|
||||
# BFS to find all upstream nodes
|
||||
visited: set[str] = set()
|
||||
queue = [node_id]
|
||||
upstream: list[dict] = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for source in reverse_adj.get(current, []):
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
queue.append(source)
|
||||
if source in nodes:
|
||||
upstream.append(cls._extract_node_info(nodes[source]))
|
||||
|
||||
return upstream
|
||||
def _format_parameter_info(cls, parameter_info: ParameterInfoPayload) -> str:
|
||||
payload = parameter_info.model_dump(mode="python", by_alias=True)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
|
||||
"""Get node by ID from graph."""
|
||||
for node in graph_dict.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
return node
|
||||
return None
|
||||
def _format_available_vars(
|
||||
cls,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
*,
|
||||
max_items: int,
|
||||
max_schema_chars: int,
|
||||
max_description_chars: int,
|
||||
) -> str:
|
||||
payload = [item.model_dump(mode="python", by_alias=True) for item in available_vars]
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: dict) -> dict:
|
||||
"""Extract minimal node info with outputs based on node type."""
|
||||
node_type = node["data"]["type"]
|
||||
node_data = node.get("data", {})
|
||||
|
||||
# Build outputs based on node type (only type, no description to reduce tokens)
|
||||
outputs: dict[str, str] = {}
|
||||
match node_type:
|
||||
case "start":
|
||||
for var in node_data.get("variables", []):
|
||||
name = var.get("variable", var.get("name", ""))
|
||||
outputs[name] = var.get("type", "string")
|
||||
case "llm":
|
||||
outputs["text"] = "string"
|
||||
case "code":
|
||||
for name, output in node_data.get("outputs", {}).items():
|
||||
outputs[name] = output.get("type", "string")
|
||||
case "http-request":
|
||||
outputs = {"body": "string", "status_code": "number", "headers": "object"}
|
||||
case "knowledge-retrieval":
|
||||
outputs["result"] = "array[object]"
|
||||
case "tool":
|
||||
outputs = {"text": "string", "json": "object"}
|
||||
case _:
|
||||
outputs["output"] = "string"
|
||||
|
||||
info: dict = {
|
||||
"id": node["id"],
|
||||
"title": node_data.get("title", node["id"]),
|
||||
"outputs": outputs,
|
||||
}
|
||||
# Only include description if not empty
|
||||
desc = node_data.get("desc", "")
|
||||
if desc:
|
||||
info["desc"] = desc
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
|
||||
"""Get parameter info from tool schema using ToolManager."""
|
||||
default_info = {"name": parameter_name, "type": "string", "description": ""}
|
||||
|
||||
if node_data.get("type") != "tool":
|
||||
return default_info
|
||||
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_type_str = node_data.get("provider_type", "")
|
||||
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
|
||||
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=provider_type,
|
||||
provider_id=node_data.get("provider_id", ""),
|
||||
tool_name=node_data.get("tool_name", ""),
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
for param in parameters:
|
||||
if param.name == parameter_name:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
|
||||
"description": param.llm_description
|
||||
or (param.human_description.en_US if param.human_description else ""),
|
||||
"required": param.required,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get parameter info from ToolManager: %s", e)
|
||||
|
||||
return default_info
|
||||
def _format_code_context(cls, code_context: CodeContextPayload | None) -> str:
|
||||
if not code_context:
|
||||
return ""
|
||||
code = code_context.code
|
||||
outputs = code_context.outputs
|
||||
variables = code_context.variables
|
||||
if not code and not outputs and not variables:
|
||||
return ""
|
||||
payload = code_context.model_dump(mode="python", by_alias=True)
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def _build_extractor_system_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
available_vars: Sequence[AvailableVarPayload],
|
||||
parameter_info: ParameterInfoPayload,
|
||||
language: str,
|
||||
code_context: CodeContextPayload,
|
||||
) -> str:
|
||||
"""Build system prompt for extractor code generation."""
|
||||
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
|
||||
param_type = parameter_info.get("type", "string")
|
||||
return f"""You are a code generator for workflow automation.
|
||||
param_type = parameter_info.type or "string"
|
||||
parameter_block = cls._format_parameter_info(parameter_info)
|
||||
available_vars_block = cls._format_available_vars(
|
||||
available_vars,
|
||||
max_items=80,
|
||||
max_schema_chars=800,
|
||||
max_description_chars=160,
|
||||
)
|
||||
code_context_block = cls._format_code_context(code_context)
|
||||
code_context_section = f"\n{code_context_block}\n" if code_context_block else "\n"
|
||||
return f"""You are a code generator for Dify workflow automation.
|
||||
|
||||
Generate {language} code to extract/transform upstream node outputs for the target parameter.
|
||||
Generate {language} code to extract/transform available variables for the target parameter.
|
||||
|
||||
## Upstream Nodes
|
||||
{upstream_json}
|
||||
## Target Parameter
|
||||
{parameter_block}
|
||||
|
||||
## Target
|
||||
Node: {current_node["data"].get("title", current_node["id"])}
|
||||
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
|
||||
|
||||
## Requirements
|
||||
- Write a main function that returns type: {param_type}
|
||||
- Use value_selector format: ["node_id", "output_name"]
|
||||
## Available Variables
|
||||
{available_vars_block}
|
||||
{code_context_section}## Requirements
|
||||
- Use only the listed value_selector paths.
|
||||
- Do not invent variables or fields that are not listed.
|
||||
- Write a main function that returns type: {param_type}.
|
||||
- Respect target constraints (options/min/max/default/multiple) if provided.
|
||||
- If existing code is provided, adapt it instead of rewriting from scratch.
|
||||
- Return only JSON that matches the provided schema.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
|
||||
"""
|
||||
Parse structured output to CodeNodeData format.
|
||||
|
||||
Args:
|
||||
content: Structured output dict from invoke_llm_with_structured_output
|
||||
language: Code language
|
||||
parameter_type: Expected parameter type
|
||||
|
||||
Returns dict with variables, code_language, code, outputs, message, error.
|
||||
"""
|
||||
if content is None:
|
||||
return cls._error_response("Empty or invalid response from LLM")
|
||||
|
||||
# Validate and normalize variables
|
||||
variables = [
|
||||
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
|
||||
for v in content.get("variables", [])
|
||||
if isinstance(v, dict)
|
||||
]
|
||||
|
||||
# Convert outputs from array format [{name, type}] to dict format {name: {type}}
|
||||
# Array format is required for OpenAI/Azure strict JSON schema compatibility
|
||||
raw_outputs = content.get("outputs", [])
|
||||
if isinstance(raw_outputs, list):
|
||||
outputs = {
|
||||
item.get("name", "result"): {"type": item.get("type", parameter_type)}
|
||||
for item in raw_outputs
|
||||
if isinstance(item, dict) and item.get("name")
|
||||
}
|
||||
if not outputs:
|
||||
outputs = {"result": {"type": parameter_type}}
|
||||
else:
|
||||
outputs = raw_outputs or {"result": {"type": parameter_type}}
|
||||
|
||||
return {
|
||||
"variables": variables,
|
||||
"code_language": language,
|
||||
"code": content.get("code", ""),
|
||||
"outputs": outputs,
|
||||
"message": content.get("explanation", ""),
|
||||
"error": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
|
|
@ -891,7 +728,7 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
|
|||
raise ValueError("Workflow not found for the given app model.")
|
||||
last_run = workflow_service.get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
|
||||
try:
|
||||
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
|
||||
node_type = last_run.node_type
|
||||
except Exception:
|
||||
try:
|
||||
node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][
|
||||
|
|
@ -1010,7 +847,7 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
|
|||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
)
|
||||
return response.structured_output or {}
|
||||
return response.model_dump(mode="python")
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
|
|
|
|||
|
|
@ -236,7 +236,6 @@ def invoke_llm_with_structured_output(
|
|||
return generator()
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_pydantic_model(
|
||||
*,
|
||||
provider: str,
|
||||
|
|
@ -251,24 +250,7 @@ def invoke_llm_with_pydantic_model(
|
|||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
def invoke_llm_with_pydantic_model(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
output_model: type[T],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput:
|
||||
) -> T:
|
||||
"""
|
||||
Invoke large language model with a Pydantic output model.
|
||||
|
||||
|
|
@ -299,7 +281,7 @@ def invoke_llm_with_pydantic_model(
|
|||
raise OutputParserError("Structured output is empty")
|
||||
|
||||
validated_output = _validate_structured_output(output_model, structured_output)
|
||||
return result.model_copy(update={"structured_output": validated_output})
|
||||
return output_model.model_validate(validated_output)
|
||||
|
||||
|
||||
def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
|
||||
|
|
@ -309,12 +291,12 @@ def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
|
|||
def _validate_structured_output(
|
||||
output_model: type[T],
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
) -> T:
|
||||
try:
|
||||
validated_output = output_model.model_validate(structured_output)
|
||||
except ValidationError as exc:
|
||||
raise OutputParserError(f"Structured output validation failed: {exc}") from exc
|
||||
return validated_output.model_dump(mode="python")
|
||||
return validated_output
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
|
|
|
|||
|
|
@ -493,8 +493,8 @@ def test_structured_output_with_pydantic_model():
|
|||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"name": "test"}
|
||||
assert isinstance(result, ExampleOutput)
|
||||
assert result.name == "test"
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_streaming_rejected():
|
||||
|
|
|
|||
Loading…
Reference in New Issue