test: unit test cases for core.app.apps module (#32482)

This commit is contained in:
Rajat Agarwal 2026-03-12 08:53:25 +05:30 committed by GitHub
parent 44713a5c0f
commit 0045e387f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 8006 additions and 3 deletions

View File

@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):

View File

@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@ -0,0 +1,75 @@
from types import SimpleNamespace
from unittest.mock import patch
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from models.model import AppMode
class TestAdvancedChatAppConfigManager:
def test_get_app_config(self):
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value)
workflow = SimpleNamespace(id="wf-1", features_dict={})
with (
patch(
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
return_value=None,
),
patch(
"core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert",
return_value=[],
),
):
app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow)
assert app_config.workflow_id == "wf-1"
assert app_config.app_mode == AppMode.ADVANCED_CHAT
def test_config_validate_filters_keys(self):
def _add_key(key, value):
def _inner(*args, **kwargs):
config = kwargs.get("config") if kwargs else args[-1]
config = {**config, key: value}
return config, [key]
return _inner
with (
patch(
"core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
side_effect=_add_key("file_upload", 1),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
side_effect=_add_key("opening_statement", 2),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
side_effect=_add_key("suggested_questions_after_answer", 3),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
side_effect=_add_key("speech_to_text", 4),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
side_effect=_add_key("text_to_speech", 5),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
side_effect=_add_key("retriever_resource", 6),
),
patch(
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
side_effect=_add_key("sensitive_word_avoidance", 7),
),
):
filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={})
assert filtered["file_upload"] == 1
assert filtered["opening_statement"] == 2
assert filtered["suggested_questions_after_answer"] == 3
assert filtered["speech_to_text"] == 4
assert filtered["text_to_speech"] == 5
assert filtered["retriever_resource"] == 6
assert filtered["sensitive_word_avoidance"] == 7

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,96 @@
from collections.abc import Generator
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
)
from dify_graph.enums import WorkflowNodeExecutionStatus
class TestAdvancedChatGenerateResponseConverter:
def test_blocking_simple_response_metadata(self):
data = ChatbotAppBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
answer="hi",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
)
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "usage" not in response["metadata"]
def test_stream_simple_response_includes_node_events(self):
node_start = NodeStartStreamResponse(
task_id="t1",
workflow_run_id="r1",
data=NodeStartStreamResponse.Data(
id="e1",
node_id="n1",
node_type="answer",
title="Answer",
index=1,
created_at=1,
),
)
node_finish = NodeFinishStreamResponse(
task_id="t1",
workflow_run_id="r1",
data=NodeFinishStreamResponse.Data(
id="e1",
node_id="n1",
node_type="answer",
title="Answer",
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
elapsed_time=0.1,
created_at=1,
finished_at=2,
),
)
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=PingStreamResponse(task_id="t1"),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=node_start,
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=node_finish,
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
)
converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert converted[0] == "ping"
assert converted[1]["event"] == "node_started"
assert converted[2]["event"] == "node_finished"
assert converted[3]["event"] == "error"

View File

@ -0,0 +1,600 @@
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from types import SimpleNamespace
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
from dify_graph.enums import NodeType
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from models.enums import MessageStatus
from models.model import AppMode, EndUser
def _make_pipeline():
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.ADVANCED_CHAT,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
query="hello",
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
trace_manager=None,
workflow_run_id="run-id",
)
message = SimpleNamespace(
id="message-id",
query="hello",
created_at=datetime.utcnow(),
status=MessageStatus.NORMAL,
answer="",
)
conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT)
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session")
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
conversation=conversation,
message=message,
user=user,
stream=False,
dialogue_count=1,
draft_var_saver_factory=lambda **kwargs: None,
)
return pipeline
class TestAdvancedChatGenerateTaskPipeline:
def test_ensure_workflow_initialized_raises(self):
pipeline = _make_pipeline()
with pytest.raises(ValueError, match="workflow run not initialized"):
pipeline._ensure_workflow_initialized()
def test_to_blocking_response_returns_message_end(self):
pipeline = _make_pipeline()
pipeline._task_state.answer = "done"
def _gen():
yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"})
response = pipeline._to_blocking_response(_gen())
assert response.data.answer == "done"
assert response.data.metadata == {"k": "v"}
def test_handle_text_chunk_event_updates_state(self):
pipeline = _make_pipeline()
pipeline._message_cycle_manager = SimpleNamespace(
message_to_stream_response=lambda **kwargs: MessageEndStreamResponse(
task_id="task", id="message-id", metadata={}
)
)
event = SimpleNamespace(text="hi", from_variable_selector=None)
responses = list(pipeline._handle_text_chunk_event(event))
assert pipeline._task_state.answer == "hi"
assert responses
def test_listen_audio_msg_returns_audio_stream(self):
pipeline = _make_pipeline()
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
assert isinstance(response, MessageAudioStreamResponse)
def test_handle_ping_event(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
assert isinstance(responses[0], PingStreamResponse)
def test_handle_error_event(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
@contextmanager
def _fake_session():
yield SimpleNamespace()
pipeline._database_session = _fake_session
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
assert isinstance(responses[0], ValueError)
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
@contextmanager
def _fake_session():
yield SimpleNamespace()
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace())
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
assert pipeline._workflow_run_id == "run-id"
assert responses == ["started"]
def test_message_end_to_stream_response_strips_annotation_reply(self):
pipeline = _make_pipeline()
pipeline._task_state.metadata.annotation_reply = AnnotationReply(
id="ann",
account=AnnotationReplyAccount(id="acc", name="acc"),
)
response = pipeline._message_end_to_stream_response()
assert "annotation_reply" not in response.metadata
def test_handle_output_moderation_chunk_publishes_stop(self):
pipeline = _make_pipeline()
events: list[object] = []
class _Moderation:
def should_direct_output(self):
return True
def get_final_output(self):
return "final"
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
pipeline._base_task_pipeline.queue_manager = SimpleNamespace(
publish=lambda event, pub_from: events.append(event)
)
result = pipeline._handle_output_moderation_chunk("ignored")
assert result is True
assert pipeline._task_state.answer == "final"
assert any(isinstance(event, QueueTextChunkEvent) for event in events)
assert any(isinstance(event, QueueStopEvent) for event in events)
def test_handle_node_succeeded_event_records_files(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [
{"type": "file", "transfer_method": "local"}
]
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
pipeline._save_output_for_event = lambda event, node_execution_id: None
event = SimpleNamespace(
node_type=NodeType.ANSWER,
outputs={"k": "v"},
node_execution_id="exec",
node_id="node",
)
responses = list(pipeline._handle_node_succeeded_event(event))
assert responses == ["done"]
assert pipeline._recorded_files
def test_iteration_and_loop_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = (
lambda **kwargs: "iter_start"
)
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = (
lambda **kwargs: "iter_done"
)
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
iter_start = QueueIterationStartEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
iter_next = QueueIterationNextEvent(
index=1,
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
node_run_index=1,
)
iter_done = QueueIterationCompletedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
loop_start = QueueLoopStartEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
loop_next = QueueLoopNextEvent(
index=1,
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
node_run_index=1,
)
loop_done = QueueLoopCompletedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"]
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"]
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"]
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"]
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
def test_workflow_finish_handlers(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"]
pipeline._persist_human_input_extra_content = lambda **kwargs: None
pipeline._save_message = lambda **kwargs: None
pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id")
@contextmanager
def _fake_session():
yield SimpleNamespace(scalar=lambda *args, **kwargs: None)
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={})))
assert len(succeeded_responses) == 2
assert isinstance(succeeded_responses[0], MessageEndStreamResponse)
assert succeeded_responses[1] == "finish"
partial_success_responses = list(
pipeline._handle_workflow_partial_success_event(
QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
)
)
assert len(partial_success_responses) == 2
assert isinstance(partial_success_responses[0], MessageEndStreamResponse)
assert partial_success_responses[1] == "finish"
assert (
list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0]
== "finish"
)
assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [
"pause"
]
def test_node_failure_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish"
pipeline._save_output_for_event = lambda event, node_execution_id: None
failed_event = QueueNodeFailedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="err",
)
exc_event = QueueNodeExceptionEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="err",
)
assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"]
assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"]
def test_handle_text_chunk_event_tracks_streaming_metrics(self):
pipeline = _make_pipeline()
published: list[object] = []
class _Publisher:
def publish(self, message):
published.append(message)
pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk")
event = SimpleNamespace(text="hi", from_variable_selector=["a"])
queue_message = SimpleNamespace(event=event)
responses = list(
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
)
assert responses == ["chunk"]
assert pipeline._task_state.is_streaming_response is True
assert pipeline._task_state.first_token_time is not None
assert pipeline._task_state.last_token_time is not None
assert pipeline._task_state.answer == "hi"
assert published == [queue_message]
def test_handle_output_moderation_chunk_appends_token(self):
pipeline = _make_pipeline()
seen: list[str] = []
class _Moderation:
def should_direct_output(self):
return False
def append_new_token(self, text):
seen.append(text)
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
result = pipeline._handle_output_moderation_chunk("token")
assert result is False
assert seen == ["token"]
def test_handle_retriever_and_annotation_events(self):
pipeline = _make_pipeline()
calls = {"retriever": 0, "annotation": 0}
def _hit_retriever(event):
calls["retriever"] += 1
def _hit_annotation(event):
calls["annotation"] += 1
pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever
pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation
retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[])
annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann")
assert list(pipeline._handle_retriever_resources_event(retriever_event)) == []
assert list(pipeline._handle_annotation_reply_event(annotation_event)) == []
assert calls == {"retriever": 1, "annotation": 1}
def test_handle_message_replace_event(self):
pipeline = _make_pipeline()
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
event = QueueMessageReplaceEvent(
text="new",
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
assert list(pipeline._handle_message_replace_event(event)) == ["replace"]
def test_handle_human_input_events(self):
pipeline = _make_pipeline()
persisted: list[str] = []
pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved")
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
filled_event = QueueHumanInputFormFilledEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="title",
rendered_content="content",
action_id="action",
action_text="action",
)
timeout_event = QueueHumanInputFormTimeoutEvent(
node_id="node",
node_type=NodeType.LLM,
node_title="title",
expiration_time=datetime.utcnow(),
)
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
assert persisted == ["saved"]
def test_save_message_strips_markdown_and_sets_usage(self):
pipeline = _make_pipeline()
pipeline._recorded_files = [
{
"type": "image",
"transfer_method": "remote",
"remote_url": "http://example.com/file.png",
"related_id": "file-id",
}
]
pipeline._task_state.answer = "![img](url) hello"
pipeline._task_state.is_streaming_response = True
pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1
pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2
message = SimpleNamespace(
id="message-id",
status=MessageStatus.PAUSED,
answer="",
updated_at=None,
provider_response_latency=None,
message_tokens=None,
message_unit_price=None,
message_price_unit=None,
answer_tokens=None,
answer_unit_price=None,
answer_price_unit=None,
total_price=None,
currency=None,
message_metadata=None,
invoke_from=InvokeFrom.WEB_APP,
from_account_id=None,
from_end_user_id="end-user",
)
class _Session:
def scalar(self, *args, **kwargs):
return message
def add_all(self, items):
self.items = items
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state)
assert message.status == MessageStatus.NORMAL
assert message.answer == "hello"
assert message.message_metadata
def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._message_end_to_stream_response = lambda: "end"
saved: list[str] = []
def _save_message(**kwargs):
saved.append("saved")
pipeline._save_message = _save_message
@contextmanager
def _fake_session():
yield SimpleNamespace()
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)))
assert responses == ["end"]
assert saved == ["saved"]
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
pipeline._message_end_to_stream_response = lambda: "end"
saved: list[str] = []
def _save_message(**kwargs):
saved.append("saved")
pipeline._save_message = _save_message
@contextmanager
def _fake_session():
yield SimpleNamespace()
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent()))
assert responses == ["replace", "end"]
assert saved == ["saved"]
def test_dispatch_event_handles_node_exception(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
pipeline._save_output_for_event = lambda *args, **kwargs: None
event = QueueNodeExceptionEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="err",
)
assert list(pipeline._dispatch_event(event)) == ["failed"]

View File

@ -0,0 +1,302 @@
import uuid
from types import SimpleNamespace
import pytest
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.agent_chat.app_config_manager import (
AgentChatAppConfigManager,
)
from core.entities.agent_entities import PlanningStrategy
class TestAgentChatAppConfigManagerGetAppConfig:
def test_get_app_config_override_config(self, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"ignored": True}
override_config = {"model": {"provider": "p"}}
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
return_value=("variables", "external"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
)
result = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=None,
override_config_dict=override_config,
)
assert result.app_model_config_dict == override_config
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
assert result.variables == "variables"
assert result.external_data_variables == "external"
def test_get_app_config_conversation_specific(self, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
conversation = mocker.MagicMock()
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
return_value=("variables", "external"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
)
result = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=None,
)
assert result.app_model_config_dict == app_model_config.to_dict.return_value
assert result.app_model_config_from.value == "conversation-specific-config"
def test_get_app_config_latest_config(self, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
return_value=("variables", "external"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
)
result = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=None,
override_config_dict=None,
)
assert result.app_model_config_from.value == "app-latest-config"
class TestAgentChatAppConfigManagerConfigValidate:
def test_config_validate_filters_related_keys(self, mocker):
config = {
"model": {},
"user_input_form": {},
"file_upload": {},
"prompt_template": {},
"agent_mode": {},
"opening_statement": {},
"suggested_questions_after_answer": {},
"speech_to_text": {},
"text_to_speech": {},
"retriever_resource": {},
"dataset": {},
"moderation": {},
"extra": "value",
}
def return_with_key(key):
return config, [key]
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
side_effect=lambda tenant_id, cfg: return_with_key("model"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("file_upload"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
side_effect=lambda app_mode, cfg: return_with_key("prompt_template"),
)
mocker.patch.object(
AgentChatAppConfigManager,
"validate_agent_mode_and_set_defaults",
side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("opening_statement"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("speech_to_text"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("text_to_speech"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
side_effect=lambda cfg: return_with_key("retriever_resource"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"),
)
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
side_effect=lambda tenant_id, cfg: return_with_key("moderation"),
)
filtered = AgentChatAppConfigManager.config_validate("tenant", config)
assert set(filtered.keys()) == {
"model",
"user_input_form",
"file_upload",
"prompt_template",
"agent_mode",
"opening_statement",
"suggested_questions_after_answer",
"speech_to_text",
"text_to_speech",
"retriever_resource",
"dataset",
"moderation",
}
assert "extra" not in filtered
class TestValidateAgentModeAndSetDefaults:
def test_defaults_when_missing(self):
config = {}
updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
assert "agent_mode" in updated
assert updated["agent_mode"]["enabled"] is False
assert updated["agent_mode"]["tools"] == []
assert keys == ["agent_mode"]
@pytest.mark.parametrize(
"agent_mode",
["invalid", 123],
)
def test_agent_mode_type_validation(self, agent_mode):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode})
def test_agent_mode_empty_list_defaults(self):
config = {"agent_mode": []}
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
assert updated["agent_mode"]["enabled"] is False
assert updated["agent_mode"]["tools"] == []
def test_enabled_must_be_bool(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}})
def test_strategy_must_be_valid(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}}
)
def test_tools_must_be_list(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}}
)
def test_old_tool_dataset_requires_id(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}}
)
def test_old_tool_dataset_id_must_be_uuid(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant",
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}},
)
def test_old_tool_dataset_id_not_exists(self, mocker):
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
return_value=False,
)
dataset_id = str(uuid.uuid4())
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant",
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}},
)
def test_old_tool_enabled_must_be_bool(self):
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant",
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}},
)
@pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"])
def test_new_style_tool_requires_fields(self, missing_key):
tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"}
tool.pop(missing_key, None)
with pytest.raises(ValueError):
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
"tenant", {"agent_mode": {"enabled": True, "tools": [tool]}}
)
def test_valid_old_and_new_style_tools(self, mocker):
mocker.patch(
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
return_value=True,
)
dataset_id = str(uuid.uuid4())
config = {
"agent_mode": {
"enabled": True,
"strategy": PlanningStrategy.ROUTER.value,
"tools": [
{"dataset": {"id": dataset_id}},
{
"provider_type": "builtin",
"provider_id": "p1",
"tool_name": "tool",
"tool_parameters": {},
},
],
}
}
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False
assert updated["agent_mode"]["tools"][1]["enabled"] is False

View File

@ -0,0 +1,296 @@
import contextlib
import pytest
from pydantic import ValidationError
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
class DummyAccount:
def __init__(self, user_id):
self.id = user_id
self.session_id = f"session-{user_id}"
@pytest.fixture
def generator(mocker):
gen = AgentChatAppGenerator()
mocker.patch(
"core.app.apps.agent_chat.app_generator.current_app",
new=mocker.MagicMock(_get_current_object=mocker.MagicMock()),
)
mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx")
return gen
class TestAgentChatAppGeneratorGenerate:
def test_generate_rejects_blocking_mode(self, generator, mocker):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False)
def test_generate_requires_query(self, generator, mocker):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock())
def test_generate_rejects_non_string_query(self, generator, mocker):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(
app_model=app_model,
user=user,
args={"query": 123, "inputs": {}},
invoke_from=mocker.MagicMock(),
)
def test_generate_override_requires_debugger(self, generator, mocker):
app_model = mocker.MagicMock()
user = DummyAccount("user")
with pytest.raises(ValueError):
generator.generate(
app_model=app_model,
user=user,
args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}},
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_success_with_debugger_override(self, generator, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
user = DummyAccount("user")
invoke_from = InvokeFrom.DEBUGGER
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
generator._init_generate_records = mocker.MagicMock(
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
)
generator._handle_response = mocker.MagicMock(return_value="response")
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate",
return_value={"validated": True},
)
app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[])
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
return_value=app_config,
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
return_value=mocker.MagicMock(),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
return_value=mocker.MagicMock(),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
return_value=["file-obj"],
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.ConversationService.get_conversation",
return_value=mocker.MagicMock(id="conv"),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
return_value=mocker.MagicMock(),
)
queue_manager = mocker.MagicMock()
mocker.patch(
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
return_value=queue_manager,
)
thread_obj = mocker.MagicMock()
mocker.patch(
"core.app.apps.agent_chat.app_generator.threading.Thread",
return_value=thread_obj,
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
return_value={"result": "ok"},
)
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from)
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
return_value=app_entity,
)
args = {
"query": "hello",
"inputs": {"name": "world"},
"conversation_id": "conv",
"model_config": {"model": {"provider": "p"}},
"files": [{"id": "f1"}],
}
result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True)
assert result == {"result": "ok"}
thread_obj.start.assert_called_once()
def test_generate_without_file_config(self, generator, mocker):
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
app_model_config = mocker.MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
user = DummyAccount("user")
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
generator._init_generate_records = mocker.MagicMock(
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
)
generator._handle_response = mocker.MagicMock(return_value="response")
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
return_value=mocker.MagicMock(),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
return_value=None,
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
return_value=["file-obj"],
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
return_value=mocker.MagicMock(),
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
return_value=mocker.MagicMock(),
)
thread_obj = mocker.MagicMock()
mocker.patch(
"core.app.apps.agent_chat.app_generator.threading.Thread",
return_value=thread_obj,
)
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
return_value={"result": "ok"},
)
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP)
mocker.patch(
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
return_value=app_entity,
)
args = {"query": "hello", "inputs": {"name": "world"}}
result = generator.generate(
app_model=app_model,
user=user,
args=args,
invoke_from=InvokeFrom.WEB_APP,
streaming=True,
)
assert result == {"result": "ok"}
class TestAgentChatAppGeneratorWorker:
@pytest.fixture(autouse=True)
def patch_context(self, mocker):
@contextlib.contextmanager
def ctx_manager(*args, **kwargs):
yield
mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager)
def test_generate_worker_handles_generate_task_stopped(self, generator, mocker):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
runner = mocker.MagicMock()
runner.run.side_effect = GenerateTaskStoppedError()
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
generator._generate_worker(
flask_app=mocker.MagicMock(),
context=mocker.MagicMock(),
application_generate_entity=mocker.MagicMock(),
queue_manager=queue_manager,
conversation_id="conv",
message_id="msg",
)
queue_manager.publish_error.assert_not_called()
@pytest.mark.parametrize(
"error",
[
InvokeAuthorizationError("bad"),
ValidationError.from_exception_data("TestModel", []),
ValueError("bad"),
Exception("bad"),
],
)
def test_generate_worker_publishes_errors(self, generator, mocker, error):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
runner = mocker.MagicMock()
runner.run.side_effect = error
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
generator._generate_worker(
flask_app=mocker.MagicMock(),
context=mocker.MagicMock(),
application_generate_entity=mocker.MagicMock(),
queue_manager=queue_manager,
conversation_id="conv",
message_id="msg",
)
assert queue_manager.publish_error.called
def test_generate_worker_logs_value_error_when_debug(self, generator, mocker):
queue_manager = mocker.MagicMock()
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
runner = mocker.MagicMock()
runner.run.side_effect = ValueError("bad")
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True))
logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger")
generator._generate_worker(
flask_app=mocker.MagicMock(),
context=mocker.MagicMock(),
application_generate_entity=mocker.MagicMock(),
queue_manager=queue_manager,
conversation_id="conv",
message_id="msg",
)
logger.exception.assert_called_once()

View File

@ -0,0 +1,413 @@
import pytest
from core.agent.entities import AgentEntity
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.moderation.base import ModerationError
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
@pytest.fixture
def runner():
return AgentChatAppRunner()
class TestAgentChatAppRunnerRun:
def test_run_app_not_found(self, runner, mocker):
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
def test_run_moderation_error_direct_output(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(),
conversation_id=None,
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
mocker.patch.object(runner, "direct_output")
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
runner.direct_output.assert_called_once()
def test_run_annotation_reply_short_circuits(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(),
conversation_id=None,
user_id="user",
invoke_from=mocker.MagicMock(),
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
annotation = mocker.MagicMock(id="anno", content="answer")
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation)
mocker.patch.object(runner, "direct_output")
queue_manager = mocker.MagicMock()
runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock())
queue_manager.publish.assert_called_once()
runner.direct_output.assert_called_once()
def test_run_hosting_moderation_short_circuits(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock()
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(),
conversation_id=None,
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=True)
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
def test_run_model_schema_missing(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = None
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
@pytest.mark.parametrize(
("mode", "expected_runner"),
[
(LLMMode.CHAT, "CotChatAgentRunner"),
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
],
)
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
model_schema = mocker.MagicMock()
model_schema.features = []
model_schema.model_properties = {ModelPropertyKey.MODE: mode}
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
runner_cls = mocker.MagicMock()
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
runner_instance = mocker.MagicMock()
runner_cls.return_value = runner_instance
runner_instance.run.return_value = []
mocker.patch.object(runner, "_handle_invoke_result")
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
runner_instance.run.assert_called_once()
runner._handle_invoke_result.assert_called_once()
def test_run_invalid_llm_mode_raises(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
model_schema = mocker.MagicMock()
model_schema.features = []
model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"}
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
model_schema = mocker.MagicMock()
model_schema.features = [ModelFeature.TOOL_CALL]
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
runner_cls = mocker.MagicMock()
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
runner_instance = mocker.MagicMock()
runner_cls.return_value = runner_instance
runner_instance.run.return_value = []
mocker.patch.object(runner, "_handle_invoke_result")
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
runner_instance.run.assert_called_once()
def test_run_conversation_not_found(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, None],
)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
def test_run_message_not_found(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
generate_entity = mocker.MagicMock(
app_config=app_config,
inputs={},
query="q",
files=[],
stream=True,
model_conf=mocker.MagicMock(
provider_model_bundle=mocker.MagicMock(),
model="m",
provider="p",
credentials={"k": "v"},
),
conversation_id="conv",
invoke_from=mocker.MagicMock(),
user_id="user",
)
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
model_schema = mocker.MagicMock()
model_schema.features = []
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
llm_instance = mocker.MagicMock()
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
conversation = mocker.MagicMock(id="conv")
message = mocker.MagicMock(id="msg")
mocker.patch(
"core.app.apps.agent_chat.app_runner.db.session.scalar",
side_effect=[app_record, conversation, message],
)
with pytest.raises(ValueError):
runner.run(generate_entity, mocker.MagicMock(), conversation, message)

View File

@ -0,0 +1,162 @@
from collections.abc import Generator
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageStreamResponse,
PingStreamResponse,
)
class TestAgentChatAppGenerateResponseConverterBlocking:
def test_convert_blocking_full_response(self):
blocking = ChatbotAppBlockingResponse(
task_id="task",
data=ChatbotAppBlockingResponse.Data(
id="id",
mode="agent-chat",
conversation_id="conv",
message_id="msg",
answer="answer",
metadata={"a": 1},
created_at=123,
),
)
result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
assert result["event"] == "message"
assert result["answer"] == "answer"
assert result["metadata"] == {"a": 1}
def test_convert_blocking_simple_response_with_dict_metadata(self):
blocking = ChatbotAppBlockingResponse(
task_id="task",
data=ChatbotAppBlockingResponse.Data(
id="id",
mode="agent-chat",
conversation_id="conv",
message_id="msg",
answer="answer",
metadata={
"retriever_resources": [
{
"segment_id": "s1",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "content",
}
],
"annotation_reply": {"id": "a"},
"usage": {"prompt_tokens": 1},
},
created_at=123,
),
)
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "annotation_reply" not in result["metadata"]
assert "usage" not in result["metadata"]
def test_convert_blocking_simple_response_with_non_dict_metadata(self):
blocking = ChatbotAppBlockingResponse.model_construct(
task_id="task",
data=ChatbotAppBlockingResponse.Data.model_construct(
id="id",
mode="agent-chat",
conversation_id="conv",
message_id="msg",
answer="answer",
metadata="bad",
created_at=123,
),
)
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert result["metadata"] == {}
class TestAgentChatAppGenerateResponseConverterStream:
def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]:
def _gen():
yield ChatbotAppStreamResponse(
conversation_id="conv",
message_id="msg",
created_at=1,
stream_response=PingStreamResponse(task_id="t"),
)
yield ChatbotAppStreamResponse(
conversation_id="conv",
message_id="msg",
created_at=2,
stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"),
)
yield ChatbotAppStreamResponse(
conversation_id="conv",
message_id="msg",
created_at=3,
stream_response=MessageEndStreamResponse(
task_id="t",
id="m1",
metadata={
"retriever_resources": [
{
"segment_id": "s1",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "content",
"summary": "summary",
"extra": "ignored",
}
],
"annotation_reply": {"id": "a"},
"usage": {"prompt_tokens": 1},
},
),
)
yield ChatbotAppStreamResponse(
conversation_id="conv",
message_id="msg",
created_at=4,
stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")),
)
return _gen()
def test_convert_stream_full_response(self):
items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream()))
assert items[0] == "ping"
assert items[1]["event"] == "message"
assert "answer" in items[1]
assert items[2]["event"] == "message_end"
assert items[3]["event"] == "error"
def test_convert_stream_simple_response(self):
items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream()))
assert items[0] == "ping"
# Assert the message event structure and content at items[1]
assert items[1]["event"] == "message"
assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"]
assert items[2]["event"] == "message_end"
assert "metadata" in items[2]
metadata = items[2]["metadata"]
assert "annotation_reply" not in metadata
assert "usage" not in metadata
assert metadata["retriever_resources"] == [
{
"segment_id": "s1",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "content",
"summary": "summary",
}
]
assert items[3]["event"] == "error"

View File

@ -0,0 +1,113 @@
from types import SimpleNamespace
from unittest.mock import patch
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from models.model import AppMode
class TestChatAppConfigManager:
def test_get_app_config_uses_override_dict(self):
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value)
app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"})
override = {"model": "override"}
model_entity = ModelConfigEntity(provider="p", model="m")
prompt_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="hi",
)
with (
patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity),
patch(
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity
),
patch(
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
return_value=None,
),
patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None),
patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])),
):
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=None,
override_config_dict=override,
)
assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
assert app_config.app_model_config_dict == override
assert app_config.app_mode == AppMode.CHAT
def test_config_validate_filters_related_keys(self):
config = {"extra": 1}
def _add_key(key, value):
def _inner(*args, **kwargs):
config = args[-1]
config = {**config, key: value}
return config, [key]
return _inner
with (
patch(
"core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
side_effect=_add_key("model", 1),
),
patch(
"core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
side_effect=_add_key("inputs", 2),
),
patch(
"core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
side_effect=_add_key("file_upload", 3),
),
patch(
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
side_effect=_add_key("prompt", 4),
),
patch(
"core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
side_effect=_add_key("dataset", 5),
),
patch(
"core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
side_effect=_add_key("opening_statement", 6),
),
patch(
"core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
side_effect=_add_key("suggested_questions_after_answer", 7),
),
patch(
"core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
side_effect=_add_key("speech_to_text", 8),
),
patch(
"core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
side_effect=_add_key("text_to_speech", 9),
),
patch(
"core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
side_effect=_add_key("retriever_resource", 10),
),
patch(
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
side_effect=_add_key("sensitive_word_avoidance", 11),
),
):
filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config)
assert filtered["model"] == 1
assert filtered["inputs"] == 2
assert filtered["file_upload"] == 3
assert filtered["prompt"] == 4
assert filtered["dataset"] == 5
assert filtered["opening_statement"] == 6
assert filtered["suggested_questions_after_answer"] == 7
assert filtered["speech_to_text"] == 8
assert filtered["text_to_speech"] == 9
assert filtered["retriever_resource"] == 10
assert filtered["sensitive_word_avoidance"] == 11

View File

@ -0,0 +1,280 @@
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.moderation.base import ModerationError
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from models.model import AppMode
class DummyGenerateEntity:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class DummyQueueManager:
def __init__(self, *args, **kwargs):
self.published = []
def publish_error(self, error, pub_from):
self.published.append((error, pub_from))
def publish(self, event, pub_from):
self.published.append((event, pub_from))
class TestChatAppGenerator:
def test_generate_requires_query(self):
generator = ChatAppGenerator()
with pytest.raises(ValueError):
generator.generate(
app_model=SimpleNamespace(),
user=SimpleNamespace(),
args={"inputs": {}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
def test_generate_rejects_non_string_query(self):
generator = ChatAppGenerator()
with pytest.raises(ValueError):
generator.generate(
app_model=SimpleNamespace(),
user=SimpleNamespace(),
args={"query": 1, "inputs": {}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
def test_generate_debugger_overrides_model_config(self):
generator = ChatAppGenerator()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
user = SimpleNamespace(id="user-1", session_id="session-1")
args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}}
with (
patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None),
patch("core.app.apps.chat.app_generator.ChatAppConfigManager.config_validate", return_value={"x": 1}),
patch(
"core.app.apps.chat.app_generator.ChatAppConfigManager.get_app_config",
return_value=SimpleNamespace(
variables=[], external_data_variables=[], app_model_config_dict={}, app_mode=AppMode.CHAT
),
),
patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()),
patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None),
patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]),
patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity),
patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()),
patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager),
patch(
"core.app.apps.chat.app_generator.ChatAppGenerateResponseConverter.convert", return_value={"ok": True}
),
patch.object(ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})),
patch.object(ChatAppGenerator, "_prepare_user_inputs", return_value={}),
patch.object(
ChatAppGenerator,
"_init_generate_records",
return_value=(SimpleNamespace(id="c1", mode="chat"), SimpleNamespace(id="m1")),
),
patch.object(ChatAppGenerator, "_handle_response", return_value={"response": True}),
patch("core.app.apps.chat.app_generator.copy_current_request_context", side_effect=lambda f: f),
patch("core.app.apps.chat.app_generator.threading.Thread") as mock_thread,
):
mock_thread.return_value.start.return_value = None
result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False)
assert result == {"ok": True}
def test_generate_rejects_model_config_override_for_non_debugger(self):
generator = ChatAppGenerator()
with pytest.raises(ValueError):
with (
patch.object(
ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})
),
):
generator.generate(
app_model=SimpleNamespace(tenant_id="t1", id="a1", mode=AppMode.CHAT.value),
user=SimpleNamespace(id="u1", session_id="s1"),
args={"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}},
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
)
def test_generate_worker_handles_exceptions(self):
generator = ChatAppGenerator()
queue_manager = DummyQueueManager()
entity = DummyGenerateEntity(task_id="t1", user_id="u1")
with (
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=InvokeAuthorizationError()),
patch("core.app.apps.chat.app_generator.db.session.close"),
):
generator._generate_worker(
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
application_generate_entity=entity,
queue_manager=queue_manager,
conversation_id="c1",
message_id="m1",
)
assert queue_manager.published
with (
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=GenerateTaskStoppedError()),
patch("core.app.apps.chat.app_generator.db.session.close"),
):
generator._generate_worker(
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
application_generate_entity=entity,
queue_manager=queue_manager,
conversation_id="c1",
message_id="m1",
)
class TestChatAppRunner:
def test_run_raises_when_app_missing(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1", tenant_id="tenant-1", prompt_template=None, external_data_variables=[]
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=False,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
with pytest.raises(ValueError):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
def test_run_moderation_error_direct_output(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
prompt_template=None,
external_data_variables=[],
dataset=None,
additional_features=None,
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=False,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
patch.object(ChatAppRunner, "direct_output") as mock_direct,
):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
mock_direct.assert_called_once()
def test_run_annotation_reply_short_circuits(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
prompt_template=None,
external_data_variables=[],
dataset=None,
additional_features=None,
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=False,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
annotation = SimpleNamespace(id="ann-1", content="answer")
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
patch.object(ChatAppRunner, "direct_output") as mock_direct,
):
queue_manager = DummyQueueManager()
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
assert any(isinstance(item[0], QueueAnnotationReplyEvent) for item in queue_manager.published)
mock_direct.assert_called_once()
def test_run_returns_when_hosting_moderation_blocks(self):
runner = ChatAppRunner()
app_config = SimpleNamespace(
app_id="app-1",
tenant_id="tenant-1",
prompt_template=None,
external_data_variables=[],
dataset=None,
additional_features=None,
)
app_generate_entity = DummyGenerateEntity(
app_config=app_config,
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
inputs={},
query="hi",
files=[],
file_upload_config=None,
conversation_id=None,
stream=False,
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
)
with (
patch(
"core.app.apps.chat.app_runner.db.session.scalar",
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
),
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
):
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))

View File

@ -0,0 +1,65 @@
from collections.abc import Generator
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageStreamResponse,
PingStreamResponse,
)
class TestChatAppGenerateResponseConverter:
def test_convert_blocking_simple_response_metadata(self):
data = ChatbotAppBlockingResponse.Data(
id="msg-1",
mode="chat",
conversation_id="c1",
message_id="m1",
answer="hi",
metadata={"usage": {"total_tokens": 1}},
created_at=1,
)
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
response = ChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "usage" not in response["metadata"]
def test_convert_stream_responses(self):
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=PingStreamResponse(task_id="t1"),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=MessageStreamResponse(task_id="t1", id="m1", answer="hi"),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
)
yield ChatbotAppStreamResponse(
conversation_id="c1",
message_id="m1",
created_at=1,
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
)
full = list(ChatAppGenerateResponseConverter.convert_stream_full_response(stream()))
assert full[0] == "ping"
assert full[1]["event"] == "message"
assert full[2]["event"] == "error"
simple = list(ChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert simple[0] == "ping"
assert simple[-1]["event"] == "message_end"

View File

@ -0,0 +1,162 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import core.app.apps.completion.app_runner as module
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.moderation.base import ModerationError
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
@pytest.fixture
def runner():
return CompletionAppRunner()
def _build_app_config(dataset=None, external_tools=None, additional_features=None):
app_config = MagicMock()
app_config.app_id = "app1"
app_config.tenant_id = "tenant"
app_config.prompt_template = MagicMock()
app_config.dataset = dataset
app_config.external_data_variables = external_tools or []
app_config.additional_features = additional_features
app_config.app_model_config_dict = {"file_upload": {"enabled": True}}
return app_config
def _build_generate_entity(app_config, file_upload_config=None):
model_conf = MagicMock(
provider_model_bundle="bundle",
model="model",
parameters={"max_tokens": 10},
stop=["stop"],
)
return SimpleNamespace(
app_config=app_config,
model_conf=model_conf,
inputs={"qvar": "query_from_input"},
query="original_query",
files=[],
file_upload_config=file_upload_config,
stream=True,
user_id="user",
invoke_from=MagicMock(),
)
class TestCompletionAppRunner:
def test_run_app_not_found(self, runner, mocker):
session = mocker.MagicMock()
session.scalar.return_value = None
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
with pytest.raises(ValueError):
runner.run(app_generate_entity, MagicMock(), MagicMock())
def test_run_moderation_error_outputs_direct(self, runner, mocker):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
runner.organize_prompt_messages = MagicMock(return_value=([], None))
runner.moderation_for_inputs = MagicMock(side_effect=ModerationError("blocked"))
runner.direct_output = MagicMock()
runner._handle_invoke_result = MagicMock()
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
runner.direct_output.assert_called_once()
runner._handle_invoke_result.assert_not_called()
def test_run_hosting_moderation_stops(self, runner, mocker):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config)
runner.organize_prompt_messages = MagicMock(return_value=([], None))
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.check_hosting_moderation = MagicMock(return_value=True)
runner._handle_invoke_result = MagicMock()
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
runner._handle_invoke_result.assert_not_called()
def test_run_dataset_and_external_tools_flow(self, runner, mocker):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
session.close = MagicMock()
mocker.patch.object(module.db, "session", session)
retrieve_config = MagicMock(query_variable="qvar")
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
additional_features = MagicMock(show_retrieve_source=True)
app_config = _build_app_config(
dataset=dataset_config,
external_tools=["tool"],
additional_features=additional_features,
)
file_upload_config = MagicMock()
file_upload_config.image_config.detail = ImagePromptMessageContent.DETAIL.HIGH
app_generate_entity = _build_generate_entity(app_config, file_upload_config=file_upload_config)
runner.organize_prompt_messages = MagicMock(side_effect=[(["pm1"], ["stop"]), (["pm2"], ["stop"])])
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.fill_in_inputs_from_external_data_tools = MagicMock(return_value=app_generate_entity.inputs)
runner.check_hosting_moderation = MagicMock(return_value=False)
runner.recalc_llm_max_tokens = MagicMock()
runner._handle_invoke_result = MagicMock()
dataset_retrieval = MagicMock()
dataset_retrieval.retrieve.return_value = ("ctx", ["file1"])
mocker.patch.object(module, "DatasetRetrieval", return_value=dataset_retrieval)
model_instance = MagicMock()
model_instance.invoke_llm.return_value = "invoke_result"
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
dataset_retrieval.retrieve.assert_called_once()
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
runner._handle_invoke_result.assert_called_once()
def test_run_uses_low_image_detail_default(self, runner, mocker):
app_record = MagicMock(id="app1", tenant_id="tenant")
session = mocker.MagicMock()
session.scalar.return_value = app_record
mocker.patch.object(module.db, "session", session)
app_config = _build_app_config()
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
runner.organize_prompt_messages = MagicMock(return_value=([], None))
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
runner.check_hosting_moderation = MagicMock(return_value=True)
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
assert (
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]
== ImagePromptMessageContent.DETAIL.LOW
)

View File

@ -0,0 +1,122 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import core.app.apps.completion.app_config_manager as module
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from models.model import AppMode
class TestCompletionAppConfigManager:
def test_get_app_config_with_override(self, mocker):
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
app_model_config = MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
override_config = {"model": {"provider": "override"}}
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=(["v1"], ["ext1"]))
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
result = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
override_config_dict=override_config,
)
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
assert result.app_model_config_dict == override_config
assert result.variables == ["v1"]
assert result.external_data_variables == ["ext1"]
assert result.app_mode == AppMode.COMPLETION
def test_get_app_config_without_override_uses_model_config(self, mocker):
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
app_model_config = MagicMock(id="cfg1")
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=([], []))
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
result = CompletionAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
assert result.app_model_config_dict == {"model": {"provider": "x"}}
def test_config_validate_filters_related_keys(self, mocker):
config = {
"model": {"provider": "x"},
"variables": ["v"],
"file_upload": {"enabled": True},
"prompt": {"template": "t"},
"dataset": {"enabled": True},
"tts": {"enabled": True},
"more_like_this": {"enabled": True},
"moderation": {"enabled": True},
"extra": "drop",
}
mocker.patch.object(
module.ModelConfigManager,
"validate_and_set_defaults",
return_value=(config, ["model"]),
)
mocker.patch.object(
module.BasicVariablesConfigManager,
"validate_and_set_defaults",
return_value=(config, ["variables"]),
)
mocker.patch.object(
module.FileUploadConfigManager,
"validate_and_set_defaults",
return_value=(config, ["file_upload"]),
)
mocker.patch.object(
module.PromptTemplateConfigManager,
"validate_and_set_defaults",
return_value=(config, ["prompt"]),
)
mocker.patch.object(
module.DatasetConfigManager,
"validate_and_set_defaults",
return_value=(config, ["dataset"]),
)
mocker.patch.object(
module.TextToSpeechConfigManager,
"validate_and_set_defaults",
return_value=(config, ["tts"]),
)
mocker.patch.object(
module.MoreLikeThisConfigManager,
"validate_and_set_defaults",
return_value=(config, ["more_like_this"]),
)
mocker.patch.object(
module.SensitiveWordAvoidanceConfigManager,
"validate_and_set_defaults",
return_value=(config, ["moderation"]),
)
filtered = CompletionAppConfigManager.config_validate("tenant", config)
assert "extra" not in filtered
assert set(filtered.keys()) == {
"model",
"variables",
"file_upload",
"prompt",
"dataset",
"tts",
"more_like_this",
"moderation",
}

View File

@ -0,0 +1,321 @@
import contextlib
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pydantic import ValidationError
import core.app.apps.completion.app_generator as module
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
@pytest.fixture
def generator(mocker):
gen = CompletionAppGenerator()
mocker.patch.object(module, "copy_current_request_context", side_effect=lambda fn: fn)
flask_app = MagicMock()
flask_app.app_context.return_value = contextlib.nullcontext()
mocker.patch.object(module, "current_app", MagicMock(_get_current_object=MagicMock(return_value=flask_app)))
thread = MagicMock()
mocker.patch.object(module.threading, "Thread", return_value=thread)
mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock())
mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock())
mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
return gen
def _build_app_model():
return MagicMock(tenant_id="tenant", id="app1", mode="completion")
def _build_user():
return MagicMock(id="user", session_id="session")
def _build_app_model_config():
config = MagicMock(id="cfg")
config.to_dict.return_value = {"model": {"provider": "x"}}
return config
class TestCompletionAppGenerator:
def test_generate_invalid_query_type(self, generator):
with pytest.raises(ValueError):
generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": 123, "inputs": {}, "files": []},
invoke_from=InvokeFrom.WEB_APP,
streaming=True,
)
def test_generate_override_not_debugger(self, generator):
with pytest.raises(ValueError):
generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": "q", "inputs": {}, "files": [], "model_config": {}},
invoke_from=InvokeFrom.WEB_APP,
streaming=False,
)
def test_generate_success_no_file_config(self, generator, mocker):
app_model_config = _build_app_model_config()
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
mocker.patch.object(module.file_factory, "build_from_mappings")
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
conversation = MagicMock(id="conv", mode="completion")
message = MagicMock(id="msg")
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
result = generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": "q", "inputs": {"a": 1}, "files": []},
invoke_from=InvokeFrom.WEB_APP,
streaming=True,
)
assert result == "converted"
module.file_factory.build_from_mappings.assert_not_called()
def test_generate_success_with_files(self, generator, mocker):
app_model_config = _build_app_model_config()
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
file_extra_config = MagicMock()
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
conversation = MagicMock(id="conv", mode="completion")
message = MagicMock(id="msg")
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
result = generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": "q", "inputs": {"a": 1}, "files": [{"id": "f"}]},
invoke_from=InvokeFrom.WEB_APP,
streaming=False,
)
assert result == "converted"
module.file_factory.build_from_mappings.assert_called_once()
def test_generate_override_model_config_debugger(self, generator, mocker):
app_model_config = _build_app_model_config()
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
override_config = {"model": {"provider": "override"}}
mocker.patch.object(module.CompletionAppConfigManager, "config_validate", return_value=override_config)
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
get_app_config = mocker.patch.object(
module.CompletionAppConfigManager,
"get_app_config",
return_value=app_config,
)
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
mocker.patch.object(
generator,
"_init_generate_records",
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
)
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
generator.generate(
app_model=_build_app_model(),
user=_build_user(),
args={"query": "q", "inputs": {}, "files": [], "model_config": override_config},
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
assert get_app_config.call_args.kwargs["override_config_dict"] == override_config
def test_generate_more_like_this_message_not_found(self, generator, mocker):
session = mocker.MagicMock()
session.scalar.return_value = None
mocker.patch.object(module.db, "session", session)
with pytest.raises(MessageNotExistsError):
generator.generate_more_like_this(
app_model=_build_app_model(),
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_more_like_this_disabled(self, generator, mocker):
app_model = _build_app_model()
app_model.app_model_config = MagicMock(more_like_this=False, more_like_this_dict={"enabled": False})
message = MagicMock()
session = mocker.MagicMock()
session.scalar.return_value = message
mocker.patch.object(module.db, "session", session)
with pytest.raises(MoreLikeThisDisabledError):
generator.generate_more_like_this(
app_model=app_model,
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_more_like_this_app_model_config_missing(self, generator, mocker):
app_model = _build_app_model()
app_model.app_model_config = None
message = MagicMock()
session = mocker.MagicMock()
session.scalar.return_value = message
mocker.patch.object(module.db, "session", session)
with pytest.raises(MoreLikeThisDisabledError):
generator.generate_more_like_this(
app_model=app_model,
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_more_like_this_message_config_none(self, generator, mocker):
app_model = _build_app_model()
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
message = MagicMock(app_model_config=None)
session = mocker.MagicMock()
session.scalar.return_value = message
mocker.patch.object(module.db, "session", session)
with pytest.raises(ValueError):
generator.generate_more_like_this(
app_model=app_model,
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
)
def test_generate_more_like_this_success(self, generator, mocker):
app_model = _build_app_model()
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
message = MagicMock()
message.message_files = [{"id": "f"}]
message.inputs = {"a": 1}
message.query = "q"
app_model_config = MagicMock()
app_model_config.to_dict.return_value = {
"model": {"completion_params": {"temperature": 0.1}},
"file_upload": {"enabled": True},
}
message.app_model_config = app_model_config
session = mocker.MagicMock()
session.scalar.return_value = message
mocker.patch.object(module.db, "session", session)
file_extra_config = MagicMock()
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
get_app_config = mocker.patch.object(
module.CompletionAppConfigManager,
"get_app_config",
return_value=app_config,
)
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
mocker.patch.object(
generator,
"_init_generate_records",
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
)
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
result = generator.generate_more_like_this(
app_model=app_model,
message_id="msg",
user=_build_user(),
invoke_from=InvokeFrom.WEB_APP,
stream=True,
)
assert result == "converted"
override_dict = get_app_config.call_args.kwargs["override_config_dict"]
assert override_dict["model"]["completion_params"]["temperature"] == 0.9
@pytest.mark.parametrize(
("error", "should_publish"),
[
(GenerateTaskStoppedError(), False),
(InvokeAuthorizationError("bad"), True),
(
ValidationError.from_exception_data(
"Model",
[{"type": "missing", "loc": ("x",), "msg": "Field required", "input": {}}],
),
True,
),
(ValueError("bad"), True),
(RuntimeError("boom"), True),
],
)
def test_generate_worker_error_handling(self, generator, mocker, error, should_publish):
flask_app = MagicMock()
flask_app.app_context.return_value = contextlib.nullcontext()
session = mocker.MagicMock()
mocker.patch.object(module.db, "session", session)
mocker.patch.object(generator, "_get_message", return_value=MagicMock())
runner_instance = MagicMock()
runner_instance.run.side_effect = error
mocker.patch.object(module, "CompletionAppRunner", return_value=runner_instance)
queue_manager = MagicMock()
generator._generate_worker(
flask_app=flask_app,
application_generate_entity=MagicMock(),
queue_manager=queue_manager,
message_id="msg",
)
assert queue_manager.publish_error.called is should_publish

View File

@ -0,0 +1,153 @@
from collections.abc import Generator
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageStreamResponse,
PingStreamResponse,
)
class TestCompletionAppGenerateResponseConverter:
def test_convert_blocking_full_response(self):
blocking = CompletionAppBlockingResponse(
task_id="task",
data=CompletionAppBlockingResponse.Data(
id="id",
mode="completion",
message_id="msg",
answer="answer",
metadata={"k": "v"},
created_at=123,
),
)
result = CompletionAppGenerateResponseConverter.convert_blocking_full_response(blocking)
assert result["event"] == "message"
assert result["task_id"] == "task"
assert result["message_id"] == "msg"
assert result["answer"] == "answer"
assert result["metadata"] == {"k": "v"}
def test_convert_blocking_simple_response_metadata_simplified(self):
metadata = {
"retriever_resources": [
{
"segment_id": "s",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "c",
"summary": "sum",
"extra": "x",
}
],
"annotation_reply": {"a": 1},
"usage": {"t": 2},
}
blocking = CompletionAppBlockingResponse(
task_id="task",
data=CompletionAppBlockingResponse.Data(
id="id",
mode="completion",
message_id="msg",
answer="answer",
metadata=metadata,
created_at=123,
),
)
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert "annotation_reply" not in result["metadata"]
assert "usage" not in result["metadata"]
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
assert "extra" not in result["metadata"]["retriever_resources"][0]
def test_convert_blocking_simple_response_metadata_not_dict(self):
data = CompletionAppBlockingResponse.Data.model_construct(
id="id",
mode="completion",
message_id="msg",
answer="answer",
metadata="bad",
created_at=123,
)
blocking = CompletionAppBlockingResponse.model_construct(task_id="task", data=data)
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert result["metadata"] == {}
def test_convert_stream_full_response(self):
def stream() -> Generator[AppStreamResponse, None, None]:
yield CompletionAppStreamResponse(
stream_response=PingStreamResponse(task_id="t"),
message_id="m",
created_at=1,
)
yield CompletionAppStreamResponse(
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
message_id="m",
created_at=2,
)
yield CompletionAppStreamResponse(
stream_response=MessageStreamResponse(task_id="t", id="1", answer="ok"),
message_id="m",
created_at=3,
)
result = list(CompletionAppGenerateResponseConverter.convert_stream_full_response(stream()))
assert result[0] == "ping"
assert result[1]["event"] == "error"
assert result[1]["code"] == "invalid_param"
assert result[2]["event"] == "message"
def test_convert_stream_simple_response(self):
def stream() -> Generator[AppStreamResponse, None, None]:
yield CompletionAppStreamResponse(
stream_response=PingStreamResponse(task_id="t"),
message_id="m",
created_at=1,
)
yield CompletionAppStreamResponse(
stream_response=MessageEndStreamResponse(
task_id="t",
id="end",
metadata={
"retriever_resources": [
{
"segment_id": "s",
"position": 1,
"document_name": "doc",
"score": 0.9,
"content": "c",
"summary": "sum",
}
],
"annotation_reply": {"a": 1},
"usage": {"t": 2},
},
),
message_id="m",
created_at=2,
)
yield CompletionAppStreamResponse(
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
message_id="m",
created_at=3,
)
result = list(CompletionAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert result[0] == "ping"
assert result[1]["event"] == "message_end"
assert "annotation_reply" not in result[1]["metadata"]
assert "usage" not in result[1]["metadata"]
assert result[2]["event"] == "error"

View File

@ -0,0 +1,55 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
import core.app.apps.pipeline.pipeline_config_manager as module
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from models.model import AppMode
def test_get_pipeline_config(mocker):
pipeline = MagicMock(tenant_id="tenant", id="pipe1")
workflow = MagicMock(id="wf1")
mocker.patch.object(
module.WorkflowVariablesConfigManager,
"convert_rag_pipeline_variable",
return_value=["var1"],
)
mocker.patch.object(module, "PipelineConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
result = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow, start_node_id="start")
assert result.tenant_id == "tenant"
assert result.app_id == "pipe1"
assert result.workflow_id == "wf1"
assert result.app_mode == AppMode.RAG_PIPELINE
assert result.rag_pipeline_variables == ["var1"]
def test_config_validate_filters_related_keys(mocker):
config = {
"file_upload": {"enabled": True},
"tts": {"enabled": True},
"moderation": {"enabled": True},
"extra": "drop",
}
mocker.patch.object(
module.FileUploadConfigManager,
"validate_and_set_defaults",
return_value=(config, ["file_upload"]),
)
mocker.patch.object(
module.TextToSpeechConfigManager,
"validate_and_set_defaults",
return_value=(config, ["tts"]),
)
mocker.patch.object(
module.SensitiveWordAvoidanceConfigManager,
"validate_and_set_defaults",
return_value=(config, ["moderation"]),
)
filtered = PipelineConfigManager.config_validate("tenant", config)
assert set(filtered.keys()) == {"file_upload", "tts", "moderation"}

View File

@ -0,0 +1,111 @@
from collections.abc import Generator
from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
def test_convert_blocking_full_and_simple_response():
blocking = WorkflowAppBlockingResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowAppBlockingResponse.Data(
id="id",
workflow_id="wf",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={"k": "v"},
error=None,
elapsed_time=0.1,
total_tokens=10,
total_steps=1,
created_at=1,
finished_at=2,
),
)
full = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
simple = WorkflowAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
assert full == simple
assert full["workflow_run_id"] == "run"
assert full["data"]["status"] == WorkflowExecutionStatus.SUCCEEDED
def test_convert_stream_full_response():
def stream() -> Generator[AppStreamResponse, None, None]:
yield WorkflowAppStreamResponse(
stream_response=PingStreamResponse(task_id="t"),
workflow_run_id="run",
)
yield WorkflowAppStreamResponse(
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
workflow_run_id="run",
)
result = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(stream()))
assert result[0] == "ping"
assert result[1]["event"] == "error"
assert result[1]["code"] == "invalid_param"
def test_convert_stream_simple_response_node_ignore_details():
node_start = NodeStartStreamResponse(
task_id="t",
workflow_run_id="run",
data=NodeStartStreamResponse.Data(
id="nid",
node_id="node",
node_type="type",
title="Title",
index=1,
predecessor_node_id=None,
inputs={"a": 1},
inputs_truncated=False,
created_at=1,
),
)
node_finish = NodeFinishStreamResponse(
task_id="t",
workflow_run_id="run",
data=NodeFinishStreamResponse.Data(
id="nid",
node_id="node",
node_type="type",
title="Title",
index=1,
predecessor_node_id=None,
inputs={"a": 1},
inputs_truncated=False,
process_data=None,
process_data_truncated=False,
outputs={"b": 2},
outputs_truncated=False,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=0.1,
execution_metadata=None,
created_at=1,
finished_at=2,
files=[],
),
)
def stream() -> Generator[AppStreamResponse, None, None]:
yield WorkflowAppStreamResponse(stream_response=node_start, workflow_run_id="run")
yield WorkflowAppStreamResponse(stream_response=node_finish, workflow_run_id="run")
result = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert result[0]["event"] == "node_started"
assert result[0]["data"]["inputs"] is None
assert result[1]["event"] == "node_finished"
assert result[1]["data"]["inputs"] is None

View File

@ -0,0 +1,699 @@
import contextlib
from types import SimpleNamespace
from unittest.mock import MagicMock, PropertyMock
import pytest
import core.app.apps.pipeline.pipeline_generator as module
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceProviderType
class FakeRagPipelineGenerateEntity(SimpleNamespace):
class SingleIterationRunEntity(SimpleNamespace):
pass
class SingleLoopRunEntity(SimpleNamespace):
pass
def model_dump(self):
return dict(self.__dict__)
@pytest.fixture
def generator(mocker):
gen = module.PipelineGenerator()
mocker.patch.object(module, "RagPipelineGenerateEntity", FakeRagPipelineGenerateEntity)
mocker.patch.object(module, "RagPipelineInvokeEntity", side_effect=lambda **kwargs: kwargs)
mocker.patch.object(module.contexts, "plugin_tool_providers", SimpleNamespace(set=MagicMock()))
mocker.patch.object(module.contexts, "plugin_tool_providers_lock", SimpleNamespace(set=MagicMock()))
return gen
def _build_pipeline_dataset():
return SimpleNamespace(
id="ds",
name="dataset",
description="desc",
chunk_structure="chunk",
built_in_field_enabled=True,
tenant_id="tenant",
)
def _build_pipeline():
pipeline = MagicMock(tenant_id="tenant", id="pipe")
pipeline.retrieve_dataset.return_value = _build_pipeline_dataset()
return pipeline
def _build_workflow():
return MagicMock(id="wf", graph_dict={"nodes": [], "edges": []}, tenant_id="tenant")
def _build_user():
return MagicMock(id="user", name="User", session_id="session")
def _build_args():
return {
"inputs": {"k": "v"},
"start_node_id": "start",
"datasource_type": DatasourceProviderType.LOCAL_FILE.value,
"datasource_info_list": [{"name": "file"}],
}
def _patch_session(mocker, session):
mocker.patch.object(module, "Session", return_value=session)
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
def _dummy_preserve(*args, **kwargs):
return contextlib.nullcontext()
class DummySession:
def __init__(self):
self.scalar = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def test_generate_dataset_missing(generator, mocker):
pipeline = _build_pipeline()
pipeline.retrieve_dataset.return_value = None
session = DummySession()
_patch_session(mocker, session)
with pytest.raises(ValueError):
generator.generate(
pipeline=pipeline,
workflow=_build_workflow(),
user=_build_user(),
args=_build_args(),
invoke_from=InvokeFrom.WEB_APP,
streaming=False,
)
def test_generate_debugger_calls_generate(generator, mocker):
pipeline = _build_pipeline()
workflow = _build_workflow()
session = DummySession()
_patch_session(mocker, session)
mocker.patch.object(
generator,
"_format_datasource_info_list",
return_value=[{"name": "file"}],
)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
)
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
result = generator.generate(
pipeline=pipeline,
workflow=workflow,
user=_build_user(),
args=_build_args(),
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
assert result == {"result": "ok"}
def test_generate_published_pipeline_creates_documents_and_delay(generator, mocker):
pipeline = _build_pipeline()
workflow = _build_workflow()
session = DummySession()
_patch_session(mocker, session)
datasource_info_list = [{"name": "file1"}, {"name": "file2"}]
mocker.patch.object(
generator,
"_format_datasource_info_list",
return_value=datasource_info_list,
)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
)
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1)
document1 = SimpleNamespace(
id="doc1",
position=1,
data_source_type=DatasourceProviderType.LOCAL_FILE,
data_source_info="{}",
name="file1",
indexing_status="",
error=None,
enabled=True,
)
document2 = SimpleNamespace(
id="doc2",
position=2,
data_source_type=DatasourceProviderType.LOCAL_FILE,
data_source_info="{}",
name="file2",
indexing_status="",
error=None,
enabled=True,
)
mocker.patch.object(generator, "_build_document", side_effect=[document1, document2])
mocker.patch.object(module, "DocumentPipelineExecutionLog", return_value=MagicMock())
db_session = MagicMock()
mocker.patch.object(module.db, "session", db_session)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
task_proxy = MagicMock()
mocker.patch.object(module, "RagPipelineTaskProxy", return_value=task_proxy)
result = generator.generate(
pipeline=pipeline,
workflow=workflow,
user=_build_user(),
args=_build_args(),
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
streaming=False,
)
assert result["batch"]
assert len(result["documents"]) == 2
task_proxy.delay.assert_called_once()
def test_generate_is_retry_calls_generate(generator, mocker):
pipeline = _build_pipeline()
workflow = _build_workflow()
session = DummySession()
_patch_session(mocker, session)
mocker.patch.object(
generator,
"_format_datasource_info_list",
return_value=[{"name": "file"}],
)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
)
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
result = generator.generate(
pipeline=pipeline,
workflow=workflow,
user=_build_user(),
args=_build_args(),
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
streaming=True,
is_retry=True,
)
assert result == {"result": "ok"}
def test_generate_worker_handles_errors(generator, mocker):
flask_app = MagicMock()
flask_app.app_context.return_value = contextlib.nullcontext()
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
application_generate_entity = FakeRagPipelineGenerateEntity(
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
invoke_from=InvokeFrom.WEB_APP,
user_id="user",
)
session = DummySession()
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
_patch_session(mocker, session)
runner_instance = MagicMock()
runner_instance.run.side_effect = ValueError("bad")
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
queue_manager = MagicMock()
generator._generate_worker(
flask_app=flask_app,
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
context=contextlib.nullcontext(),
variable_loader=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
queue_manager.publish_error.assert_called_once()
def test_generate_worker_sets_system_user_id_for_external_call(generator, mocker):
flask_app = MagicMock()
flask_app.app_context.return_value = contextlib.nullcontext()
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
application_generate_entity = FakeRagPipelineGenerateEntity(
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
invoke_from=InvokeFrom.WEB_APP,
user_id="user",
)
session = DummySession()
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
_patch_session(mocker, session)
runner_instance = MagicMock()
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
generator._generate_worker(
flask_app=flask_app,
application_generate_entity=application_generate_entity,
queue_manager=MagicMock(),
context=contextlib.nullcontext(),
variable_loader=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
assert module.PipelineRunner.call_args.kwargs["system_user_id"] == "session"
def test_generate_raises_when_workflow_not_found(generator, mocker):
flask_app = MagicMock()
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(module.db, "session", session)
with pytest.raises(ValueError):
generator._generate(
flask_app=flask_app,
context=contextlib.nullcontext(),
pipeline=_build_pipeline(),
workflow_id="wf",
user=_build_user(),
application_generate_entity=FakeRagPipelineGenerateEntity(
task_id="t",
app_config=SimpleNamespace(app_id="pipe"),
user_id="user",
invoke_from=InvokeFrom.DEBUGGER,
),
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
streaming=True,
)
def test_generate_success_returns_converted(generator, mocker):
flask_app = MagicMock()
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = workflow
mocker.patch.object(module.db, "session", session)
queue_manager = MagicMock()
mocker.patch.object(module, "PipelineQueueManager", return_value=queue_manager)
worker_thread = MagicMock()
mocker.patch.object(module.threading, "Thread", return_value=worker_thread)
mocker.patch.object(generator, "_get_draft_var_saver_factory", return_value=MagicMock())
mocker.patch.object(generator, "_handle_response", return_value="response")
mocker.patch.object(module.WorkflowAppGenerateResponseConverter, "convert", return_value="converted")
result = generator._generate(
flask_app=flask_app,
context=contextlib.nullcontext(),
pipeline=_build_pipeline(),
workflow_id="wf",
user=_build_user(),
application_generate_entity=FakeRagPipelineGenerateEntity(
task_id="t",
app_config=SimpleNamespace(app_id="pipe"),
user_id="user",
invoke_from=InvokeFrom.DEBUGGER,
),
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
streaming=True,
)
assert result == "converted"
def test_single_iteration_generate_validates_inputs(generator, mocker):
with pytest.raises(ValueError):
generator.single_iteration_generate(_build_pipeline(), _build_workflow(), "", _build_user(), {})
with pytest.raises(ValueError):
generator.single_iteration_generate(
_build_pipeline(), _build_workflow(), "node", _build_user(), {"inputs": None}
)
def test_single_iteration_generate_dataset_required(generator, mocker):
pipeline = _build_pipeline()
pipeline.retrieve_dataset.return_value = None
session = DummySession()
_patch_session(mocker, session)
with pytest.raises(ValueError):
generator.single_iteration_generate(
pipeline,
_build_workflow(),
"node",
_build_user(),
{"inputs": {"a": 1}},
)
def test_single_iteration_generate_success(generator, mocker):
pipeline = _build_pipeline()
session = DummySession()
_patch_session(mocker, session)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
mocker.patch.object(generator, "_generate", return_value={"ok": True})
result = generator.single_iteration_generate(
pipeline,
_build_workflow(),
"node",
_build_user(),
{"inputs": {"a": 1}},
streaming=False,
)
assert result == {"ok": True}
def test_single_loop_generate_success(generator, mocker):
pipeline = _build_pipeline()
session = DummySession()
_patch_session(mocker, session)
mocker.patch.object(
module.PipelineConfigManager,
"get_pipeline_config",
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(
module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
return_value=MagicMock(),
)
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
mocker.patch.object(generator, "_generate", return_value={"ok": True})
result = generator.single_loop_generate(
pipeline,
_build_workflow(),
"node",
_build_user(),
{"inputs": {"a": 1}},
streaming=False,
)
assert result == {"ok": True}
def test_handle_response_value_error_triggers_generate_task_stopped(generator, mocker):
pipeline = _build_pipeline()
workflow = _build_workflow()
app_entity = FakeRagPipelineGenerateEntity(task_id="t")
task_pipeline = MagicMock()
task_pipeline.process.side_effect = ValueError("I/O operation on closed file.")
mocker.patch.object(module, "WorkflowAppGenerateTaskPipeline", return_value=task_pipeline)
with pytest.raises(GenerateTaskStoppedError):
generator._handle_response(
application_generate_entity=app_entity,
workflow=workflow,
queue_manager=MagicMock(),
user=_build_user(),
draft_var_saver_factory=MagicMock(),
stream=False,
)
def test_build_document_sets_metadata_for_builtin_fields(generator, mocker):
class DummyDocument(SimpleNamespace):
pass
mocker.patch.object(module, "Document", side_effect=lambda **kwargs: DummyDocument(**kwargs))
document = generator._build_document(
tenant_id="tenant",
dataset_id="ds",
built_in_field_enabled=True,
datasource_type=DatasourceProviderType.LOCAL_FILE,
datasource_info={"name": "file"},
created_from="rag-pipeline",
position=1,
account=_build_user(),
batch="batch",
document_form="text",
)
assert document.name == "file"
assert document.doc_metadata
def test_build_document_invalid_datasource_type(generator):
with pytest.raises(ValueError):
generator._build_document(
tenant_id="tenant",
dataset_id="ds",
built_in_field_enabled=False,
datasource_type="invalid",
datasource_info={},
created_from="rag-pipeline",
position=1,
account=_build_user(),
batch="batch",
document_form="text",
)
def test_format_datasource_info_list_non_online_drive(generator):
result = generator._format_datasource_info_list(
DatasourceProviderType.LOCAL_FILE,
[{"name": "file"}],
_build_pipeline(),
_build_workflow(),
"start",
_build_user(),
)
assert result == [{"name": "file"}]
def test_format_datasource_info_list_missing_node_data(generator):
workflow = MagicMock(graph_dict={"nodes": []})
with pytest.raises(ValueError):
generator._format_datasource_info_list(
DatasourceProviderType.ONLINE_DRIVE,
[],
_build_pipeline(),
workflow,
"start",
_build_user(),
)
def test_format_datasource_info_list_online_drive_folder(generator, mocker):
workflow = MagicMock(
graph_dict={
"nodes": [
{
"id": "start",
"data": {
"plugin_id": "p",
"provider_name": "provider",
"datasource_name": "drive",
"credential_id": "cred",
},
}
]
}
)
runtime = MagicMock()
runtime.runtime = SimpleNamespace(credentials=None)
runtime.datasource_provider_type.return_value = DatasourceProviderType.ONLINE_DRIVE
mocker.patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=runtime,
)
mocker.patch.object(module.DatasourceProviderService, "get_datasource_credentials", return_value={"k": "v"})
mocker.patch.object(
generator,
"_get_files_in_folder",
side_effect=lambda *args, **kwargs: args[4].append({"id": "f"}),
)
result = generator._format_datasource_info_list(
DatasourceProviderType.ONLINE_DRIVE,
[{"id": "folder", "type": "folder", "name": "Folder", "bucket": "b"}],
_build_pipeline(),
workflow,
"start",
_build_user(),
)
assert result == [{"id": "f"}]
def test_get_files_in_folder_recurses_and_collects(generator):
class File:
def __init__(self, id, name, type):
self.id = id
self.name = name
self.type = type
class FilesPage:
def __init__(self, files, is_truncated=False, next_page_parameters=None):
self.files = files
self.is_truncated = is_truncated
self.next_page_parameters = next_page_parameters
class Result:
def __init__(self, result):
self.result = result
class Runtime:
def __init__(self):
self.calls = []
def datasource_provider_type(self):
return DatasourceProviderType.ONLINE_DRIVE
def online_drive_browse_files(self, user_id, request, provider_type):
self.calls.append(request.next_page_parameters)
if request.prefix == "fd":
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
if request.next_page_parameters is None:
return iter(
[
Result(
[FilesPage([File("f1", "file", "file"), File("fd", "folder", "folder")], True, {"page": 2})]
)
]
)
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
runtime = Runtime()
all_files = []
generator._get_files_in_folder(
datasource_runtime=runtime,
prefix="root",
bucket="b",
user_id="user",
all_files=all_files,
datasource_info={},
)
assert {f["id"] for f in all_files} == {"f1", "f2"}

View File

@ -0,0 +1,57 @@
import pytest
import core.app.apps.pipeline.pipeline_queue_manager as module
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
)
from dify_graph.model_runtime.entities.llm_entities import LLMResult
def test_publish_sets_stop_listen_and_raises_on_stopped(mocker):
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
manager._q = mocker.MagicMock()
manager.stop_listen = mocker.MagicMock()
manager._is_stopped = mocker.MagicMock(return_value=True)
with pytest.raises(GenerateTaskStoppedError):
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.APPLICATION_MANAGER)
manager.stop_listen.assert_called_once()
def test_publish_stop_events_trigger_stop_listen(mocker):
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
manager._q = mocker.MagicMock()
manager.stop_listen = mocker.MagicMock()
manager._is_stopped = mocker.MagicMock(return_value=False)
for event in [
QueueErrorEvent(error=ValueError("bad")),
QueueMessageEndEvent(llm_result=LLMResult.model_construct()),
QueueWorkflowSucceededEvent(),
QueueWorkflowFailedEvent(error="failed", exceptions_count=1),
QueueWorkflowPartialSuccessEvent(exceptions_count=1),
]:
manager.stop_listen.reset_mock()
manager._publish(event, PublishFrom.TASK_PIPELINE)
manager.stop_listen.assert_called_once()
def test_publish_non_stop_event_no_stop_listen(mocker):
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
manager._q = mocker.MagicMock()
manager.stop_listen = mocker.MagicMock()
manager._is_stopped = mocker.MagicMock(return_value=False)
non_stop_event = mocker.MagicMock(spec=module.AppQueueEvent)
manager._publish(non_stop_event, PublishFrom.TASK_PIPELINE)
manager.stop_listen.assert_not_called()

View File

@ -0,0 +1,297 @@
"""
Unit tests for PipelineRunner behavior.
Asserts correct event handling, error propagation, and user invocation logic.
Primary collaborators: PipelineRunner, InvokeFrom, GraphRunFailedEvent, UserFrom, and mocked dependencies.
Cross-references: core.app.apps.pipeline.pipeline_runner, core.app.entities.app_invoke_entities.
"""
"""Unit tests for PipelineRunner behavior.
This module validates core control-flow outcomes for
``core.app.apps.pipeline.pipeline_runner``: app/workflow lookup, graph
initialization guards, invoke-source to user-source resolution, and failed-run
event handling. Invariants asserted here include strict graph-config
validation, correct ``InvokeFrom`` to ``UserFrom`` mapping, and publishing
error paths driven by ``GraphRunFailedEvent`` through mocked collaborators.
Primary collaborators include ``PipelineRunner``,
``core.app.entities.app_invoke_entities.InvokeFrom``, ``GraphRunFailedEvent``,
``UserFrom``, and patched DB/runtime dependencies used by the runner.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import core.app.apps.pipeline.pipeline_runner as module
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from dify_graph.graph_events import GraphRunFailedEvent
def _build_app_generate_entity() -> SimpleNamespace:
app_config = SimpleNamespace(app_id="pipe", workflow_id="wf", tenant_id="tenant")
return SimpleNamespace(
app_config=app_config,
invoke_from=InvokeFrom.WEB_APP,
user_id="user",
trace_manager=MagicMock(),
inputs={"input1": "v1"},
files=[],
workflow_execution_id="run",
document_id="doc",
original_document_id=None,
batch="batch",
dataset_id="ds",
datasource_type="local_file",
datasource_info={"name": "file"},
start_node_id="start",
call_depth=0,
single_iteration_run=None,
single_loop_run=None,
)
@pytest.fixture
def runner():
app_generate_entity = _build_app_generate_entity()
queue_manager = MagicMock()
variable_loader = MagicMock()
workflow = MagicMock()
workflow_execution_repository = MagicMock()
workflow_node_execution_repository = MagicMock()
return PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=queue_manager,
variable_loader=variable_loader,
workflow=workflow,
system_user_id="sys",
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
def test_get_app_id(runner):
assert runner._get_app_id() == "pipe"
def test_get_workflow_returns_workflow(mocker, runner):
pipeline = MagicMock(tenant_id="tenant", id="pipe")
workflow = MagicMock(id="wf")
query = MagicMock()
query.where.return_value.first.return_value = workflow
mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query)))
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
assert result == workflow
def test_init_rag_pipeline_graph_invalid_config(mocker, runner):
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={})
with pytest.raises(ValueError):
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
workflow.graph_dict = {"nodes": "bad", "edges": []}
with pytest.raises(ValueError):
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
workflow.graph_dict = {"nodes": [], "edges": "bad"}
with pytest.raises(ValueError):
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
def test_init_rag_pipeline_graph_not_found(mocker, runner):
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={"nodes": [], "edges": []})
mocker.patch.object(module.Graph, "init", return_value=None)
with pytest.raises(ValueError):
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
def test_update_document_status_on_failure(mocker, runner):
document = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = document
session = MagicMock()
session.query.return_value = query
mocker.patch.object(module.db, "session", session)
event = GraphRunFailedEvent(error="boom")
runner._update_document_status(event, document_id="doc", dataset_id="ds")
assert document.indexing_status == "error"
assert document.error == "boom"
session.commit.assert_called_once()
def test_run_pipeline_not_found(mocker):
app_generate_entity = _build_app_generate_entity()
app_generate_entity.invoke_from = InvokeFrom.WEB_APP
app_generate_entity.single_iteration_run = None
app_generate_entity.single_loop_run = None
query = MagicMock()
query.where.return_value.first.return_value = None
session = MagicMock()
session.query.return_value = query
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=MagicMock(),
variable_loader=MagicMock(),
workflow=MagicMock(),
system_user_id="sys",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
with pytest.raises(ValueError):
runner.run()
def test_run_workflow_not_initialized(mocker):
app_generate_entity = _build_app_generate_entity()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
session = MagicMock()
session.query.return_value = query_pipeline
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=MagicMock(),
variable_loader=MagicMock(),
workflow=MagicMock(),
system_user_id="sys",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
runner.get_workflow = MagicMock(return_value=None)
with pytest.raises(ValueError):
runner.run()
def test_run_single_iteration_path(mocker):
app_generate_entity = _build_app_generate_entity()
app_generate_entity.single_iteration_run = MagicMock()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
query_end_user = MagicMock()
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
session = MagicMock()
session.query.side_effect = [query_end_user, query_pipeline]
mocker.patch.object(module.db, "session", session)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=MagicMock(),
variable_loader=MagicMock(),
workflow=MagicMock(),
system_user_id="sys",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
runner.get_workflow = MagicMock(
return_value=MagicMock(
id="wf",
tenant_id="tenant",
app_id="pipe",
graph_dict={},
type="rag-pipeline",
version="v1",
)
)
runner._prepare_single_node_execution = MagicMock(return_value=("graph", "pool", "state"))
runner._update_document_status = MagicMock()
runner._handle_event = MagicMock()
workflow_entry = MagicMock()
workflow_entry.graph_engine = MagicMock()
workflow_entry.run.return_value = [MagicMock()]
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
runner.run()
runner._prepare_single_node_execution.assert_called_once()
runner._handle_event.assert_called()
def test_run_normal_path_builds_graph(mocker):
app_generate_entity = _build_app_generate_entity()
pipeline = MagicMock(id="pipe")
query_pipeline = MagicMock()
query_pipeline.where.return_value.first.return_value = pipeline
query_end_user = MagicMock()
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
session = MagicMock()
session.query.side_effect = [query_end_user, query_pipeline]
mocker.patch.object(module.db, "session", session)
workflow = MagicMock(
id="wf",
tenant_id="tenant",
app_id="pipe",
graph_dict={"nodes": [], "edges": []},
environment_variables=[],
rag_pipeline_variables=[{"variable": "input1", "belong_to_node_id": "start"}],
type="rag-pipeline",
version="v1",
)
runner = PipelineRunner(
application_generate_entity=app_generate_entity,
queue_manager=MagicMock(),
variable_loader=MagicMock(),
workflow=workflow,
system_user_id="sys",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
runner.get_workflow = MagicMock(return_value=workflow)
runner._init_rag_pipeline_graph = MagicMock(return_value="graph")
runner._update_document_status = MagicMock()
runner._handle_event = MagicMock()
mocker.patch.object(
module.RAGPipelineVariable,
"model_validate",
return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"),
)
mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
workflow_entry = MagicMock()
workflow_entry.graph_engine = MagicMock()
workflow_entry.run.return_value = []
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
runner.run()
runner._init_rag_pipeline_graph.assert_called_once()

View File

@ -1,3 +1,5 @@
from unittest.mock import MagicMock
import pytest
from core.app.apps.base_app_generator import BaseAppGenerator
@ -366,3 +368,132 @@ def test_validate_inputs_optional_file_with_empty_string_ignores_default():
)
assert result is None
class TestBaseAppGeneratorExtras:
def test_prepare_user_inputs_converts_files_and_lists(self, monkeypatch):
base_app_generator = BaseAppGenerator()
variables = [
VariableEntity(
variable="file",
label="file",
type=VariableEntityType.FILE,
required=False,
allowed_file_types=[],
allowed_file_extensions=[],
allowed_file_upload_methods=[],
),
VariableEntity(
variable="file_list",
label="file_list",
type=VariableEntityType.FILE_LIST,
required=False,
allowed_file_types=[],
allowed_file_extensions=[],
allowed_file_upload_methods=[],
),
VariableEntity(
variable="json",
label="json",
type=VariableEntityType.JSON_OBJECT,
required=False,
),
]
monkeypatch.setattr(
"core.app.apps.base_app_generator.file_factory.build_from_mapping",
lambda mapping, tenant_id, config, strict_type_validation=False: "file-object",
)
monkeypatch.setattr(
"core.app.apps.base_app_generator.file_factory.build_from_mappings",
lambda mappings, tenant_id, config: ["file-1", "file-2"],
)
user_inputs = {
"file": {"id": "file-id"},
"file_list": [{"id": "file-1"}, {"id": "file-2"}],
"json": {"key": "value"},
}
prepared = base_app_generator._prepare_user_inputs(
user_inputs=user_inputs,
variables=variables,
tenant_id="tenant-id",
)
assert prepared["file"] == "file-object"
assert prepared["file_list"] == ["file-1", "file-2"]
assert prepared["json"] == {"key": "value"}
def test_prepare_user_inputs_rejects_invalid_dict_inputs(self):
base_app_generator = BaseAppGenerator()
variables = [
VariableEntity(
variable="text",
label="text",
type=VariableEntityType.TEXT_INPUT,
required=False,
)
]
with pytest.raises(ValueError, match="must be a string"):
base_app_generator._prepare_user_inputs(
user_inputs={"text": {"unexpected": "dict"}},
variables=variables,
tenant_id="tenant-id",
)
def test_prepare_user_inputs_rejects_invalid_list_inputs(self):
base_app_generator = BaseAppGenerator()
variables = [
VariableEntity(
variable="text",
label="text",
type=VariableEntityType.TEXT_INPUT,
required=False,
)
]
with pytest.raises(ValueError, match="must be a string"):
base_app_generator._prepare_user_inputs(
user_inputs={"text": [{"unexpected": "dict"}]},
variables=variables,
tenant_id="tenant-id",
)
def test_convert_to_event_stream(self):
base_app_generator = BaseAppGenerator()
assert base_app_generator.convert_to_event_stream({"ok": True}) == {"ok": True}
def _gen():
yield {"delta": "hi"}
yield "ping"
converted = list(base_app_generator.convert_to_event_stream(_gen()))
assert converted[0].startswith("data: ")
assert "\n\n" in converted[0]
assert converted[1] == "event: ping\n\n"
def test_get_draft_var_saver_factory_debugger(self):
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.enums import NodeType
from models import Account
base_app_generator = BaseAppGenerator()
account = Account(name="Tester", email="tester@example.com")
account.id = "account-id"
account.tenant_id = "tenant-id"
factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account)
saver = factory(
session=MagicMock(),
app_id="app-id",
node_id="node-id",
node_type=NodeType.START,
node_execution_id="node-exec-id",
)
assert saver is not None

View File

@ -0,0 +1,61 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueErrorEvent
class DummyQueueManager(AppQueueManager):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.published = []
def _publish(self, event, pub_from):
self.published.append((event, pub_from))
class TestBaseAppQueueManager:
def test_init_requires_user_id(self):
with pytest.raises(ValueError):
DummyQueueManager(task_id="t1", user_id="", invoke_from=InvokeFrom.SERVICE_API)
def test_publish_error_records_event(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
manager.publish_error(ValueError("boom"), PublishFrom.TASK_PIPELINE)
assert isinstance(manager.published[0][0], QueueErrorEvent)
def test_set_stop_flag_checks_user(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.get.return_value = b"end-user-u1"
AppQueueManager.set_stop_flag(task_id="t1", invoke_from=InvokeFrom.SERVICE_API, user_id="u1")
mock_redis.setex.assert_called_once()
def test_set_stop_flag_no_user_check(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
AppQueueManager.set_stop_flag_no_user_check(task_id="t1")
mock_redis.setex.assert_called_once()
def test_is_stopped_reads_cache(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
mock_redis.get.return_value = b"1"
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
assert manager._is_stopped() is True
def test_check_for_sqlalchemy_models_raises(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
bad = SimpleNamespace(_sa_instance_state=True)
with pytest.raises(TypeError):
manager._check_for_sqlalchemy_models(bad)

View File

@ -0,0 +1,442 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
PromptTemplateEntity,
)
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
from models.model import AppMode
class _DummyParameterRule:
def __init__(self, name: str, use_template: str | None = None) -> None:
self.name = name
self.use_template = use_template
class _QueueRecorder:
def __init__(self) -> None:
self.events: list[object] = []
def publish(self, event, pub_from):
_ = pub_from
self.events.append(event)
class TestAppRunner:
def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch):
runner = AppRunner()
model_schema = SimpleNamespace(
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
parameter_rules=[_DummyParameterRule("max_tokens")],
)
model_config = SimpleNamespace(
provider_model_bundle=object(),
model="mock",
model_schema=model_schema,
parameters={"max_tokens": 30},
)
monkeypatch.setattr(
"core.app.apps.base_app_runner.ModelInstance",
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 80),
)
runner.recalc_llm_max_tokens(model_config, prompt_messages=[AssistantPromptMessage(content="hi")])
assert model_config.parameters["max_tokens"] == 20
def test_recalc_llm_max_tokens_returns_minus_one_when_no_context(self, monkeypatch):
runner = AppRunner()
model_schema = SimpleNamespace(
model_properties={},
parameter_rules=[_DummyParameterRule("max_tokens")],
)
model_config = SimpleNamespace(
provider_model_bundle=object(),
model="mock",
model_schema=model_schema,
parameters={"max_tokens": 30},
)
monkeypatch.setattr(
"core.app.apps.base_app_runner.ModelInstance",
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 10),
)
assert runner.recalc_llm_max_tokens(model_config, prompt_messages=[]) == -1
def test_direct_output_streaming_publishes_chunks_and_end(self, monkeypatch):
runner = AppRunner()
queue = _QueueRecorder()
app_generate_entity = SimpleNamespace(model_conf=SimpleNamespace(model="mock"), stream=True)
monkeypatch.setattr("core.app.apps.base_app_runner.time.sleep", lambda _: None)
runner.direct_output(
queue_manager=queue,
app_generate_entity=app_generate_entity,
prompt_messages=[],
text="hi",
stream=True,
)
assert any(isinstance(event, QueueLLMChunkEvent) for event in queue.events)
assert isinstance(queue.events[-1], QueueMessageEndEvent)
def test_handle_invoke_result_direct_publishes_end_event(self):
runner = AppRunner()
queue = _QueueRecorder()
llm_result = LLMResult(
model="mock",
prompt_messages=[],
message=AssistantPromptMessage(content="done"),
usage=LLMUsage.empty_usage(),
)
runner._handle_invoke_result(
invoke_result=llm_result,
queue_manager=queue,
stream=False,
)
assert isinstance(queue.events[-1], QueueMessageEndEvent)
def test_handle_invoke_result_invalid_type_raises(self):
runner = AppRunner()
queue = _QueueRecorder()
with pytest.raises(NotImplementedError):
runner._handle_invoke_result(
invoke_result=["unexpected"],
queue_manager=queue,
stream=True,
)
def test_organize_prompt_messages_simple_template(self, monkeypatch):
runner = AppRunner()
model_config = SimpleNamespace(mode="chat", stop=["STOP"])
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="hello",
)
monkeypatch.setattr(
"core.app.apps.base_app_runner.SimplePromptTransform.get_prompt",
lambda self, **kwargs: (["simple-message"], ["simple-stop"]),
)
prompt_messages, stop = runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs={},
files=[],
query="q",
)
assert prompt_messages == ["simple-message"]
assert stop == ["simple-stop"]
def test_organize_prompt_messages_advanced_completion_template(self, monkeypatch):
runner = AppRunner()
model_config = SimpleNamespace(mode="completion", stop=["<END>"])
captured: dict[str, object] = {}
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="answer",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="U", assistant="A"),
),
)
def _fake_advanced_prompt(self, **kwargs):
captured.update(kwargs)
return ["advanced-completion-message"]
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
prompt_messages, stop = runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs={},
files=[],
query="q",
)
assert prompt_messages == ["advanced-completion-message"]
assert stop == ["<END>"]
memory_config = captured["memory_config"]
assert memory_config.role_prefix.user == "U"
assert memory_config.role_prefix.assistant == "A"
def test_organize_prompt_messages_advanced_chat_template(self, monkeypatch):
runner = AppRunner()
model_config = SimpleNamespace(mode="chat", stop=["<END>"])
captured: dict[str, object] = {}
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[
AdvancedChatMessageEntity(text="hello", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="world", role=PromptMessageRole.ASSISTANT),
]
),
)
def _fake_advanced_prompt(self, **kwargs):
captured.update(kwargs)
return ["advanced-chat-message"]
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
prompt_messages, stop = runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs={},
files=[],
query="q",
)
assert prompt_messages == ["advanced-chat-message"]
assert stop == ["<END>"]
assert len(captured["prompt_template"]) == 2
def test_organize_prompt_messages_advanced_missing_templates_raise(self):
runner = AppRunner()
with pytest.raises(InvokeBadRequestError, match="Advanced completion prompt template is required"):
runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=SimpleNamespace(mode="completion", stop=[]),
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
inputs={},
files=[],
)
with pytest.raises(InvokeBadRequestError, match="Advanced chat prompt template is required"):
runner.organize_prompt_messages(
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
model_config=SimpleNamespace(mode="chat", stop=[]),
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
inputs={},
files=[],
)
def test_handle_invoke_result_stream_routes_chunks_and_builds_message(self, monkeypatch):
runner = AppRunner()
queue = _QueueRecorder()
warning_logger = MagicMock()
monkeypatch.setattr("core.app.apps.base_app_runner._logger.warning", warning_logger)
image_content = ImagePromptMessageContent(
url="https://example.com/image.png", format="png", mime_type="image/png"
)
def _stream():
yield LLMResultChunk(
model="stream-model",
prompt_messages=[AssistantPromptMessage(content="prompt")],
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage.model_construct(
content=[
"a",
TextPromptMessageContent(data="b"),
SimpleNamespace(data="c"),
image_content,
]
),
),
)
runner._handle_invoke_result(
invoke_result=_stream(),
queue_manager=queue,
stream=True,
agent=False,
)
assert isinstance(queue.events[0], QueueLLMChunkEvent)
assert isinstance(queue.events[-1], QueueMessageEndEvent)
assert queue.events[-1].llm_result.message.content == "abc"
warning_logger.assert_called_once()
def test_handle_invoke_result_stream_agent_mode_handles_multimodal_errors(self, monkeypatch):
runner = AppRunner()
queue = _QueueRecorder()
exception_logger = MagicMock()
monkeypatch.setattr("core.app.apps.base_app_runner._logger.exception", exception_logger)
monkeypatch.setattr(
runner,
"_handle_multimodal_image_content",
MagicMock(side_effect=RuntimeError("failed to save image")),
)
usage = LLMUsage.empty_usage()
def _stream():
yield LLMResultChunk(
model="agent-model",
prompt_messages=[AssistantPromptMessage(content="prompt")],
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
ImagePromptMessageContent(
url="https://example.com/image.png",
format="png",
mime_type="image/png",
),
TextPromptMessageContent(data="done"),
]
),
usage=usage,
),
)
runner._handle_invoke_result_stream(
invoke_result=_stream(),
queue_manager=queue,
agent=True,
message_id="message-id",
user_id="user-id",
tenant_id="tenant-id",
)
assert isinstance(queue.events[0], QueueAgentMessageEvent)
assert isinstance(queue.events[-1], QueueMessageEndEvent)
assert queue.events[-1].llm_result.usage == usage
exception_logger.assert_called_once()
def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch):
runner = AppRunner()
class _ToggleBool:
def __init__(self, values: list[bool]):
self._values = values
self._index = 0
def __bool__(self):
value = self._values[min(self._index, len(self._values) - 1)]
self._index += 1
return value
content = SimpleNamespace(
url=_ToggleBool([False, False]),
base64_data=_ToggleBool([True, False]),
mime_type="image/png",
)
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock(), refresh=MagicMock())
monkeypatch.setattr("core.app.apps.base_app_runner.ToolFileManager", lambda: MagicMock())
monkeypatch.setattr("core.app.apps.base_app_runner.db", SimpleNamespace(session=db_session))
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.SERVICE_API, publish=MagicMock())
runner._handle_multimodal_image_content(
content=content,
message_id="message-id",
user_id="user-id",
tenant_id="tenant-id",
queue_manager=queue_manager,
)
db_session.add.assert_not_called()
queue_manager.publish.assert_not_called()
def test_check_hosting_moderation_direct_output_called(self, monkeypatch):
runner = AppRunner()
queue = _QueueRecorder()
app_generate_entity = SimpleNamespace(stream=False)
monkeypatch.setattr(
"core.app.apps.base_app_runner.HostingModerationFeature.check",
lambda self, application_generate_entity, prompt_messages: True,
)
direct_output = MagicMock()
monkeypatch.setattr(runner, "direct_output", direct_output)
result = runner.check_hosting_moderation(
application_generate_entity=app_generate_entity,
queue_manager=queue,
prompt_messages=[],
)
assert result is True
assert direct_output.called
def test_fill_in_inputs_from_external_data_tools(self, monkeypatch):
runner = AppRunner()
monkeypatch.setattr(
"core.app.apps.base_app_runner.ExternalDataFetch.fetch",
lambda self, tenant_id, app_id, external_data_tools, inputs, query: {"foo": "bar"},
)
result = runner.fill_in_inputs_from_external_data_tools(
tenant_id="tenant",
app_id="app",
external_data_tools=[],
inputs={},
query="q",
)
assert result == {"foo": "bar"}
def test_moderation_for_inputs_returns_result(self, monkeypatch):
runner = AppRunner()
monkeypatch.setattr(
"core.app.apps.base_app_runner.InputModeration.check",
lambda self, app_id, tenant_id, app_config, inputs, query, message_id, trace_manager: (True, {}, ""),
)
app_generate_entity = SimpleNamespace(app_config=SimpleNamespace(), trace_manager=None)
result = runner.moderation_for_inputs(
app_id="app",
tenant_id="tenant",
app_generate_entity=app_generate_entity,
inputs={},
query="q",
message_id="msg",
)
assert result == (True, {}, "")
def test_query_app_annotations_to_reply(self, monkeypatch):
runner = AppRunner()
monkeypatch.setattr(
"core.app.apps.base_app_runner.AnnotationReplyFeature.query",
lambda self, app_record, message, query, user_id, invoke_from: "reply",
)
response = runner.query_app_annotations_to_reply(
app_record=SimpleNamespace(),
message=SimpleNamespace(),
query="hello",
user_id="user",
invoke_from=InvokeFrom.WEB_APP,
)
assert response == "reply"

View File

@ -0,0 +1,7 @@
from core.app.apps.exc import GenerateTaskStoppedError
class TestAppsExceptions:
def test_generate_task_stopped_error(self):
err = GenerateTaskStoppedError("stopped")
assert str(err) == "stopped"

View File

@ -13,9 +13,11 @@ from core.app.app_config.entities import (
PromptTemplateEntity,
)
from core.app.apps import message_based_app_generator
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from models.model import AppMode, Conversation, Message
from services.errors.app_model_config import AppModelConfigBrokenError
class DummyModelConf:
@ -125,3 +127,55 @@ def test_init_generate_records_sets_conversation_fields_for_chat_entity():
assert entity.conversation_id == "generated-conversation-id"
assert entity.is_new_conversation is True
assert conversation.id == "generated-conversation-id"
class TestMessageBasedAppGeneratorExtras:
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
generator = MessageBasedAppGenerator()
class _Pipeline:
def __init__(self, **kwargs) -> None:
_ = kwargs
def process(self):
raise ValueError("I/O operation on closed file.")
monkeypatch.setattr(
"core.app.apps.message_based_app_generator.EasyUIBasedGenerateTaskPipeline",
_Pipeline,
)
with pytest.raises(GenerateTaskStoppedError):
generator._handle_response(
application_generate_entity=_make_chat_generate_entity(_make_app_config(AppMode.CHAT)),
queue_manager=SimpleNamespace(),
conversation=SimpleNamespace(id="conv"),
message=SimpleNamespace(id="msg"),
user=SimpleNamespace(),
stream=False,
)
def test_get_app_model_config_requires_valid_config(self, monkeypatch):
generator = MessageBasedAppGenerator()
app_model = SimpleNamespace(id="app", app_model_config_id=None, app_model_config=None)
with pytest.raises(AppModelConfigBrokenError):
generator._get_app_model_config(app_model, conversation=None)
conversation = SimpleNamespace(app_model_config_id="missing-id")
monkeypatch.setattr(
message_based_app_generator, "db", SimpleNamespace(session=SimpleNamespace(scalar=lambda _: None))
)
with pytest.raises(AppModelConfigBrokenError):
generator._get_app_model_config(app_model=SimpleNamespace(id="app"), conversation=conversation)
def test_get_conversation_introduction_handles_missing_inputs(self):
app_config = _make_app_config(AppMode.CHAT)
app_config.additional_features.opening_statement = "Hello {{name}}"
entity = _make_chat_generate_entity(app_config)
entity.inputs = {}
generator = MessageBasedAppGenerator()
assert generator._get_conversation_introduction(entity) == "Hello {name}"

View File

@ -0,0 +1,65 @@
from unittest.mock import Mock, patch
import pytest
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueErrorEvent, QueueMessageEndEvent, QueueStopEvent
class TestMessageBasedAppQueueManager:
def test_publish_stops_on_terminal_events(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = MessageBasedAppQueueManager(
task_id="t1",
user_id="u1",
invoke_from=InvokeFrom.SERVICE_API,
conversation_id="c1",
app_mode="chat",
message_id="m1",
)
manager.stop_listen = Mock()
manager._is_stopped = Mock(return_value=False)
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), Mock())
manager.stop_listen.assert_called_once()
def test_publish_raises_when_stopped(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = MessageBasedAppQueueManager(
task_id="t1",
user_id="u1",
invoke_from=InvokeFrom.SERVICE_API,
conversation_id="c1",
app_mode="chat",
message_id="m1",
)
manager._is_stopped = Mock(return_value=True)
with pytest.raises(GenerateTaskStoppedError):
manager._publish(QueueErrorEvent(error=ValueError("boom")), PublishFrom.APPLICATION_MANAGER)
def test_publish_enqueues_message_end(self):
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
mock_redis.setex.return_value = True
manager = MessageBasedAppQueueManager(
task_id="t1",
user_id="u1",
invoke_from=InvokeFrom.SERVICE_API,
conversation_id="c1",
app_mode="chat",
message_id="m1",
)
manager._is_stopped = Mock(return_value=False)
manager.stop_listen = Mock()
manager._publish(QueueMessageEndEvent(), PublishFrom.TASK_PIPELINE)
assert manager._q.qsize() == 1

View File

@ -0,0 +1,29 @@
from unittest.mock import Mock, patch
from core.app.apps.message_generator import MessageGenerator
from models.model import AppMode
class TestMessageGenerator:
def test_get_response_topic(self):
channel = Mock()
channel.topic.return_value = "topic"
with patch("core.app.apps.message_generator.get_pubsub_broadcast_channel", return_value=channel):
topic = MessageGenerator.get_response_topic(AppMode.WORKFLOW, "run-1")
assert topic == "topic"
expected_key = MessageGenerator._make_channel_key(AppMode.WORKFLOW, "run-1")
channel.topic.assert_called_once_with(expected_key)
def test_retrieve_events_passes_arguments(self):
with (
patch("core.app.apps.message_generator.MessageGenerator.get_response_topic", return_value="topic"),
patch(
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
) as mock_stream,
):
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
assert events == [{"event": "ping"}]
mock_stream.assert_called_once()

View File

@ -6,6 +6,7 @@ import queue
import pytest
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.streaming_utils import _normalize_terminal_events, stream_topic_events
from core.app.entities.task_entities import StreamEvent
from models.model import AppMode
@ -78,3 +79,30 @@ def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
with pytest.raises(StopIteration):
next(generator)
def test_normalize_terminal_events_defaults():
assert _normalize_terminal_events(None) == {
StreamEvent.WORKFLOW_FINISHED.value,
StreamEvent.WORKFLOW_PAUSED.value,
}
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
topic = FakeTopic()
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
def fake_time():
return times.pop(0)
monkeypatch.setattr("core.app.apps.streaming_utils.time.time", fake_time)
generator = stream_topic_events(
topic=topic,
idle_timeout=10.0,
ping_interval=1.0,
)
assert next(generator) == StreamEvent.PING.value
# next receive yields None -> ping interval triggers
assert next(generator) == StreamEvent.PING.value

View File

@ -0,0 +1,261 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
import pytest
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueLoopCompletedEvent,
QueueTextChunkEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import NodeType
from dify_graph.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
)
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
class TestWorkflowBasedAppRunner:
def test_resolve_user_from(self):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
assert runner._resolve_user_from(InvokeFrom.EXPLORE) == UserFrom.ACCOUNT
assert runner._resolve_user_from(InvokeFrom.DEBUGGER) == UserFrom.ACCOUNT
assert runner._resolve_user_from(InvokeFrom.WEB_APP) == UserFrom.END_USER
def test_init_graph_validates_graph_structure(self):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
with pytest.raises(ValueError, match="nodes or edges not found"):
runner._init_graph(
graph_config={},
graph_runtime_state=runtime_state,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
with pytest.raises(ValueError, match="nodes in workflow graph must be a list"):
runner._init_graph(
graph_config={"nodes": {}, "edges": []},
graph_runtime_state=runtime_state,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
with pytest.raises(ValueError, match="edges in workflow graph must be a list"):
runner._init_graph(
graph_config={"nodes": [], "edges": {}},
graph_runtime_state=runtime_state,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
def test_prepare_single_node_execution_requires_run(self):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
workflow = SimpleNamespace(environment_variables=[], graph_dict={})
with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"):
runner._prepare_single_node_execution(workflow, None, None)
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
graph_config = {
"nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}],
"edges": [],
}
workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config)
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.Graph.init",
lambda **kwargs: SimpleNamespace(),
)
class _NodeCls:
@staticmethod
def extract_variable_selector_to_variable_mapping(graph_config, config):
return {}
from core.app.apps import workflow_app_runner
monkeypatch.setitem(
workflow_app_runner.NODE_TYPE_CLASSES_MAPPING,
NodeType.START,
{"1": _NodeCls},
)
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.load_into_variable_pool",
lambda **kwargs: None,
)
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool",
lambda **kwargs: None,
)
graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id="node-1",
user_inputs={},
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
assert graph is not None
assert variable_pool is graph_runtime_state.variable_pool
def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch):
published: list[object] = []
class _QueueManager:
def publish(self, event, publish_from):
published.append((event, publish_from))
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
graph_runtime_state.register_paused_node("node-1")
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
emails: list[dict] = []
class _Dispatch:
def apply_async(self, *, kwargs, queue):
emails.append({"kwargs": kwargs, "queue": queue})
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.dispatch_human_input_email_task",
_Dispatch(),
)
reason = HumanInputRequired(
form_id="form",
form_content="content",
node_id="node-1",
node_title="Node",
)
runner._handle_event(workflow_entry, GraphRunStartedEvent())
runner._handle_event(workflow_entry, GraphRunSucceededEvent(outputs={"ok": True}))
runner._handle_event(workflow_entry, GraphRunPausedEvent(reasons=[reason], outputs={}))
assert any(isinstance(event, QueueWorkflowStartedEvent) for event, _ in published)
assert any(isinstance(event, QueueWorkflowSucceededEvent) for event, _ in published)
paused_event = next(event for event, _ in published if isinstance(event, QueueWorkflowPausedEvent))
assert paused_event.paused_nodes == ["node-1"]
assert emails
def test_handle_node_events_publishes_queue_events(self):
published: list[object] = []
class _QueueManager:
def publish(self, event, publish_from):
published.append(event)
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
runner._handle_event(
workflow_entry,
NodeRunStartedEvent(
id="exec",
node_id="node",
node_type=NodeType.START,
node_title="Start",
start_at=datetime.utcnow(),
),
)
runner._handle_event(
workflow_entry,
NodeRunStreamChunkEvent(
id="exec",
node_id="node",
node_type=NodeType.START,
selector=["node", "text"],
chunk="hi",
is_final=False,
),
)
runner._handle_event(
workflow_entry,
NodeRunAgentLogEvent(
id="exec",
node_id="node",
node_type=NodeType.START,
message_id="msg",
label="label",
node_execution_id="exec",
parent_id=None,
error=None,
status="done",
data={},
metadata={},
),
)
runner._handle_event(
workflow_entry,
NodeRunIterationSucceededEvent(
id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="Iter",
start_at=datetime.utcnow(),
inputs={},
outputs={"ok": True},
metadata={},
steps=1,
),
)
runner._handle_event(
workflow_entry,
NodeRunLoopFailedEvent(
id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="Loop",
start_at=datetime.utcnow(),
inputs={},
outputs={},
metadata={},
steps=1,
error="boom",
),
)
assert any(isinstance(event, QueueTextChunkEvent) for event in published)
assert any(isinstance(event, QueueAgentLogEvent) for event in published)
assert any(isinstance(event, QueueIterationCompletedEvent) for event in published)
assert any(isinstance(event, QueueLoopCompletedEvent) for event in published)

View File

@ -0,0 +1,61 @@
from types import SimpleNamespace
from unittest.mock import patch
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from models.model import AppMode
class TestWorkflowAppConfigManager:
def test_get_app_config(self):
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
workflow = SimpleNamespace(id="wf-1", features_dict={})
with (
patch(
"core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
return_value=None,
),
patch(
"core.app.apps.workflow.app_config_manager.WorkflowVariablesConfigManager.convert",
return_value=[],
),
):
app_config = WorkflowAppConfigManager.get_app_config(app_model, workflow)
assert app_config.workflow_id == "wf-1"
assert app_config.app_mode == AppMode.WORKFLOW
def test_config_validate_filters_keys(self):
def _add_key(key, value):
def _inner(*args, **kwargs):
# Support both positional and keyword arguments for config
if "config" in kwargs:
config = kwargs["config"]
elif len(args) > 0:
config = args[0]
else:
config = {}
config[key] = value
return config, [key]
return _inner
with (
patch(
"core.app.apps.workflow.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
side_effect=_add_key("file_upload", 1),
),
patch(
"core.app.apps.workflow.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
side_effect=_add_key("text_to_speech", 2),
),
patch(
"core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
side_effect=_add_key("sensitive_word_avoidance", 3),
),
):
filtered = WorkflowAppConfigManager.config_validate(tenant_id="t1", config={})
assert filtered["file_upload"] == 1
assert filtered["text_to_speech"] == 2
assert filtered["sensitive_word_avoidance"] == 3

View File

@ -0,0 +1,187 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TestWorkflowAppGeneratorValidation:
def test_should_prepare_user_inputs(self):
generator = WorkflowAppGenerator()
assert generator._should_prepare_user_inputs({}) is True
assert generator._should_prepare_user_inputs({SKIP_PREPARE_USER_INPUTS_KEY: True}) is False
def test_single_iteration_generate_validates_args(self):
generator = WorkflowAppGenerator()
with pytest.raises(ValueError, match="node_id is required"):
generator.single_iteration_generate(
app_model=SimpleNamespace(),
workflow=SimpleNamespace(),
node_id="",
user=SimpleNamespace(),
args={"inputs": {}},
streaming=False,
)
with pytest.raises(ValueError, match="inputs is required"):
generator.single_iteration_generate(
app_model=SimpleNamespace(),
workflow=SimpleNamespace(),
node_id="node",
user=SimpleNamespace(),
args={},
streaming=False,
)
def test_single_loop_generate_validates_args(self):
generator = WorkflowAppGenerator()
with pytest.raises(ValueError, match="node_id is required"):
generator.single_loop_generate(
app_model=SimpleNamespace(),
workflow=SimpleNamespace(),
node_id="",
user=SimpleNamespace(),
args=SimpleNamespace(inputs={}),
streaming=False,
)
with pytest.raises(ValueError, match="inputs is required"):
generator.single_loop_generate(
app_model=SimpleNamespace(),
workflow=SimpleNamespace(),
node_id="node",
user=SimpleNamespace(),
args=SimpleNamespace(inputs=None),
streaming=False,
)
class TestWorkflowAppGeneratorHandleResponse:
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
generator = WorkflowAppGenerator()
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
trace_manager=None,
workflow_execution_id="run-id",
call_depth=0,
)
class _Pipeline:
def __init__(self, **kwargs) -> None:
_ = kwargs
def process(self):
raise ValueError("I/O operation on closed file.")
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateTaskPipeline",
_Pipeline,
)
with pytest.raises(GenerateTaskStoppedError):
generator._handle_response(
application_generate_entity=application_generate_entity,
workflow=SimpleNamespace(),
queue_manager=SimpleNamespace(),
user=SimpleNamespace(),
draft_var_saver_factory=lambda **kwargs: None,
stream=False,
)
class TestWorkflowAppGeneratorGenerate:
def test_generate_skips_prepare_inputs_when_flag_set(self, monkeypatch):
generator = WorkflowAppGenerator()
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config",
lambda app_model, workflow: app_config,
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.FileUploadConfigManager.convert",
lambda features_dict, is_vision=False: None,
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.file_factory.build_from_mappings",
lambda **kwargs: [],
)
DummyTraceQueueManager = type(
"_DummyTraceQueueManager",
(TraceQueueManager,),
{
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
or setattr(self, "user_id", user_id)
},
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.TraceQueueManager",
DummyTraceQueueManager,
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
lambda **kwargs: SimpleNamespace(),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
lambda **kwargs: SimpleNamespace(),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.db",
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
)
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.sessionmaker",
lambda **kwargs: SimpleNamespace(),
)
prepare_inputs = pytest.fail
monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: prepare_inputs())
monkeypatch.setattr(generator, "_generate", lambda **kwargs: {"ok": True})
result = generator.generate(
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
workflow=SimpleNamespace(features_dict={}),
user=SimpleNamespace(id="user", session_id="session"),
args={"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True},
invoke_from=InvokeFrom.WEB_APP,
streaming=False,
call_depth=0,
)
assert result == {"ok": True}

View File

@ -0,0 +1,33 @@
from __future__ import annotations
import pytest
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueMessageEndEvent, QueuePingEvent
class TestWorkflowAppQueueManager:
def test_publish_stop_events_trigger_stop(self):
manager = WorkflowAppQueueManager(
task_id="task",
user_id="user",
invoke_from=InvokeFrom.DEBUGGER,
app_mode="workflow",
)
manager._is_stopped = lambda: True
with pytest.raises(GenerateTaskStoppedError):
manager._publish(QueueMessageEndEvent(llm_result=None), PublishFrom.APPLICATION_MANAGER)
def test_publish_non_stop_event_does_not_raise(self):
manager = WorkflowAppQueueManager(
task_id="task",
user_id="user",
invoke_from=InvokeFrom.DEBUGGER,
app_mode="workflow",
)
manager._publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)

View File

@ -0,0 +1,9 @@
from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError
class TestWorkflowErrors:
def test_workflow_paused_in_blocking_mode_error_attributes(self):
err = WorkflowPausedInBlockingModeError()
assert err.error_code == "workflow_paused_in_blocking_mode"
assert err.code == 400
assert "blocking response mode" in err.description

View File

@ -0,0 +1,133 @@
from collections.abc import Generator
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.entities.task_entities import (
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
class TestWorkflowGenerateResponseConverter:
def test_blocking_full_response(self):
blocking = WorkflowAppBlockingResponse(
task_id="t1",
workflow_run_id="r1",
data=WorkflowAppBlockingResponse.Data(
id="exec-1",
workflow_id="wf-1",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={"ok": True},
error=None,
elapsed_time=1.2,
total_tokens=10,
total_steps=2,
created_at=1,
finished_at=2,
),
)
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
assert response["workflow_run_id"] == "r1"
def test_stream_simple_response_node_events(self):
node_start = NodeStartStreamResponse(
task_id="t1",
workflow_run_id="r1",
data=NodeStartStreamResponse.Data(
id="e1",
node_id="n1",
node_type="answer",
title="Answer",
index=1,
created_at=1,
),
)
node_finish = NodeFinishStreamResponse(
task_id="t1",
workflow_run_id="r1",
data=NodeFinishStreamResponse.Data(
id="e1",
node_id="n1",
node_type="answer",
title="Answer",
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
elapsed_time=0.1,
created_at=1,
finished_at=2,
),
)
def stream() -> Generator[WorkflowAppStreamResponse, None, None]:
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=PingStreamResponse(task_id="t1"))
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_start)
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_finish)
yield WorkflowAppStreamResponse(
workflow_run_id="r1", stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom"))
)
converted = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
assert converted[0] == "ping"
assert converted[1]["event"] == "node_started"
assert converted[2]["event"] == "node_finished"
assert converted[3]["event"] == "error"
def test_convert_stream_simple_response_handles_ping_and_nodes(self):
def _gen():
yield WorkflowAppStreamResponse(stream_response=PingStreamResponse(task_id="task"))
yield WorkflowAppStreamResponse(
workflow_run_id="run",
stream_response=NodeStartStreamResponse(
task_id="task",
workflow_run_id="run",
data=NodeStartStreamResponse.Data(
id="node-exec",
node_id="node",
node_type="start",
title="Start",
index=1,
created_at=1,
),
),
)
yield WorkflowAppStreamResponse(
workflow_run_id="run",
stream_response=NodeFinishStreamResponse(
task_id="task",
workflow_run_id="run",
data=NodeFinishStreamResponse.Data(
id="node-exec",
node_id="node",
node_type="start",
title="Start",
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
created_at=1,
finished_at=2,
elapsed_time=1.0,
error=None,
),
),
)
chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(_gen()))
assert chunks[0] == "ping"
assert chunks[1]["event"] == "node_started"
assert chunks[2]["event"] == "node_finished"
def test_convert_stream_full_response_handles_error(self):
def _gen():
yield WorkflowAppStreamResponse(
workflow_run_id="run",
stream_response=ErrorStreamResponse(task_id="task", err=ValueError("boom")),
)
chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(_gen()))
assert chunks[0]["event"] == "error"

View File

@ -0,0 +1,868 @@
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from types import SimpleNamespace
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
from dify_graph.enums import NodeType, WorkflowExecutionStatus
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from models.enums import CreatorUserRole
from models.model import AppMode, EndUser
def _make_pipeline():
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="user",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
trace_manager=None,
workflow_execution_id="run-id",
extras={},
call_depth=0,
)
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
user = SimpleNamespace(id="user", session_id="session")
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
user=user,
stream=False,
draft_var_saver_factory=lambda **kwargs: None,
)
return pipeline
class TestWorkflowGenerateTaskPipeline:
def test_to_blocking_response_handles_pause(self):
pipeline = _make_pipeline()
def _gen():
yield WorkflowPauseStreamResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowPauseStreamResponse.Data(
workflow_run_id="run",
status=WorkflowExecutionStatus.PAUSED,
outputs={},
created_at=1,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
),
)
response = pipeline._to_blocking_response(_gen())
assert response.data.status == WorkflowExecutionStatus.PAUSED
def test_to_blocking_response_handles_finish(self):
pipeline = _make_pipeline()
def _gen():
yield WorkflowFinishStreamResponse(
task_id="task",
workflow_run_id="run",
data=WorkflowFinishStreamResponse.Data(
id="run",
workflow_id="workflow-id",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={"ok": True},
error=None,
elapsed_time=1.0,
total_tokens=5,
total_steps=2,
created_at=1,
finished_at=2,
),
)
response = pipeline._to_blocking_response(_gen())
assert response.data.outputs == {"ok": True}
def test_listen_audio_msg_returns_audio_stream(self):
pipeline = _make_pipeline()
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
assert isinstance(response, MessageAudioStreamResponse)
def test_handle_ping_event(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
assert isinstance(responses[0], PingStreamResponse)
def test_handle_error_event(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
assert isinstance(responses[0], ValueError)
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
@contextmanager
def _fake_session():
yield SimpleNamespace()
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
monkeypatch.setattr(pipeline, "_save_workflow_app_log", lambda **kwargs: None)
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
assert pipeline._workflow_execution_id == "run-id"
assert responses == ["started"]
def test_handle_node_succeeded_event_saves_output(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
pipeline._save_output_for_event = lambda event, node_execution_id: None
pipeline._workflow_execution_id = "run-id"
event = QueueNodeSucceededEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.START,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
)
responses = list(pipeline._handle_node_succeeded_event(event))
assert responses == ["done"]
def test_handle_workflow_failed_event_yields_error(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
responses = list(
pipeline._handle_workflow_failed_and_stop_events(QueueWorkflowFailedEvent(error="fail", exceptions_count=1))
)
assert responses[0] == "finish"
def test_handle_text_chunk_event_publishes_tts(self):
pipeline = _make_pipeline()
published: list[object] = []
class _Publisher:
def publish(self, message):
published.append(message)
event = QueueTextChunkEvent(text="hi", from_variable_selector=["x"])
queue_message = SimpleNamespace(event=event)
responses = list(
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
)
assert responses[0].data.text == "hi"
assert published == [queue_message]
def test_dispatch_event_handles_node_failed(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
event = QueueNodeFailedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.START,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="err",
)
assert list(pipeline._dispatch_event(event)) == ["done"]
def test_handle_stop_event_yields_finish(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
responses = list(
pipeline._handle_workflow_failed_and_stop_events(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)
)
)
assert responses == ["finish"]
def test_save_workflow_app_log_created_from(self):
pipeline = _make_pipeline()
pipeline._application_generate_entity.invoke_from = InvokeFrom.SERVICE_API
pipeline._user_id = "user"
added: list[object] = []
class _Session:
def add(self, item):
added.append(item)
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
assert added
def test_iteration_loop_and_human_input_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: "iter"
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "next"
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: "done"
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop"
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
pipeline._workflow_response_converter.handle_agent_log = lambda **kwargs: "log"
iter_start = QueueIterationStartEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
iter_next = QueueIterationNextEvent(
index=1,
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
node_run_index=1,
)
iter_done = QueueIterationCompletedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
loop_start = QueueLoopStartEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
loop_next = QueueLoopNextEvent(
index=1,
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
node_run_index=1,
)
loop_done = QueueLoopCompletedEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="LLM",
start_at=datetime.utcnow(),
node_run_index=1,
)
filled_event = QueueHumanInputFormFilledEvent(
node_execution_id="exec",
node_id="node",
node_type=NodeType.LLM,
node_title="title",
rendered_content="content",
action_id="action",
action_text="action",
)
timeout_event = QueueHumanInputFormTimeoutEvent(
node_id="node",
node_type=NodeType.LLM,
node_title="title",
expiration_time=datetime.utcnow(),
)
agent_event = QueueAgentLogEvent(
id="log",
label="label",
node_execution_id="exec",
parent_id=None,
error=None,
status="done",
data={},
metadata={},
node_id="node",
)
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter"]
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["next"]
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["done"]
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop"]
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
assert list(pipeline._handle_agent_log_event(agent_event)) == ["log"]
def test_wrapper_process_stream_response_emits_audio_end(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._workflow_features_dict = {
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
}
pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
class _Publisher:
def __init__(self, *args, **kwargs):
self.calls = 0
def check_and_get_audio(self):
self.calls += 1
if self.calls == 1:
return AudioTrunk(status="stream", audio="data")
if self.calls == 2:
return None
return AudioTrunk(status="finish", audio="")
def publish(self, message):
return None
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
_Publisher,
)
responses = list(pipeline._wrapper_process_stream_response())
assert any(isinstance(item, MessageAudioStreamResponse) for item in responses)
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
def test_init_with_end_user_sets_role_and_system_user(self):
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant",
app_id="app",
app_mode=AppMode.WORKFLOW,
additional_features=AppAdditionalFeatures(),
variables=[],
workflow_id="workflow-id",
)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
task_id="task",
app_config=app_config,
inputs={},
files=[],
user_id="end-user-id",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
trace_manager=None,
workflow_execution_id="run-id",
extras={},
call_depth=0,
)
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None)
end_user = EndUser(tenant_id="tenant", type="session", name="user", session_id="session-id")
end_user.id = "end-user-id"
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=end_user,
stream=False,
draft_var_saver_factory=lambda **kwargs: None,
)
assert pipeline._created_by_role == CreatorUserRole.END_USER
assert pipeline._workflow_system_variables.user_id == "session-id"
def test_process_returns_stream_and_blocking_variants(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.stream = True
pipeline._wrapper_process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
stream_response = list(pipeline.process())
assert len(stream_response) == 1
assert stream_response[0].workflow_run_id is None
pipeline._base_task_pipeline.stream = False
pipeline._wrapper_process_stream_response = lambda **kwargs: iter(
[
WorkflowFinishStreamResponse(
task_id="task",
workflow_run_id="run-id",
data=WorkflowFinishStreamResponse.Data(
id="run-id",
workflow_id="workflow-id",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs={},
error=None,
elapsed_time=0.1,
total_tokens=0,
total_steps=0,
created_at=1,
finished_at=2,
),
)
]
)
blocking_response = pipeline.process()
assert blocking_response.workflow_run_id == "run-id"
def test_to_blocking_response_handles_error_and_unexpected_end(self):
pipeline = _make_pipeline()
def _error_gen():
yield ErrorStreamResponse(task_id="task", err=ValueError("boom"))
with pytest.raises(ValueError, match="boom"):
pipeline._to_blocking_response(_error_gen())
def _unexpected_gen():
yield PingStreamResponse(task_id="task")
with pytest.raises(ValueError, match="queue listening stopped unexpectedly"):
pipeline._to_blocking_response(_unexpected_gen())
def test_to_stream_response_tracks_workflow_run_id(self):
pipeline = _make_pipeline()
def _gen():
yield WorkflowStartStreamResponse(
task_id="task",
workflow_run_id="run-id",
data=WorkflowStartStreamResponse.Data(
id="run-id",
workflow_id="workflow-id",
inputs={},
created_at=1,
),
)
yield PingStreamResponse(task_id="task")
stream_responses = list(pipeline._to_stream_response(_gen()))
assert stream_responses[0].workflow_run_id == "run-id"
assert stream_responses[1].workflow_run_id == "run-id"
def test_listen_audio_msg_returns_none_without_publisher(self):
pipeline = _make_pipeline()
assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None
def test_wrapper_process_stream_response_without_tts(self):
pipeline = _make_pipeline()
pipeline._workflow_features_dict = {}
pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
responses = list(pipeline._wrapper_process_stream_response())
assert responses == [PingStreamResponse(task_id="task")]
def test_wrapper_process_stream_response_final_audio_none_then_finish(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._workflow_features_dict = {
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
}
pipeline._process_stream_response = lambda **kwargs: iter([])
sleep_spy = []
class _Publisher:
def __init__(self, *args, **kwargs):
self.calls = 0
def check_and_get_audio(self):
self.calls += 1
if self.calls == 1:
return None
return AudioTrunk(status="finish", audio="")
def publish(self, message):
_ = message
time_values = iter([0.0, 0.0, 0.2])
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: next(time_values))
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.time.sleep", lambda _: sleep_spy.append(True)
)
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
_Publisher,
)
responses = list(pipeline._wrapper_process_stream_response())
assert sleep_spy
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
def test_wrapper_process_stream_response_handles_audio_exception(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._workflow_features_dict = {
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
}
pipeline._process_stream_response = lambda **kwargs: iter([])
class _Publisher:
def __init__(self, *args, **kwargs):
self.called = False
def check_and_get_audio(self):
if not self.called:
self.called = True
raise RuntimeError("tts failure")
return AudioTrunk(status="finish", audio="")
def publish(self, message):
_ = message
logger_exception = []
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: 0.0)
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.logger.exception",
lambda *args, **kwargs: logger_exception.append((args, kwargs)),
)
monkeypatch.setattr(
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
_Publisher,
)
responses = list(pipeline._wrapper_process_stream_response())
assert logger_exception
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
def test_database_session_rolls_back_on_error(self, monkeypatch):
pipeline = _make_pipeline()
calls = {"commit": 0, "rollback": 0}
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def commit(self):
calls["commit"] += 1
def rollback(self):
calls["rollback"] += 1
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
with pytest.raises(RuntimeError, match="db error"):
with pipeline._database_session():
raise RuntimeError("db error")
assert calls["commit"] == 0
assert calls["rollback"] == 1
def test_node_retry_and_started_handlers_cover_none_and_value(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
retry_event = QueueNodeRetryEvent(
node_execution_id="exec",
node_id="node",
node_title="title",
node_type=NodeType.LLM,
node_run_index=1,
start_at=datetime.utcnow(),
provider_type="provider",
provider_id="provider-id",
error="error",
retry_index=1,
)
started_event = QueueNodeStartedEvent(
node_execution_id="exec",
node_id="node",
node_title="title",
node_type=NodeType.LLM,
node_run_index=1,
start_at=datetime.utcnow(),
provider_type="provider",
provider_id="provider-id",
)
pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: None
assert list(pipeline._handle_node_retry_event(retry_event)) == []
pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: "retry"
assert list(pipeline._handle_node_retry_event(retry_event)) == ["retry"]
pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: None
assert list(pipeline._handle_node_started_event(started_event)) == []
pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: "started"
assert list(pipeline._handle_node_started_event(started_event)) == ["started"]
def test_handle_node_exception_event_saves_output(self):
pipeline = _make_pipeline()
saved_ids: list[str] = []
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
pipeline._save_output_for_event = lambda event, node_execution_id: saved_ids.append(node_execution_id)
event = QueueNodeExceptionEvent(
node_execution_id="exec-id",
node_id="node",
node_type=NodeType.START,
start_at=datetime.utcnow(),
inputs={},
outputs={},
process_data={},
error="boom",
)
responses = list(pipeline._handle_node_failed_events(event))
assert responses == ["failed"]
assert saved_ids == ["exec-id"]
def test_success_partial_and_pause_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
assert list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) == ["finish"]
assert list(
pipeline._handle_workflow_partial_success_event(
QueueWorkflowPartialSuccessEvent(exceptions_count=2, outputs={})
)
) == ["finish"]
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: [
"pause-a",
"pause-b",
]
pause_event = QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=["node"])
assert list(pipeline._handle_workflow_paused_event(pause_event)) == ["pause-a", "pause-b"]
def test_text_chunk_handler_returns_empty_when_text_missing(self):
pipeline = _make_pipeline()
event = QueueTextChunkEvent.model_construct(text=None, from_variable_selector=None)
assert list(pipeline._handle_text_chunk_event(event)) == []
def test_dispatch_event_direct_failed_and_unhandled_paths(self):
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
assert list(pipeline._dispatch_event(QueuePingEvent())) == ["ping"]
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["workflow-failed"])
assert list(pipeline._dispatch_event(QueueWorkflowFailedEvent(error="failed", exceptions_count=1))) == [
"workflow-failed"
]
assert list(pipeline._dispatch_event(SimpleNamespace())) == []
def test_process_stream_response_main_match_paths_and_cleanup(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
start_at=0.0,
)
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
[
SimpleNamespace(event=QueueWorkflowStartedEvent()),
SimpleNamespace(event=QueueTextChunkEvent(text="hello")),
SimpleNamespace(event=QueuePingEvent()),
SimpleNamespace(event=QueueErrorEvent(error="e")),
]
)
pipeline._handle_workflow_started_event = lambda event, **kwargs: iter(["started"])
pipeline._handle_text_chunk_event = lambda event, **kwargs: iter(["text"])
pipeline._dispatch_event = lambda event, **kwargs: iter(["dispatched"])
pipeline._handle_error_event = lambda event, **kwargs: iter(["error"])
publisher_calls: list[object] = []
class _Publisher:
def publish(self, message):
publisher_calls.append(message)
responses = list(pipeline._process_stream_response(tts_publisher=_Publisher()))
assert responses == ["started", "text", "dispatched", "error"]
assert publisher_calls == [None]
def test_process_stream_response_break_paths(self):
pipeline = _make_pipeline()
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
[SimpleNamespace(event=QueueWorkflowFailedEvent(error="fail", exceptions_count=1))]
)
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["failed"])
assert list(pipeline._process_stream_response()) == ["failed"]
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
[SimpleNamespace(event=QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=[]))]
)
pipeline._handle_workflow_paused_event = lambda event, **kwargs: iter(["paused"])
assert list(pipeline._process_stream_response()) == ["paused"]
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
[SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))]
)
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["stopped"])
assert list(pipeline._process_stream_response()) == ["stopped"]
def test_save_workflow_app_log_covers_invoke_from_variants(self):
pipeline = _make_pipeline()
pipeline._user_id = "user-id"
added: list[object] = []
class _Session:
def add(self, item):
added.append(item)
pipeline._application_generate_entity.invoke_from = InvokeFrom.EXPLORE
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
assert added[-1].created_from == "installed-app"
pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
assert added[-1].created_from == "web-app"
count_before = len(added)
pipeline._application_generate_entity.invoke_from = InvokeFrom.DEBUGGER
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
assert len(added) == count_before
pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None)
assert len(added) == count_before
def test_save_output_for_event_writes_draft_variables(self, monkeypatch):
pipeline = _make_pipeline()
saver_calls: list[tuple[object, object]] = []
captured_factory_args: dict[str, object] = {}
class _Saver:
def save(self, process_data, outputs):
saver_calls.append((process_data, outputs))
def _factory(**kwargs):
captured_factory_args.update(kwargs)
return _Saver()
class _Begin:
def __enter__(self):
return None
def __exit__(self, exc_type, exc, tb):
return False
class _Session:
def __init__(self, *args, **kwargs):
_ = args, kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def begin(self):
return _Begin()
pipeline._draft_var_saver_factory = _factory
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
event = QueueNodeSucceededEvent(
node_execution_id="exec-id",
node_id="node-id",
node_type=NodeType.START,
in_loop_id="loop-id",
start_at=datetime.utcnow(),
process_data={"k": "v"},
outputs={"out": 1},
)
pipeline._save_output_for_event(event=event, node_execution_id="exec-id")
assert captured_factory_args["node_execution_id"] == "exec-id"
assert captured_factory_args["enclosing_node_id"] == "loop-id"
assert saver_calls == [({"k": "v"}, {"out": 1})]