fix: show citations in advanced chat apps (#32985)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
kurokobo 2026-03-06 10:56:14 +09:00 committed by GitHub
parent f751864ab3
commit ad81513b6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 119 additions and 3 deletions

View File

@ -516,8 +516,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
graph_runtime_state=validated_state,
)
yield from self._handle_advanced_chat_message_end_event(
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_partial_success_event(
self,
@ -538,6 +540,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
exceptions_count=event.exceptions_count,
)
yield from self._handle_advanced_chat_message_end_event(
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
@ -854,6 +859,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
yield from self._handle_workflow_paused_event(event)
break
case QueueWorkflowSucceededEvent():
yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager)
break
case QueueWorkflowPartialSuccessEvent():
yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
break

View File

@ -116,7 +116,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,

View File

@ -9,8 +9,16 @@ import pytest
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
from core.app.entities.queue_entities import (
QueuePingEvent,
QueueTextChunkEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import StreamEvent
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowExecutionStatus
from models.enums import MessageStatus
from models.execution_extra_content import HumanInputContent
from models.model import EndUser
@ -185,3 +193,97 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
assert message.answer == "beforeafter"
assert message.status == MessageStatus.NORMAL
def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None:
pipeline = _build_pipeline()
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
pipeline._workflow_id = "workflow-1"
pipeline._ensure_workflow_initialized = mock.Mock()
runtime_state = SimpleNamespace()
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
)
pipeline._workflow_response_converter = mock.Mock()
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
event=StreamEvent.WORKFLOW_FINISHED,
data=SimpleNamespace(status=WorkflowExecutionStatus.SUCCEEDED),
)
event = QueueWorkflowSucceededEvent(outputs={})
responses = list(pipeline._handle_workflow_succeeded_event(event))
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
def test_workflow_partial_success_emits_message_end_before_workflow_finished() -> None:
pipeline = _build_pipeline()
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
pipeline._workflow_id = "workflow-1"
pipeline._ensure_workflow_initialized = mock.Mock()
runtime_state = SimpleNamespace()
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
)
pipeline._workflow_response_converter = mock.Mock()
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
event=StreamEvent.WORKFLOW_FINISHED,
data=SimpleNamespace(status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED),
)
event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
responses = list(pipeline._handle_workflow_partial_success_event(event))
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
def test_process_stream_response_breaks_after_workflow_succeeded() -> None:
pipeline = _build_pipeline()
succeeded_event = QueueWorkflowSucceededEvent(outputs={})
ping_event = QueuePingEvent()
queue_messages = [
SimpleNamespace(event=succeeded_event),
SimpleNamespace(event=ping_event),
]
pipeline._conversation_name_generate_thread = None
pipeline._base_task_pipeline = mock.Mock()
pipeline._base_task_pipeline.queue_manager = mock.Mock()
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
pipeline._handle_workflow_succeeded_event = mock.Mock(
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
)
responses = list(pipeline._process_stream_response())
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
pipeline._handle_workflow_succeeded_event.assert_called_once_with(succeeded_event, trace_manager=None)
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()
def test_process_stream_response_breaks_after_workflow_partial_success() -> None:
pipeline = _build_pipeline()
partial_event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
ping_event = QueuePingEvent()
queue_messages = [
SimpleNamespace(event=partial_event),
SimpleNamespace(event=ping_event),
]
pipeline._conversation_name_generate_thread = None
pipeline._base_task_pipeline = mock.Mock()
pipeline._base_task_pipeline.queue_manager = mock.Mock()
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
pipeline._handle_workflow_partial_success_event = mock.Mock(
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
)
responses = list(pipeline._process_stream_response())
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
pipeline._handle_workflow_partial_success_event.assert_called_once_with(partial_event, trace_manager=None)
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()

View File

@ -205,6 +205,7 @@ class TestKnowledgeRetrievalNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "result" in result.outputs
assert mock_rag_retrieval.knowledge_retrieval.called
mock_source.model_dump.assert_called_once_with(by_alias=True)
def test_run_with_query_variable_multiple_mode(
self,