From a409e3d32e787585615d2de33e5fa8d1817d8736 Mon Sep 17 00:00:00 2001 From: Stream Date: Fri, 23 Jan 2026 01:02:12 +0800 Subject: [PATCH] refactor: better `/context-generate` with frontend support Signed-off-by: Stream --- api/controllers/console/app/generator.py | 34 +- api/core/llm_generator/context_models.py | 62 +++ api/core/llm_generator/llm_generator.py | 373 +++++------------- .../output_parser/structured_output.py | 26 +- .../test_structured_output_parser.py | 4 +- 5 files changed, 191 insertions(+), 308 deletions(-) create mode 100644 api/core/llm_generator/context_models.py diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index b13b94f67d..2349f16b9b 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -16,6 +16,11 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider +from core.llm_generator.context_models import ( + AvailableVarPayload, + CodeContextPayload, + ParameterInfoPayload, +) from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -58,22 +63,21 @@ class InstructionTemplatePayload(BaseModel): class ContextGeneratePayload(BaseModel): """Payload for generating extractor code node.""" - workflow_id: str = Field(..., description="Workflow ID") - node_id: str = Field(..., description="Current tool/llm node ID") - parameter_name: str = Field(..., description="Parameter name to generate code for") language: str = Field(default="python3", description="Code language (python3/javascript)") prompt_messages: list[dict[str, Any]] = Field( ..., description="Multi-turn conversation history, last message is the current instruction" ) model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes") + parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend") + code_context: CodeContextPayload | None = Field( + default=None, description="Existing code node context for incremental generation" + ) class SuggestedQuestionsPayload(BaseModel): """Payload for generating suggested questions.""" - workflow_id: str = Field(..., description="Workflow ID") - node_id: str = Field(..., description="Current tool/llm node ID") - parameter_name: str = Field(..., description="Parameter name") language: str = Field( default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)" ) @@ -82,6 +86,8 @@ class SuggestedQuestionsPayload(BaseModel): alias="model_config", description="Model configuration (optional, uses system default if not provided)", ) + available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes") + parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend") def reg(cls: type[BaseModel]): @@ -328,17 +334,15 @@ class ContextGenerateApi(Resource): args = ContextGeneratePayload.model_validate(console_ns.payload) _, current_tenant_id = current_account_with_tenant() - prompt_messages = deserialize_prompt_messages(args.prompt_messages) - try: return LLMGenerator.generate_with_context( tenant_id=current_tenant_id, - workflow_id=args.workflow_id, - node_id=args.node_id, - parameter_name=args.parameter_name, language=args.language, - prompt_messages=prompt_messages, + prompt_messages=deserialize_prompt_messages(args.prompt_messages), model_config=args.model_config_data, + available_vars=args.available_vars, + parameter_info=args.parameter_info, + code_context=args.code_context, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -362,14 +366,12 @@ class SuggestedQuestionsApi(Resource): def post(self): args = SuggestedQuestionsPayload.model_validate(console_ns.payload) _, current_tenant_id = current_account_with_tenant() - try: return LLMGenerator.generate_suggested_questions( tenant_id=current_tenant_id, - workflow_id=args.workflow_id, - node_id=args.node_id, - parameter_name=args.parameter_name, language=args.language, + available_vars=args.available_vars, + parameter_info=args.parameter_info, model_config=args.model_config_data, ) except ProviderTokenNotInitError as ex: diff --git a/api/core/llm_generator/context_models.py b/api/core/llm_generator/context_models.py new file mode 100644 index 0000000000..66db0ac64b --- /dev/null +++ b/api/core/llm_generator/context_models.py @@ -0,0 +1,62 @@ +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class VariableSelectorPayload(BaseModel): + model_config = ConfigDict(extra="forbid") + + variable: str = Field(..., description="Variable name used in generated code") + value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]") + + +class CodeOutputPayload(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: str = Field(..., description="Output variable type") + + +class CodeContextPayload(BaseModel): + # From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot). + model_config = ConfigDict(extra="forbid") + + code: str = Field(..., description="Existing code in the Code node") + outputs: dict[str, CodeOutputPayload] | None = Field( + default=None, description="Existing output definitions for the Code node" + ) + variables: list[VariableSelectorPayload] | None = Field( + default=None, description="Existing variable selectors used by the Code node" + ) + + +class AvailableVarPayload(BaseModel): + # From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables). + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + value_selector: list[str] = Field(..., description="Path to upstream node output") + type: str = Field(..., description="Variable type, e.g. string, number, array[object]") + description: str | None = Field(default=None, description="Optional variable description") + node_id: str | None = Field(default=None, description="Source node ID") + node_title: str | None = Field(default=None, description="Source node title") + node_type: str | None = Field(default=None, description="Source node type") + json_schema: dict[str, Any] | None = Field( + default=None, + alias="schema", + description="Optional JSON schema for object variables", + ) + + +class ParameterInfoPayload(BaseModel): + # From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata). + model_config = ConfigDict(extra="forbid") + + name: str = Field(..., description="Target parameter name") + type: str = Field(default="string", description="Target parameter type") + description: str = Field(default="", description="Parameter description") + required: bool | None = Field(default=None, description="Whether the parameter is required") + options: list[str] | None = Field(default=None, description="Allowed option values") + min: float | None = Field(default=None, description="Minimum numeric value") + max: float | None = Field(default=None, description="Maximum numeric value") + default: str | int | float | bool | None = Field(default=None, description="Default value") + multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values") + label: str | None = Field(default=None, description="Optional display label") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index aba069333a..efa9e05727 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,11 +1,16 @@ import json import logging import re -from collections.abc import Mapping, Sequence -from typing import Any, Protocol, cast +from collections.abc import Sequence +from typing import Protocol import json_repair +from core.llm_generator.context_models import ( + AvailableVarPayload, + CodeContextPayload, + ParameterInfoPayload, +) from core.llm_generator.output_models import ( CodeNodeStructuredOutput, InstructionModifyOutput, @@ -132,9 +137,6 @@ class LLMGenerator: return [] prompt_messages = [UserPromptMessage(content=prompt)] - - questions: Sequence[str] = [] - try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), @@ -402,24 +404,24 @@ class LLMGenerator: def generate_with_context( cls, tenant_id: str, - workflow_id: str, - node_id: str, - parameter_name: str, language: str, prompt_messages: list[PromptMessage], model_config: dict, + available_vars: Sequence[AvailableVarPayload], + parameter_info: ParameterInfoPayload, + code_context: CodeContextPayload, ) -> dict: """ Generate extractor code node based on conversation context. Args: tenant_id: Tenant/workspace ID - workflow_id: Workflow ID - node_id: Current tool/llm node ID - parameter_name: Parameter name to generate code for language: Code language (python3/javascript) prompt_messages: Multi-turn conversation history (last message is instruction) model_config: Model configuration (provider, name, completion_params) + available_vars: Client-provided available variables with types/schema + parameter_info: Client-provided parameter metadata (type/constraints) + code_context: Client-provided existing code node context Returns: dict with CodeNodeData format: @@ -430,43 +432,12 @@ class LLMGenerator: - message: Explanation - error: Error message if any """ - from sqlalchemy import select - from sqlalchemy.orm import Session - from services.workflow_service import WorkflowService + # available_vars/parameter_info/code_context are provided by the frontend context-generate modal. + # See web/app/components/workflow/nodes/tool/components/context-generate-modal/hooks/use-context-generate.ts - # Get workflow - with Session(db.engine) as session: - stmt = select(App).where(App.id == workflow_id) - app = session.scalar(stmt) - if not app: - return cls._error_response(f"App {workflow_id} not found") - - workflow = WorkflowService().get_draft_workflow(app_model=app) - if not workflow: - return cls._error_response(f"Workflow for app {workflow_id} not found") - - # Get upstream nodes via edge backtracking - upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id) - - # Get current node info - current_node = cls._get_node_by_id(workflow.graph_dict, node_id) - if not current_node: - return cls._error_response(f"Node {node_id} not found") - - # Get parameter info - parameter_info = cls._get_parameter_info( - tenant_id=tenant_id, - node_data=current_node.get("data", {}), - parameter_name=parameter_name, - ) - - # Build system prompt system_prompt = cls._build_extractor_system_prompt( - upstream_nodes=upstream_nodes, - current_node=current_node, - parameter_info=parameter_info, - language=language, + available_vars=available_vars, parameter_info=parameter_info, language=language, code_context=code_context ) # Construct complete prompt_messages with system prompt @@ -504,9 +475,14 @@ class LLMGenerator: tenant_id=tenant_id, ) - return cls._parse_code_node_output( - response.structured_output, language, parameter_info.get("type", "string") - ) + return { + "variables": response.variables, + "code_language": language, + "code": response.code, + "outputs": response.outputs, + "message": response.explanation, + "error": "", + } except InvokeError as e: return cls._error_response(str(e)) @@ -530,49 +506,24 @@ class LLMGenerator: def generate_suggested_questions( cls, tenant_id: str, - workflow_id: str, - node_id: str, - parameter_name: str, language: str, - model_config: dict | None = None, + available_vars: Sequence[AvailableVarPayload], + parameter_info: ParameterInfoPayload, + model_config: dict, ) -> dict: """ Generate suggested questions for context generation. Returns dict with questions array and error field. """ - from sqlalchemy import select - from sqlalchemy.orm import Session from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model - from services.workflow_service import WorkflowService - - # Get workflow context (reuse existing logic) - with Session(db.engine) as session: - stmt = select(App).where(App.id == workflow_id) - app = session.scalar(stmt) - if not app: - return {"questions": [], "error": f"App {workflow_id} not found"} - - workflow = WorkflowService().get_draft_workflow(app_model=app) - if not workflow: - return {"questions": [], "error": f"Workflow for app {workflow_id} not found"} - - upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id) - current_node = cls._get_node_by_id(workflow.graph_dict, node_id) - if not current_node: - return {"questions": [], "error": f"Node {node_id} not found"} - - parameter_info = cls._get_parameter_info( - tenant_id=tenant_id, - node_data=current_node.get("data", {}), - parameter_name=parameter_name, - ) + # available_vars/parameter_info are provided by the frontend context-generate modal. + # See web/app/components/workflow/nodes/tool/components/context-generate-modal/hooks/use-context-generate.ts # Build prompt system_prompt = cls._build_suggested_questions_prompt( - upstream_nodes=upstream_nodes, - current_node=current_node, + available_vars=available_vars, parameter_info=parameter_info, language=language, ) @@ -617,8 +568,7 @@ class LLMGenerator: tenant_id=tenant_id, ) - questions = response.structured_output.get("questions", []) if response.structured_output else [] - return {"questions": questions, "error": ""} + return {"questions": response.questions, "error": ""} except InvokeError as e: return {"questions": [], "error": str(e)} @@ -629,213 +579,100 @@ class LLMGenerator: @classmethod def _build_suggested_questions_prompt( cls, - upstream_nodes: list[dict], - current_node: dict, - parameter_info: dict, + available_vars: Sequence[AvailableVarPayload], + parameter_info: ParameterInfoPayload, language: str = "English", ) -> str: """Build minimal prompt for suggested questions generation.""" - # Simplify upstream nodes to reduce tokens - sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]] - param_type = parameter_info.get("type", "string") - param_desc = parameter_info.get("description", "")[:100] + parameter_block = cls._format_parameter_info(parameter_info) + available_vars_block = cls._format_available_vars( + available_vars, + max_items=30, + max_schema_chars=400, + max_description_chars=120, + ) - return f"""Suggest 3 code generation questions for extracting data. -Sources: {", ".join(sources)} -Target: {parameter_info.get("name")}({param_type}) - {param_desc} -Output 3 short, practical questions in {language}.""" + return f"""Suggest exactly 3 short questions that would help generate code for the target parameter. + +## Target Parameter +{parameter_block} + +## Available Variables +{available_vars_block} + +## Constraints +- Output exactly 3 questions. +- Use {language}. +- Keep each question short and practical. +- Do not include code or variable syntax in the questions. +""" @classmethod - def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]: - """ - Get all upstream nodes via edge backtracking. - - Traverses the graph backwards from node_id to collect all reachable nodes. - """ - from collections import defaultdict - - nodes = {n["id"]: n for n in graph_dict.get("nodes", [])} - edges = graph_dict.get("edges", []) - - # Build reverse adjacency list - reverse_adj: dict[str, list[str]] = defaultdict(list) - for edge in edges: - reverse_adj[edge["target"]].append(edge["source"]) - - # BFS to find all upstream nodes - visited: set[str] = set() - queue = [node_id] - upstream: list[dict] = [] - - while queue: - current = queue.pop(0) - for source in reverse_adj.get(current, []): - if source not in visited: - visited.add(source) - queue.append(source) - if source in nodes: - upstream.append(cls._extract_node_info(nodes[source])) - - return upstream + def _format_parameter_info(cls, parameter_info: ParameterInfoPayload) -> str: + payload = parameter_info.model_dump(mode="python", by_alias=True) + return json.dumps(payload, ensure_ascii=False) @classmethod - def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None: - """Get node by ID from graph.""" - for node in graph_dict.get("nodes", []): - if node["id"] == node_id: - return node - return None + def _format_available_vars( + cls, + available_vars: Sequence[AvailableVarPayload], + *, + max_items: int, + max_schema_chars: int, + max_description_chars: int, + ) -> str: + payload = [item.model_dump(mode="python", by_alias=True) for item in available_vars] + return json.dumps(payload, ensure_ascii=False) @classmethod - def _extract_node_info(cls, node: dict) -> dict: - """Extract minimal node info with outputs based on node type.""" - node_type = node["data"]["type"] - node_data = node.get("data", {}) - - # Build outputs based on node type (only type, no description to reduce tokens) - outputs: dict[str, str] = {} - match node_type: - case "start": - for var in node_data.get("variables", []): - name = var.get("variable", var.get("name", "")) - outputs[name] = var.get("type", "string") - case "llm": - outputs["text"] = "string" - case "code": - for name, output in node_data.get("outputs", {}).items(): - outputs[name] = output.get("type", "string") - case "http-request": - outputs = {"body": "string", "status_code": "number", "headers": "object"} - case "knowledge-retrieval": - outputs["result"] = "array[object]" - case "tool": - outputs = {"text": "string", "json": "object"} - case _: - outputs["output"] = "string" - - info: dict = { - "id": node["id"], - "title": node_data.get("title", node["id"]), - "outputs": outputs, - } - # Only include description if not empty - desc = node_data.get("desc", "") - if desc: - info["desc"] = desc - - return info - - @classmethod - def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict: - """Get parameter info from tool schema using ToolManager.""" - default_info = {"name": parameter_name, "type": "string", "description": ""} - - if node_data.get("type") != "tool": - return default_info - - try: - from core.app.entities.app_invoke_entities import InvokeFrom - from core.tools.entities.tool_entities import ToolProviderType - from core.tools.tool_manager import ToolManager - - provider_type_str = node_data.get("provider_type", "") - provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN - - tool_runtime = ToolManager.get_tool_runtime( - provider_type=provider_type, - provider_id=node_data.get("provider_id", ""), - tool_name=node_data.get("tool_name", ""), - tenant_id=tenant_id, - invoke_from=InvokeFrom.DEBUGGER, - ) - - parameters = tool_runtime.get_merged_runtime_parameters() - for param in parameters: - if param.name == parameter_name: - return { - "name": param.name, - "type": param.type.value if hasattr(param.type, "value") else str(param.type), - "description": param.llm_description - or (param.human_description.en_US if param.human_description else ""), - "required": param.required, - } - except Exception as e: - logger.debug("Failed to get parameter info from ToolManager: %s", e) - - return default_info + def _format_code_context(cls, code_context: CodeContextPayload | None) -> str: + if not code_context: + return "" + code = code_context.code + outputs = code_context.outputs + variables = code_context.variables + if not code and not outputs and not variables: + return "" + payload = code_context.model_dump(mode="python", by_alias=True) + return json.dumps(payload, ensure_ascii=False) @classmethod def _build_extractor_system_prompt( cls, - upstream_nodes: list[dict], - current_node: dict, - parameter_info: dict, + available_vars: Sequence[AvailableVarPayload], + parameter_info: ParameterInfoPayload, language: str, + code_context: CodeContextPayload, ) -> str: """Build system prompt for extractor code generation.""" - upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False) - param_type = parameter_info.get("type", "string") - return f"""You are a code generator for workflow automation. + param_type = parameter_info.type or "string" + parameter_block = cls._format_parameter_info(parameter_info) + available_vars_block = cls._format_available_vars( + available_vars, + max_items=80, + max_schema_chars=800, + max_description_chars=160, + ) + code_context_block = cls._format_code_context(code_context) + code_context_section = f"\n{code_context_block}\n" if code_context_block else "\n" + return f"""You are a code generator for Dify workflow automation. -Generate {language} code to extract/transform upstream node outputs for the target parameter. +Generate {language} code to extract/transform available variables for the target parameter. -## Upstream Nodes -{upstream_json} +## Target Parameter +{parameter_block} -## Target -Node: {current_node["data"].get("title", current_node["id"])} -Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")} - -## Requirements -- Write a main function that returns type: {param_type} -- Use value_selector format: ["node_id", "output_name"] +## Available Variables +{available_vars_block} +{code_context_section}## Requirements +- Use only the listed value_selector paths. +- Do not invent variables or fields that are not listed. +- Write a main function that returns type: {param_type}. +- Respect target constraints (options/min/max/default/multiple) if provided. +- If existing code is provided, adapt it instead of rewriting from scratch. +- Return only JSON that matches the provided schema. """ - @classmethod - def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict: - """ - Parse structured output to CodeNodeData format. - - Args: - content: Structured output dict from invoke_llm_with_structured_output - language: Code language - parameter_type: Expected parameter type - - Returns dict with variables, code_language, code, outputs, message, error. - """ - if content is None: - return cls._error_response("Empty or invalid response from LLM") - - # Validate and normalize variables - variables = [ - {"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])} - for v in content.get("variables", []) - if isinstance(v, dict) - ] - - # Convert outputs from array format [{name, type}] to dict format {name: {type}} - # Array format is required for OpenAI/Azure strict JSON schema compatibility - raw_outputs = content.get("outputs", []) - if isinstance(raw_outputs, list): - outputs = { - item.get("name", "result"): {"type": item.get("type", parameter_type)} - for item in raw_outputs - if isinstance(item, dict) and item.get("name") - } - if not outputs: - outputs = {"result": {"type": parameter_type}} - else: - outputs = raw_outputs or {"result": {"type": parameter_type}} - - return { - "variables": variables, - "code_language": language, - "code": content.get("code", ""), - "outputs": outputs, - "message": content.get("explanation", ""), - "error": "", - } - @staticmethod def instruction_modify_legacy( tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None @@ -891,7 +728,7 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de raise ValueError("Workflow not found for the given app model.") last_run = workflow_service.get_node_last_run(app_model=app, workflow=workflow, node_id=node_id) try: - node_type = cast(WorkflowNodeExecutionModel, last_run).node_type + node_type = last_run.node_type except Exception: try: node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][ @@ -1010,7 +847,7 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de model_parameters=model_parameters, stream=False, ) - return response.structured_output or {} + return response.model_dump(mode="python") except InvokeError as e: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 7e931fed32..924a2df783 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -236,7 +236,6 @@ def invoke_llm_with_structured_output( return generator() -@overload def invoke_llm_with_pydantic_model( *, provider: str, @@ -251,24 +250,7 @@ def invoke_llm_with_pydantic_model( user: str | None = None, callbacks: list[Callback] | None = None, tenant_id: str | None = None, -) -> LLMResultWithStructuredOutput: ... - - -def invoke_llm_with_pydantic_model( - *, - provider: str, - model_schema: AIModelEntity, - model_instance: ModelInstance, - prompt_messages: Sequence[PromptMessage], - output_model: type[T], - model_parameters: Mapping | None = None, - tools: Sequence[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = False, - user: str | None = None, - callbacks: list[Callback] | None = None, - tenant_id: str | None = None, -) -> LLMResultWithStructuredOutput: +) -> T: """ Invoke large language model with a Pydantic output model. @@ -299,7 +281,7 @@ def invoke_llm_with_pydantic_model( raise OutputParserError("Structured output is empty") validated_output = _validate_structured_output(output_model, structured_output) - return result.model_copy(update={"structured_output": validated_output}) + return output_model.model_validate(validated_output) def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]: @@ -309,12 +291,12 @@ def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]: def _validate_structured_output( output_model: type[T], structured_output: Mapping[str, Any], -) -> dict[str, Any]: +) -> T: try: validated_output = output_model.model_validate(structured_output) except ValidationError as exc: raise OutputParserError(f"Structured output validation failed: {exc}") from exc - return validated_output.model_dump(mode="python") + return validated_output def _handle_native_json_schema( diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 9742590cd4..3f24c5c8ae 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -493,8 +493,8 @@ def test_structured_output_with_pydantic_model(): stream=False, ) - assert isinstance(result, LLMResultWithStructuredOutput) - assert result.structured_output == {"name": "test"} + assert isinstance(result, ExampleOutput) + assert result.name == "test" def test_structured_output_with_pydantic_model_streaming_rejected():