refactor: better `/context-generate` with frontend support

Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
Stream 2026-01-23 01:02:12 +08:00
parent 71f811930f
commit a409e3d32e
No known key found for this signature in database
GPG Key ID: 033728094B100D70
5 changed files with 191 additions and 308 deletions

View File

@ -16,6 +16,11 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.context_models import (
AvailableVarPayload,
CodeContextPayload,
ParameterInfoPayload,
)
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
@ -58,22 +63,21 @@ class InstructionTemplatePayload(BaseModel):
class ContextGeneratePayload(BaseModel):
"""Payload for generating extractor code node."""
workflow_id: str = Field(..., description="Workflow ID")
node_id: str = Field(..., description="Current tool/llm node ID")
parameter_name: str = Field(..., description="Parameter name to generate code for")
language: str = Field(default="python3", description="Code language (python3/javascript)")
prompt_messages: list[dict[str, Any]] = Field(
..., description="Multi-turn conversation history, last message is the current instruction"
)
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
code_context: CodeContextPayload | None = Field(
default=None, description="Existing code node context for incremental generation"
)
class SuggestedQuestionsPayload(BaseModel):
"""Payload for generating suggested questions."""
workflow_id: str = Field(..., description="Workflow ID")
node_id: str = Field(..., description="Current tool/llm node ID")
parameter_name: str = Field(..., description="Parameter name")
language: str = Field(
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
)
@ -82,6 +86,8 @@ class SuggestedQuestionsPayload(BaseModel):
alias="model_config",
description="Model configuration (optional, uses system default if not provided)",
)
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
def reg(cls: type[BaseModel]):
@ -328,17 +334,15 @@ class ContextGenerateApi(Resource):
args = ContextGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
try:
return LLMGenerator.generate_with_context(
tenant_id=current_tenant_id,
workflow_id=args.workflow_id,
node_id=args.node_id,
parameter_name=args.parameter_name,
language=args.language,
prompt_messages=prompt_messages,
prompt_messages=deserialize_prompt_messages(args.prompt_messages),
model_config=args.model_config_data,
available_vars=args.available_vars,
parameter_info=args.parameter_info,
code_context=args.code_context,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -362,14 +366,12 @@ class SuggestedQuestionsApi(Resource):
def post(self):
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
return LLMGenerator.generate_suggested_questions(
tenant_id=current_tenant_id,
workflow_id=args.workflow_id,
node_id=args.node_id,
parameter_name=args.parameter_name,
language=args.language,
available_vars=args.available_vars,
parameter_info=args.parameter_info,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:

View File

@ -0,0 +1,62 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class VariableSelectorPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
variable: str = Field(..., description="Variable name used in generated code")
value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]")
class CodeOutputPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
type: str = Field(..., description="Output variable type")
class CodeContextPayload(BaseModel):
# From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot).
model_config = ConfigDict(extra="forbid")
code: str = Field(..., description="Existing code in the Code node")
outputs: dict[str, CodeOutputPayload] | None = Field(
default=None, description="Existing output definitions for the Code node"
)
variables: list[VariableSelectorPayload] | None = Field(
default=None, description="Existing variable selectors used by the Code node"
)
class AvailableVarPayload(BaseModel):
# From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables).
model_config = ConfigDict(extra="forbid", populate_by_name=True)
value_selector: list[str] = Field(..., description="Path to upstream node output")
type: str = Field(..., description="Variable type, e.g. string, number, array[object]")
description: str | None = Field(default=None, description="Optional variable description")
node_id: str | None = Field(default=None, description="Source node ID")
node_title: str | None = Field(default=None, description="Source node title")
node_type: str | None = Field(default=None, description="Source node type")
json_schema: dict[str, Any] | None = Field(
default=None,
alias="schema",
description="Optional JSON schema for object variables",
)
class ParameterInfoPayload(BaseModel):
# From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata).
model_config = ConfigDict(extra="forbid")
name: str = Field(..., description="Target parameter name")
type: str = Field(default="string", description="Target parameter type")
description: str = Field(default="", description="Parameter description")
required: bool | None = Field(default=None, description="Whether the parameter is required")
options: list[str] | None = Field(default=None, description="Allowed option values")
min: float | None = Field(default=None, description="Minimum numeric value")
max: float | None = Field(default=None, description="Maximum numeric value")
default: str | int | float | bool | None = Field(default=None, description="Default value")
multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values")
label: str | None = Field(default=None, description="Optional display label")

View File

@ -1,11 +1,16 @@
import json
import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any, Protocol, cast
from collections.abc import Sequence
from typing import Protocol
import json_repair
from core.llm_generator.context_models import (
AvailableVarPayload,
CodeContextPayload,
ParameterInfoPayload,
)
from core.llm_generator.output_models import (
CodeNodeStructuredOutput,
InstructionModifyOutput,
@ -132,9 +137,6 @@ class LLMGenerator:
return []
prompt_messages = [UserPromptMessage(content=prompt)]
questions: Sequence[str] = []
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
@ -402,24 +404,24 @@ class LLMGenerator:
def generate_with_context(
cls,
tenant_id: str,
workflow_id: str,
node_id: str,
parameter_name: str,
language: str,
prompt_messages: list[PromptMessage],
model_config: dict,
available_vars: Sequence[AvailableVarPayload],
parameter_info: ParameterInfoPayload,
code_context: CodeContextPayload,
) -> dict:
"""
Generate extractor code node based on conversation context.
Args:
tenant_id: Tenant/workspace ID
workflow_id: Workflow ID
node_id: Current tool/llm node ID
parameter_name: Parameter name to generate code for
language: Code language (python3/javascript)
prompt_messages: Multi-turn conversation history (last message is instruction)
model_config: Model configuration (provider, name, completion_params)
available_vars: Client-provided available variables with types/schema
parameter_info: Client-provided parameter metadata (type/constraints)
code_context: Client-provided existing code node context
Returns:
dict with CodeNodeData format:
@ -430,43 +432,12 @@ class LLMGenerator:
- message: Explanation
- error: Error message if any
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from services.workflow_service import WorkflowService
# available_vars/parameter_info/code_context are provided by the frontend context-generate modal.
# See web/app/components/workflow/nodes/tool/components/context-generate-modal/hooks/use-context-generate.ts
# Get workflow
with Session(db.engine) as session:
stmt = select(App).where(App.id == workflow_id)
app = session.scalar(stmt)
if not app:
return cls._error_response(f"App {workflow_id} not found")
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return cls._error_response(f"Workflow for app {workflow_id} not found")
# Get upstream nodes via edge backtracking
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
# Get current node info
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
if not current_node:
return cls._error_response(f"Node {node_id} not found")
# Get parameter info
parameter_info = cls._get_parameter_info(
tenant_id=tenant_id,
node_data=current_node.get("data", {}),
parameter_name=parameter_name,
)
# Build system prompt
system_prompt = cls._build_extractor_system_prompt(
upstream_nodes=upstream_nodes,
current_node=current_node,
parameter_info=parameter_info,
language=language,
available_vars=available_vars, parameter_info=parameter_info, language=language, code_context=code_context
)
# Construct complete prompt_messages with system prompt
@ -504,9 +475,14 @@ class LLMGenerator:
tenant_id=tenant_id,
)
return cls._parse_code_node_output(
response.structured_output, language, parameter_info.get("type", "string")
)
return {
"variables": response.variables,
"code_language": language,
"code": response.code,
"outputs": response.outputs,
"message": response.explanation,
"error": "",
}
except InvokeError as e:
return cls._error_response(str(e))
@ -530,49 +506,24 @@ class LLMGenerator:
def generate_suggested_questions(
cls,
tenant_id: str,
workflow_id: str,
node_id: str,
parameter_name: str,
language: str,
model_config: dict | None = None,
available_vars: Sequence[AvailableVarPayload],
parameter_info: ParameterInfoPayload,
model_config: dict,
) -> dict:
"""
Generate suggested questions for context generation.
Returns dict with questions array and error field.
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
from services.workflow_service import WorkflowService
# Get workflow context (reuse existing logic)
with Session(db.engine) as session:
stmt = select(App).where(App.id == workflow_id)
app = session.scalar(stmt)
if not app:
return {"questions": [], "error": f"App {workflow_id} not found"}
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
if not current_node:
return {"questions": [], "error": f"Node {node_id} not found"}
parameter_info = cls._get_parameter_info(
tenant_id=tenant_id,
node_data=current_node.get("data", {}),
parameter_name=parameter_name,
)
# available_vars/parameter_info are provided by the frontend context-generate modal.
# See web/app/components/workflow/nodes/tool/components/context-generate-modal/hooks/use-context-generate.ts
# Build prompt
system_prompt = cls._build_suggested_questions_prompt(
upstream_nodes=upstream_nodes,
current_node=current_node,
available_vars=available_vars,
parameter_info=parameter_info,
language=language,
)
@ -617,8 +568,7 @@ class LLMGenerator:
tenant_id=tenant_id,
)
questions = response.structured_output.get("questions", []) if response.structured_output else []
return {"questions": questions, "error": ""}
return {"questions": response.questions, "error": ""}
except InvokeError as e:
return {"questions": [], "error": str(e)}
@ -629,213 +579,100 @@ class LLMGenerator:
@classmethod
def _build_suggested_questions_prompt(
cls,
upstream_nodes: list[dict],
current_node: dict,
parameter_info: dict,
available_vars: Sequence[AvailableVarPayload],
parameter_info: ParameterInfoPayload,
language: str = "English",
) -> str:
"""Build minimal prompt for suggested questions generation."""
# Simplify upstream nodes to reduce tokens
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
param_type = parameter_info.get("type", "string")
param_desc = parameter_info.get("description", "")[:100]
parameter_block = cls._format_parameter_info(parameter_info)
available_vars_block = cls._format_available_vars(
available_vars,
max_items=30,
max_schema_chars=400,
max_description_chars=120,
)
return f"""Suggest 3 code generation questions for extracting data.
Sources: {", ".join(sources)}
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
Output 3 short, practical questions in {language}."""
return f"""Suggest exactly 3 short questions that would help generate code for the target parameter.
## Target Parameter
{parameter_block}
## Available Variables
{available_vars_block}
## Constraints
- Output exactly 3 questions.
- Use {language}.
- Keep each question short and practical.
- Do not include code or variable syntax in the questions.
"""
@classmethod
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
"""
Get all upstream nodes via edge backtracking.
Traverses the graph backwards from node_id to collect all reachable nodes.
"""
from collections import defaultdict
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
edges = graph_dict.get("edges", [])
# Build reverse adjacency list
reverse_adj: dict[str, list[str]] = defaultdict(list)
for edge in edges:
reverse_adj[edge["target"]].append(edge["source"])
# BFS to find all upstream nodes
visited: set[str] = set()
queue = [node_id]
upstream: list[dict] = []
while queue:
current = queue.pop(0)
for source in reverse_adj.get(current, []):
if source not in visited:
visited.add(source)
queue.append(source)
if source in nodes:
upstream.append(cls._extract_node_info(nodes[source]))
return upstream
def _format_parameter_info(cls, parameter_info: ParameterInfoPayload) -> str:
payload = parameter_info.model_dump(mode="python", by_alias=True)
return json.dumps(payload, ensure_ascii=False)
@classmethod
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
"""Get node by ID from graph."""
for node in graph_dict.get("nodes", []):
if node["id"] == node_id:
return node
return None
def _format_available_vars(
cls,
available_vars: Sequence[AvailableVarPayload],
*,
max_items: int,
max_schema_chars: int,
max_description_chars: int,
) -> str:
payload = [item.model_dump(mode="python", by_alias=True) for item in available_vars]
return json.dumps(payload, ensure_ascii=False)
@classmethod
def _extract_node_info(cls, node: dict) -> dict:
"""Extract minimal node info with outputs based on node type."""
node_type = node["data"]["type"]
node_data = node.get("data", {})
# Build outputs based on node type (only type, no description to reduce tokens)
outputs: dict[str, str] = {}
match node_type:
case "start":
for var in node_data.get("variables", []):
name = var.get("variable", var.get("name", ""))
outputs[name] = var.get("type", "string")
case "llm":
outputs["text"] = "string"
case "code":
for name, output in node_data.get("outputs", {}).items():
outputs[name] = output.get("type", "string")
case "http-request":
outputs = {"body": "string", "status_code": "number", "headers": "object"}
case "knowledge-retrieval":
outputs["result"] = "array[object]"
case "tool":
outputs = {"text": "string", "json": "object"}
case _:
outputs["output"] = "string"
info: dict = {
"id": node["id"],
"title": node_data.get("title", node["id"]),
"outputs": outputs,
}
# Only include description if not empty
desc = node_data.get("desc", "")
if desc:
info["desc"] = desc
return info
@classmethod
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
"""Get parameter info from tool schema using ToolManager."""
default_info = {"name": parameter_name, "type": "string", "description": ""}
if node_data.get("type") != "tool":
return default_info
try:
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
provider_type_str = node_data.get("provider_type", "")
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
tool_runtime = ToolManager.get_tool_runtime(
provider_type=provider_type,
provider_id=node_data.get("provider_id", ""),
tool_name=node_data.get("tool_name", ""),
tenant_id=tenant_id,
invoke_from=InvokeFrom.DEBUGGER,
)
parameters = tool_runtime.get_merged_runtime_parameters()
for param in parameters:
if param.name == parameter_name:
return {
"name": param.name,
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
"description": param.llm_description
or (param.human_description.en_US if param.human_description else ""),
"required": param.required,
}
except Exception as e:
logger.debug("Failed to get parameter info from ToolManager: %s", e)
return default_info
def _format_code_context(cls, code_context: CodeContextPayload | None) -> str:
if not code_context:
return ""
code = code_context.code
outputs = code_context.outputs
variables = code_context.variables
if not code and not outputs and not variables:
return ""
payload = code_context.model_dump(mode="python", by_alias=True)
return json.dumps(payload, ensure_ascii=False)
@classmethod
def _build_extractor_system_prompt(
cls,
upstream_nodes: list[dict],
current_node: dict,
parameter_info: dict,
available_vars: Sequence[AvailableVarPayload],
parameter_info: ParameterInfoPayload,
language: str,
code_context: CodeContextPayload,
) -> str:
"""Build system prompt for extractor code generation."""
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
param_type = parameter_info.get("type", "string")
return f"""You are a code generator for workflow automation.
param_type = parameter_info.type or "string"
parameter_block = cls._format_parameter_info(parameter_info)
available_vars_block = cls._format_available_vars(
available_vars,
max_items=80,
max_schema_chars=800,
max_description_chars=160,
)
code_context_block = cls._format_code_context(code_context)
code_context_section = f"\n{code_context_block}\n" if code_context_block else "\n"
return f"""You are a code generator for Dify workflow automation.
Generate {language} code to extract/transform upstream node outputs for the target parameter.
Generate {language} code to extract/transform available variables for the target parameter.
## Upstream Nodes
{upstream_json}
## Target Parameter
{parameter_block}
## Target
Node: {current_node["data"].get("title", current_node["id"])}
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
## Requirements
- Write a main function that returns type: {param_type}
- Use value_selector format: ["node_id", "output_name"]
## Available Variables
{available_vars_block}
{code_context_section}## Requirements
- Use only the listed value_selector paths.
- Do not invent variables or fields that are not listed.
- Write a main function that returns type: {param_type}.
- Respect target constraints (options/min/max/default/multiple) if provided.
- If existing code is provided, adapt it instead of rewriting from scratch.
- Return only JSON that matches the provided schema.
"""
@classmethod
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
"""
Parse structured output to CodeNodeData format.
Args:
content: Structured output dict from invoke_llm_with_structured_output
language: Code language
parameter_type: Expected parameter type
Returns dict with variables, code_language, code, outputs, message, error.
"""
if content is None:
return cls._error_response("Empty or invalid response from LLM")
# Validate and normalize variables
variables = [
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
for v in content.get("variables", [])
if isinstance(v, dict)
]
# Convert outputs from array format [{name, type}] to dict format {name: {type}}
# Array format is required for OpenAI/Azure strict JSON schema compatibility
raw_outputs = content.get("outputs", [])
if isinstance(raw_outputs, list):
outputs = {
item.get("name", "result"): {"type": item.get("type", parameter_type)}
for item in raw_outputs
if isinstance(item, dict) and item.get("name")
}
if not outputs:
outputs = {"result": {"type": parameter_type}}
else:
outputs = raw_outputs or {"result": {"type": parameter_type}}
return {
"variables": variables,
"code_language": language,
"code": content.get("code", ""),
"outputs": outputs,
"message": content.get("explanation", ""),
"error": "",
}
@staticmethod
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
@ -891,7 +728,7 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
raise ValueError("Workflow not found for the given app model.")
last_run = workflow_service.get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
try:
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
node_type = last_run.node_type
except Exception:
try:
node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][
@ -1010,7 +847,7 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
model_parameters=model_parameters,
stream=False,
)
return response.structured_output or {}
return response.model_dump(mode="python")
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}

View File

@ -236,7 +236,6 @@ def invoke_llm_with_structured_output(
return generator()
@overload
def invoke_llm_with_pydantic_model(
*,
provider: str,
@ -251,24 +250,7 @@ def invoke_llm_with_pydantic_model(
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput: ...
def invoke_llm_with_pydantic_model(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
output_model: type[T],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput:
) -> T:
"""
Invoke large language model with a Pydantic output model.
@ -299,7 +281,7 @@ def invoke_llm_with_pydantic_model(
raise OutputParserError("Structured output is empty")
validated_output = _validate_structured_output(output_model, structured_output)
return result.model_copy(update={"structured_output": validated_output})
return output_model.model_validate(validated_output)
def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
@ -309,12 +291,12 @@ def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
def _validate_structured_output(
output_model: type[T],
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
) -> T:
try:
validated_output = output_model.model_validate(structured_output)
except ValidationError as exc:
raise OutputParserError(f"Structured output validation failed: {exc}") from exc
return validated_output.model_dump(mode="python")
return validated_output
def _handle_native_json_schema(

View File

@ -493,8 +493,8 @@ def test_structured_output_with_pydantic_model():
stream=False,
)
assert isinstance(result, LLMResultWithStructuredOutput)
assert result.structured_output == {"name": "test"}
assert isinstance(result, ExampleOutput)
assert result.name == "test"
def test_structured_output_with_pydantic_model_streaming_rejected():