feat: add tool call based structured output

This commit is contained in:
Novice 2026-01-26 14:17:16 +08:00
parent 39799b9db7
commit 87bcd70f59
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
6 changed files with 521 additions and 28 deletions

View File

@ -26,7 +26,7 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ParameterRule
class ResponseFormat(StrEnum):
@ -44,6 +44,12 @@ class SpecialModelType(StrEnum):
OLLAMA = "ollama"
# Tool name for structured output via tool call
STRUCTURED_OUTPUT_TOOL_NAME = "structured_output"
# Features that indicate tool call support
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
T = TypeVar("T", bound=BaseModel)
@ -132,20 +138,24 @@ def invoke_llm_with_structured_output(
file IDs in the output will be automatically converted to File objects.
:return: full response or stream response chunk generator result
"""
# handle native json schema
model_parameters_with_json_schema: dict[str, Any] = {
**(model_parameters or {}),
}
# Determine structured output strategy
if model_schema.support_structure_output:
model_parameters = _handle_native_json_schema(
# Priority 1: Native JSON schema support
model_parameters_with_json_schema = _handle_native_json_schema(
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
)
elif _supports_tool_call(model_schema):
# Priority 2: Tool call based structured output
structured_output_tool = _create_structured_output_tool(json_schema)
tools = [structured_output_tool]
else:
# Set appropriate response format based on model capabilities
# Priority 3: Prompt-based fallback
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
# handle prompt based schema
prompt_messages = _handle_prompt_based_schema(
prompt_messages=prompt_messages,
structured_output_schema=json_schema,
@ -162,12 +172,11 @@ def invoke_llm_with_structured_output(
)
if isinstance(llm_result, LLMResult):
if not isinstance(llm_result.message.content, str):
raise OutputParserError(
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
)
# Non-streaming result
structured_output = _extract_structured_output(llm_result)
structured_output = _parse_structured_output(llm_result.message.content)
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Convert file references if tenant_id is provided
if tenant_id is not None:
@ -189,13 +198,16 @@ def invoke_llm_with_structured_output(
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
result_text: str = ""
tool_call_args: dict[str, str] = {} # tool_call_id -> arguments
prompt_messages: Sequence[PromptMessage] = []
system_fingerprint: str | None = None
for event in llm_result:
if isinstance(event, LLMResultChunk):
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
# Collect text content
if isinstance(event.delta.message.content, str):
result_text += event.delta.message.content
elif isinstance(event.delta.message.content, list):
@ -203,6 +215,13 @@ def invoke_llm_with_structured_output(
if isinstance(item, TextPromptMessageContent):
result_text += item.data
# Collect tool call arguments
if event.delta.message.tool_calls:
for tool_call in event.delta.message.tool_calls:
call_id = tool_call.id or ""
if tool_call.function.arguments:
tool_call_args[call_id] = tool_call_args.get(call_id, "") + tool_call.function.arguments
yield LLMResultChunkWithStructuredOutput(
model=model_schema.model,
prompt_messages=prompt_messages,
@ -210,7 +229,11 @@ def invoke_llm_with_structured_output(
delta=event.delta,
)
structured_output = _parse_structured_output(result_text)
# Extract structured output: prefer tool call, fallback to text
structured_output = _extract_structured_output_from_stream(result_text, tool_call_args)
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Convert file references if tenant_id is provided
if tenant_id is not None:
@ -299,6 +322,144 @@ def _validate_structured_output(
return validated_output
def _supports_tool_call(model_schema: AIModelEntity) -> bool:
"""Check if model supports tool call feature."""
return bool(set(model_schema.features or []) & TOOL_CALL_FEATURES)
def _create_structured_output_tool(json_schema: Mapping[str, Any]) -> PromptMessageTool:
"""Create a tool definition for structured output."""
return PromptMessageTool(
name=STRUCTURED_OUTPUT_TOOL_NAME,
description="Generate structured output according to the provided schema. "
"You MUST call this function to provide your response in the required format.",
parameters=dict(json_schema),
)
def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]:
"""
Extract structured output from LLM result (non-streaming).
First tries to extract from tool_calls (if present), then falls back to text content.
"""
# Try to extract from tool call first
tool_calls = llm_result.message.tool_calls
if tool_calls:
for tool_call in tool_calls:
if tool_call.function.name == STRUCTURED_OUTPUT_TOOL_NAME:
return _parse_tool_call_arguments(tool_call.function.arguments)
# Fallback to text content parsing
content = llm_result.message.content
if not isinstance(content, str):
raise OutputParserError(f"Failed to parse structured output, LLM result is not a string: {content}")
return _parse_structured_output(content)
def _extract_structured_output_from_stream(
result_text: str,
tool_call_args: dict[str, str],
) -> Mapping[str, Any]:
"""
Extract structured output from streaming collected data.
First tries to parse from collected tool call arguments, then falls back to text content.
"""
# Try to parse from tool call arguments first
if tool_call_args:
# Use the first non-empty tool call arguments
for arguments in tool_call_args.values():
if arguments.strip():
return _parse_tool_call_arguments(arguments)
# Fallback to text content parsing
if not result_text:
raise OutputParserError("No tool call arguments and no text content to parse")
return _parse_structured_output(result_text)
def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
"""Parse JSON from tool call arguments."""
if not arguments:
raise OutputParserError("Tool call arguments is empty")
try:
parsed = json.loads(arguments)
if not isinstance(parsed, dict):
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
return parsed
except json.JSONDecodeError:
# Try to repair malformed JSON
repaired = json_repair.loads(arguments)
if not isinstance(repaired, dict):
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
return cast(dict, repaired)
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
"""Get default empty value for a JSON schema type."""
# Handle array of types (e.g., ["string", "null"])
if isinstance(type_name, list):
# Use the first non-null type
type_name = next((t for t in type_name if t != "null"), None)
if type_name == "string":
return ""
elif type_name == "object":
return {}
elif type_name == "array":
return []
elif type_name in {"number", "integer"}:
return 0
elif type_name == "boolean":
return False
elif type_name == "null" or type_name is None:
return None
else:
return None
def fill_defaults_from_schema(
output: Mapping[str, Any],
json_schema: Mapping[str, Any],
) -> dict[str, Any]:
"""
Fill missing required fields in output with default empty values based on JSON schema.
Only fills default values for fields that are marked as required in the schema.
Recursively processes nested objects to fill their required fields as well.
Default values by type:
- string ""
- object {} (with nested required fields filled)
- array []
- number/integer 0
- boolean False
- null None
"""
result = dict(output)
properties = json_schema.get("properties", {})
required_fields = set(json_schema.get("required", []))
for prop_name, prop_schema in properties.items():
prop_type = prop_schema.get("type")
is_required = prop_name in required_fields
if prop_name not in result:
# Field is missing from output
if is_required:
# Only fill default value for required fields
if prop_type == "object" and "properties" in prop_schema:
# Create empty object and recursively fill its required fields
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
else:
result[prop_name] = _get_default_value_for_type(prop_type)
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
# Field exists and is an object, recursively fill nested required fields
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
return result
def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,

View File

@ -6,6 +6,8 @@ from pydantic import BaseModel, ConfigDict
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import (
_get_default_value_for_type,
fill_defaults_from_schema,
invoke_llm_with_pydantic_model,
invoke_llm_with_structured_output,
)
@ -530,3 +532,304 @@ def test_structured_output_with_pydantic_model_validation_error():
output_model=ExampleOutput,
stream=False,
)
class TestGetDefaultValueForType:
"""Test cases for _get_default_value_for_type function"""
def test_string_type(self):
assert _get_default_value_for_type("string") == ""
def test_object_type(self):
assert _get_default_value_for_type("object") == {}
def test_array_type(self):
assert _get_default_value_for_type("array") == []
def test_number_type(self):
assert _get_default_value_for_type("number") == 0
def test_integer_type(self):
assert _get_default_value_for_type("integer") == 0
def test_boolean_type(self):
assert _get_default_value_for_type("boolean") is False
def test_null_type(self):
assert _get_default_value_for_type("null") is None
def test_none_type(self):
assert _get_default_value_for_type(None) is None
def test_unknown_type(self):
assert _get_default_value_for_type("unknown") is None
def test_union_type_string_null(self):
# ["string", "null"] should return "" (first non-null type)
assert _get_default_value_for_type(["string", "null"]) == ""
def test_union_type_null_first(self):
# ["null", "integer"] should return 0 (first non-null type)
assert _get_default_value_for_type(["null", "integer"]) == 0
def test_union_type_only_null(self):
# ["null"] should return None
assert _get_default_value_for_type(["null"]) is None
class TestFillDefaultsFromSchema:
"""Test cases for fill_defaults_from_schema function"""
def test_simple_required_fields(self):
"""Test filling simple required fields"""
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"email": {"type": "string"},
},
"required": ["name", "age"],
}
output = {"name": "Alice"}
result = fill_defaults_from_schema(output, schema)
assert result == {"name": "Alice", "age": 0}
# email is not required, so it should not be added
assert "email" not in result
def test_non_required_fields_not_filled(self):
"""Test that non-required fields are not filled"""
schema = {
"type": "object",
"properties": {
"required_field": {"type": "string"},
"optional_field": {"type": "string"},
},
"required": ["required_field"],
}
output = {}
result = fill_defaults_from_schema(output, schema)
assert result == {"required_field": ""}
assert "optional_field" not in result
def test_nested_object_required_fields(self):
"""Test filling nested object required fields"""
schema = {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string"},
"address": {
"type": "object",
"properties": {
"city": {"type": "string"},
"street": {"type": "string"},
"zipcode": {"type": "string"},
},
"required": ["city", "street"],
},
},
"required": ["name", "email", "address"],
},
},
"required": ["user"],
}
output = {
"user": {
"name": "Alice",
"address": {
"city": "Beijing",
},
}
}
result = fill_defaults_from_schema(output, schema)
assert result == {
"user": {
"name": "Alice",
"email": "", # filled because required
"address": {
"city": "Beijing",
"street": "", # filled because required
# zipcode not filled because not required
},
}
}
def test_missing_nested_object_created(self):
"""Test that missing required nested objects are created"""
schema = {
"type": "object",
"properties": {
"metadata": {
"type": "object",
"properties": {
"created_at": {"type": "string"},
"updated_at": {"type": "string"},
},
"required": ["created_at"],
},
},
"required": ["metadata"],
}
output = {}
result = fill_defaults_from_schema(output, schema)
assert result == {
"metadata": {
"created_at": "",
}
}
def test_all_types_default_values(self):
"""Test default values for all types"""
schema = {
"type": "object",
"properties": {
"str_field": {"type": "string"},
"int_field": {"type": "integer"},
"num_field": {"type": "number"},
"bool_field": {"type": "boolean"},
"arr_field": {"type": "array"},
"obj_field": {"type": "object"},
},
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
}
output = {}
result = fill_defaults_from_schema(output, schema)
assert result == {
"str_field": "",
"int_field": 0,
"num_field": 0,
"bool_field": False,
"arr_field": [],
"obj_field": {},
}
def test_existing_values_preserved(self):
"""Test that existing values are not overwritten"""
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "integer"},
},
"required": ["name", "count"],
}
output = {"name": "Bob", "count": 42}
result = fill_defaults_from_schema(output, schema)
assert result == {"name": "Bob", "count": 42}
def test_complex_nested_structure(self):
"""Test complex nested structure with multiple levels"""
schema = {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string"},
"age": {"type": "integer"},
"address": {
"type": "object",
"properties": {
"city": {"type": "string"},
"street": {"type": "string"},
"zipcode": {"type": "string"},
},
"required": ["city", "street"],
},
},
"required": ["name", "email", "address"],
},
"tags": {"type": "array"},
"orders": {"type": "array"},
"metadata": {
"type": "object",
"properties": {
"created_at": {"type": "string"},
"updated_at": {"type": "string"},
},
"required": ["created_at"],
},
"is_active": {"type": "boolean"},
"notes": {"type": "string"},
},
"required": ["user", "tags", "metadata", "is_active"],
}
output = {
"user": {
"name": "Alice",
"age": 25,
"address": {
"city": "Beijing",
},
},
"orders": [{"id": 1}],
"metadata": {
"updated_at": "2024-01-01",
},
}
result = fill_defaults_from_schema(output, schema)
expected = {
"user": {
"name": "Alice",
"email": "", # required, filled
"age": 25, # not required but exists
"address": {
"city": "Beijing",
"street": "", # required, filled
# zipcode not required
},
},
"tags": [], # required, filled
"orders": [{"id": 1}], # not required but exists
"metadata": {
"created_at": "", # required, filled
"updated_at": "2024-01-01", # exists
},
"is_active": False, # required, filled
# notes not required
}
assert result == expected
def test_empty_schema(self):
"""Test with empty schema"""
schema = {}
output = {"any": "value"}
result = fill_defaults_from_schema(output, schema)
assert result == {"any": "value"}
def test_schema_without_required(self):
"""Test schema without required field"""
schema = {
"type": "object",
"properties": {
"optional1": {"type": "string"},
"optional2": {"type": "integer"},
},
}
output = {}
result = fill_defaults_from_schema(output, schema)
# No required fields, so nothing should be added
assert result == {}

View File

@ -1,7 +1,7 @@
import type { FC } from 'react'
import type { LLMNodeType } from './types'
import type { NodePanelProps } from '@/app/components/workflow/types'
import { RiAlertFill, RiQuestionLine } from '@remixicon/react'
import { RiAlertFill, RiInformationLine, RiQuestionLine } from '@remixicon/react'
import * as React from 'react'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
@ -63,6 +63,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
handleVisionResolutionEnabledChange,
handleVisionResolutionChange,
isModelSupportStructuredOutput,
isModelSupportToolCall,
structuredOutputCollapsed,
setStructuredOutputCollapsed,
handleStructureOutputEnableChange,
@ -299,19 +300,37 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
operations={(
<div className="mr-4 flex shrink-0 items-center">
{(!isModelSupportStructuredOutput && !!inputs.structured_output_enabled) && (
<Tooltip
noDecoration
popupContent={(
<div className="w-[232px] rounded-xl border-[0.5px] border-components-panel-border bg-components-tooltip-bg px-4 py-3.5 shadow-lg backdrop-blur-[5px]">
<div className="title-xs-semi-bold text-text-primary">{t('structOutput.modelNotSupported', { ns: 'app' })}</div>
<div className="body-xs-regular mt-1 text-text-secondary">{t('structOutput.modelNotSupportedTip', { ns: 'app' })}</div>
</div>
)}
>
<div>
<RiAlertFill className="mr-1 size-4 text-text-warning-secondary" />
</div>
</Tooltip>
isModelSupportToolCall
? (
<Tooltip
noDecoration
popupContent={(
<div className="w-[232px] rounded-xl border-[0.5px] border-components-panel-border bg-components-tooltip-bg px-4 py-3.5 shadow-lg backdrop-blur-[5px]">
<div className="title-xs-semi-bold text-text-primary">{t('structOutput.toolCallFallback', { ns: 'app' })}</div>
<div className="body-xs-regular mt-1 text-text-secondary">{t('structOutput.toolCallFallbackTip', { ns: 'app' })}</div>
</div>
)}
>
<div>
<RiInformationLine className="mr-1 size-4 text-text-tertiary" />
</div>
</Tooltip>
)
: (
<Tooltip
noDecoration
popupContent={(
<div className="w-[232px] rounded-xl border-[0.5px] border-components-panel-border bg-components-tooltip-bg px-4 py-3.5 shadow-lg backdrop-blur-[5px]">
<div className="title-xs-semi-bold text-text-primary">{t('structOutput.modelNotSupported', { ns: 'app' })}</div>
<div className="body-xs-regular mt-1 text-text-secondary">{t('structOutput.modelNotSupportedTip', { ns: 'app' })}</div>
</div>
)}
>
<div>
<RiAlertFill className="mr-1 size-4 text-text-warning-secondary" />
</div>
</Tooltip>
)
)}
<div className="system-xs-medium-uppercase mr-0.5 text-text-tertiary">{t('structOutput.structured', { ns: 'app' })}</div>
<Tooltip popupContent={

View File

@ -313,12 +313,17 @@ const useConfig = (id: string, payload: LLMNodeType) => {
// structure output
const { data: modelList } = useModelList(ModelTypeEnum.textGeneration)
const isModelSupportStructuredOutput = modelList
const currentModelFeatures = modelList
?.find(provideItem => provideItem.provider === model?.provider)
?.models
.find(modelItem => modelItem.model === model?.name)
?.features
?.includes(ModelFeatureEnum.StructuredOutput)
const isModelSupportStructuredOutput = currentModelFeatures?.includes(ModelFeatureEnum.StructuredOutput)
const isModelSupportToolCall = currentModelFeatures?.some(
feature => [ModelFeatureEnum.toolCall, ModelFeatureEnum.multiToolCall, ModelFeatureEnum.streamToolCall].includes(feature),
)
const [structuredOutputCollapsed, setStructuredOutputCollapsed] = useState(true)
const handleStructureOutputEnableChange = useCallback((enabled: boolean) => {
@ -394,6 +399,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
handleVisionResolutionEnabledChange,
handleVisionResolutionChange,
isModelSupportStructuredOutput,
isModelSupportToolCall,
handleStructureOutputChange,
structuredOutputCollapsed,
setStructuredOutputCollapsed,

View File

@ -217,6 +217,8 @@
"structOutput.required": "Required",
"structOutput.structured": "Structured",
"structOutput.structuredTip": "Structured Outputs is a feature that ensures the model will always generate responses that adhere to your supplied JSON Schema",
"structOutput.toolCallFallback": "Using Tool Call mode",
"structOutput.toolCallFallbackTip": "The current model does not support native structured output, but supports tool calling. Structured output will be achieved via tool call.",
"switch": "Switch to Workflow Orchestrate",
"switchLabel": "The app copy to be created",
"switchStart": "Start switch",

View File

@ -215,6 +215,8 @@
"structOutput.required": "必填",
"structOutput.structured": "结构化输出",
"structOutput.structuredTip": "结构化输出是一项功能,可确保模型始终生成符合您提供的 JSON 模式的响应",
"structOutput.toolCallFallback": "使用工具调用模式",
"structOutput.toolCallFallbackTip": "当前模型不支持原生结构化输出,但支持工具调用。将通过工具调用实现结构化输出。",
"switch": "迁移为工作流编排",
"switchLabel": "新应用创建为",
"switchStart": "开始迁移",