mirror of https://github.com/langgenius/dify.git
fix: show citations in advanced chat apps (#32985)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f751864ab3
commit
ad81513b6a
|
|
@ -516,8 +516,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
yield from self._handle_advanced_chat_message_end_event(
|
||||
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_partial_success_event(
|
||||
self,
|
||||
|
|
@ -538,6 +540,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield from self._handle_advanced_chat_message_end_event(
|
||||
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
|
|
@ -854,6 +859,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueWorkflowSucceededEvent():
|
||||
yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueWorkflowPartialSuccessEvent():
|
||||
yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
break
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,16 @@ import pytest
|
|||
|
||||
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueuePingEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models.enums import MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.model import EndUser
|
||||
|
|
@ -185,3 +193,97 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
|
|||
|
||||
assert message.answer == "beforeafter"
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
|
||||
|
||||
def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_id = "workflow-1"
|
||||
pipeline._ensure_workflow_initialized = mock.Mock()
|
||||
runtime_state = SimpleNamespace()
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
|
||||
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
|
||||
)
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
|
||||
event=StreamEvent.WORKFLOW_FINISHED,
|
||||
data=SimpleNamespace(status=WorkflowExecutionStatus.SUCCEEDED),
|
||||
)
|
||||
|
||||
event = QueueWorkflowSucceededEvent(outputs={})
|
||||
responses = list(pipeline._handle_workflow_succeeded_event(event))
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
|
||||
|
||||
|
||||
def test_workflow_partial_success_emits_message_end_before_workflow_finished() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_id = "workflow-1"
|
||||
pipeline._ensure_workflow_initialized = mock.Mock()
|
||||
runtime_state = SimpleNamespace()
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
|
||||
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
|
||||
)
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
|
||||
event=StreamEvent.WORKFLOW_FINISHED,
|
||||
data=SimpleNamespace(status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED),
|
||||
)
|
||||
|
||||
event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
responses = list(pipeline._handle_workflow_partial_success_event(event))
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
|
||||
|
||||
|
||||
def test_process_stream_response_breaks_after_workflow_succeeded() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
succeeded_event = QueueWorkflowSucceededEvent(outputs={})
|
||||
ping_event = QueuePingEvent()
|
||||
queue_messages = [
|
||||
SimpleNamespace(event=succeeded_event),
|
||||
SimpleNamespace(event=ping_event),
|
||||
]
|
||||
|
||||
pipeline._conversation_name_generate_thread = None
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
|
||||
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
|
||||
pipeline._handle_workflow_succeeded_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
|
||||
)
|
||||
|
||||
responses = list(pipeline._process_stream_response())
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
|
||||
pipeline._handle_workflow_succeeded_event.assert_called_once_with(succeeded_event, trace_manager=None)
|
||||
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()
|
||||
|
||||
|
||||
def test_process_stream_response_breaks_after_workflow_partial_success() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
partial_event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
ping_event = QueuePingEvent()
|
||||
queue_messages = [
|
||||
SimpleNamespace(event=partial_event),
|
||||
SimpleNamespace(event=ping_event),
|
||||
]
|
||||
|
||||
pipeline._conversation_name_generate_thread = None
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
|
||||
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
|
||||
pipeline._handle_workflow_partial_success_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
|
||||
)
|
||||
|
||||
responses = list(pipeline._process_stream_response())
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
|
||||
pipeline._handle_workflow_partial_success_event.assert_called_once_with(partial_event, trace_manager=None)
|
||||
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -205,6 +205,7 @@ class TestKnowledgeRetrievalNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "result" in result.outputs
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
mock_source.model_dump.assert_called_once_with(by_alias=True)
|
||||
|
||||
def test_run_with_query_variable_multiple_mode(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue