mirror of https://github.com/langgenius/dify.git
test: unit test cases for core.app.apps module (#32482)
This commit is contained in:
parent
44713a5c0f
commit
0045e387f5
|
|
@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,75 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestAdvancedChatAppConfigManager:
|
||||
def test_get_app_config(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value)
|
||||
workflow = SimpleNamespace(id="wf-1", features_dict={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow)
|
||||
|
||||
assert app_config.workflow_id == "wf-1"
|
||||
assert app_config.app_mode == AppMode.ADVANCED_CHAT
|
||||
|
||||
def test_config_validate_filters_keys(self):
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
config = kwargs.get("config") if kwargs else args[-1]
|
||||
config = {**config, key: value}
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("opening_statement", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("suggested_questions_after_answer", 3),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("speech_to_text", 4),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 5),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("retriever_resource", 6),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 7),
|
||||
),
|
||||
):
|
||||
filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={})
|
||||
|
||||
assert filtered["file_upload"] == 1
|
||||
assert filtered["opening_statement"] == 2
|
||||
assert filtered["suggested_questions_after_answer"] == 3
|
||||
assert filtered["speech_to_text"] == 4
|
||||
assert filtered["text_to_speech"] == 5
|
||||
assert filtered["retriever_resource"] == 6
|
||||
assert filtered["sensitive_word_avoidance"] == 7
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,96 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateResponseConverter:
|
||||
def test_blocking_simple_response_metadata(self):
|
||||
data = ChatbotAppBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
answer="hi",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
)
|
||||
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_stream_simple_response_includes_node_events(self):
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
elapsed_time=0.1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t1"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=node_start,
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=node_finish,
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
|
||||
)
|
||||
|
||||
converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert converted[0] == "ping"
|
||||
assert converted[1]["event"] == "node_started"
|
||||
assert converted[2]["event"] == "node_finished"
|
||||
assert converted[3]["event"] == "error"
|
||||
|
|
@ -0,0 +1,600 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from models.enums import MessageStatus
|
||||
from models.model import AppMode, EndUser
|
||||
|
||||
|
||||
def _make_pipeline():
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
|
||||
message = SimpleNamespace(
|
||||
id="message-id",
|
||||
query="hello",
|
||||
created_at=datetime.utcnow(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
)
|
||||
conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session")
|
||||
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=False,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_ensure_workflow_initialized_raises(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
with pytest.raises(ValueError, match="workflow run not initialized"):
|
||||
pipeline._ensure_workflow_initialized()
|
||||
|
||||
def test_to_blocking_response_returns_message_end(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.answer = "done"
|
||||
|
||||
def _gen():
|
||||
yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"})
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert response.data.answer == "done"
|
||||
assert response.data.metadata == {"k": "v"}
|
||||
|
||||
def test_handle_text_chunk_event_updates_state(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager = SimpleNamespace(
|
||||
message_to_stream_response=lambda **kwargs: MessageEndStreamResponse(
|
||||
task_id="task", id="message-id", metadata={}
|
||||
)
|
||||
)
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=None)
|
||||
|
||||
responses = list(pipeline._handle_text_chunk_event(event))
|
||||
|
||||
assert pipeline._task_state.answer == "hi"
|
||||
assert responses
|
||||
|
||||
def test_listen_audio_msg_returns_audio_stream(self):
|
||||
pipeline = _make_pipeline()
|
||||
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
|
||||
|
||||
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
|
||||
|
||||
assert isinstance(response, MessageAudioStreamResponse)
|
||||
|
||||
def test_handle_ping_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
|
||||
|
||||
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
|
||||
|
||||
assert isinstance(responses[0], PingStreamResponse)
|
||||
|
||||
def test_handle_error_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
pipeline._database_session = _fake_session
|
||||
|
||||
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
|
||||
|
||||
assert isinstance(responses[0], ValueError)
|
||||
|
||||
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace())
|
||||
|
||||
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
|
||||
|
||||
assert pipeline._workflow_run_id == "run-id"
|
||||
assert responses == ["started"]
|
||||
|
||||
def test_message_end_to_stream_response_strips_annotation_reply(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id="ann",
|
||||
account=AnnotationReplyAccount(id="acc", name="acc"),
|
||||
)
|
||||
|
||||
response = pipeline._message_end_to_stream_response()
|
||||
|
||||
assert "annotation_reply" not in response.metadata
|
||||
|
||||
def test_handle_output_moderation_chunk_publishes_stop(self):
|
||||
pipeline = _make_pipeline()
|
||||
events: list[object] = []
|
||||
|
||||
class _Moderation:
|
||||
def should_direct_output(self):
|
||||
return True
|
||||
|
||||
def get_final_output(self):
|
||||
return "final"
|
||||
|
||||
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
|
||||
pipeline._base_task_pipeline.queue_manager = SimpleNamespace(
|
||||
publish=lambda event, pub_from: events.append(event)
|
||||
)
|
||||
|
||||
result = pipeline._handle_output_moderation_chunk("ignored")
|
||||
|
||||
assert result is True
|
||||
assert pipeline._task_state.answer == "final"
|
||||
assert any(isinstance(event, QueueTextChunkEvent) for event in events)
|
||||
assert any(isinstance(event, QueueStopEvent) for event in events)
|
||||
|
||||
def test_handle_node_succeeded_event_records_files(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [
|
||||
{"type": "file", "transfer_method": "local"}
|
||||
]
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
|
||||
event = SimpleNamespace(
|
||||
node_type=NodeType.ANSWER,
|
||||
outputs={"k": "v"},
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
assert responses == ["done"]
|
||||
assert pipeline._recorded_files
|
||||
|
||||
def test_iteration_and_loop_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = (
|
||||
lambda **kwargs: "iter_start"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = (
|
||||
lambda **kwargs: "iter_done"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
|
||||
|
||||
iter_start = QueueIterationStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_next = QueueIterationNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_done = QueueIterationCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_start = QueueLoopStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_next = QueueLoopNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_done = QueueLoopCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"]
|
||||
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"]
|
||||
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"]
|
||||
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"]
|
||||
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
|
||||
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
|
||||
|
||||
def test_workflow_finish_handlers(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"]
|
||||
pipeline._persist_human_input_extra_content = lambda **kwargs: None
|
||||
pipeline._save_message = lambda **kwargs: None
|
||||
pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id")
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace(scalar=lambda *args, **kwargs: None)
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={})))
|
||||
assert len(succeeded_responses) == 2
|
||||
assert isinstance(succeeded_responses[0], MessageEndStreamResponse)
|
||||
assert succeeded_responses[1] == "finish"
|
||||
|
||||
partial_success_responses = list(
|
||||
pipeline._handle_workflow_partial_success_event(
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
)
|
||||
)
|
||||
assert len(partial_success_responses) == 2
|
||||
assert isinstance(partial_success_responses[0], MessageEndStreamResponse)
|
||||
assert partial_success_responses[1] == "finish"
|
||||
assert (
|
||||
list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0]
|
||||
== "finish"
|
||||
)
|
||||
assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [
|
||||
"pause"
|
||||
]
|
||||
|
||||
def test_node_failure_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
|
||||
failed_event = QueueNodeFailedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
exc_event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"]
|
||||
assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"]
|
||||
|
||||
def test_handle_text_chunk_event_tracks_streaming_metrics(self):
|
||||
pipeline = _make_pipeline()
|
||||
published: list[object] = []
|
||||
|
||||
class _Publisher:
|
||||
def publish(self, message):
|
||||
published.append(message)
|
||||
|
||||
pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk")
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=["a"])
|
||||
queue_message = SimpleNamespace(event=event)
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
|
||||
)
|
||||
|
||||
assert responses == ["chunk"]
|
||||
assert pipeline._task_state.is_streaming_response is True
|
||||
assert pipeline._task_state.first_token_time is not None
|
||||
assert pipeline._task_state.last_token_time is not None
|
||||
assert pipeline._task_state.answer == "hi"
|
||||
assert published == [queue_message]
|
||||
|
||||
def test_handle_output_moderation_chunk_appends_token(self):
|
||||
pipeline = _make_pipeline()
|
||||
seen: list[str] = []
|
||||
|
||||
class _Moderation:
|
||||
def should_direct_output(self):
|
||||
return False
|
||||
|
||||
def append_new_token(self, text):
|
||||
seen.append(text)
|
||||
|
||||
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
|
||||
|
||||
result = pipeline._handle_output_moderation_chunk("token")
|
||||
|
||||
assert result is False
|
||||
assert seen == ["token"]
|
||||
|
||||
def test_handle_retriever_and_annotation_events(self):
|
||||
pipeline = _make_pipeline()
|
||||
calls = {"retriever": 0, "annotation": 0}
|
||||
|
||||
def _hit_retriever(event):
|
||||
calls["retriever"] += 1
|
||||
|
||||
def _hit_annotation(event):
|
||||
calls["annotation"] += 1
|
||||
|
||||
pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever
|
||||
pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation
|
||||
|
||||
retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[])
|
||||
annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann")
|
||||
|
||||
assert list(pipeline._handle_retriever_resources_event(retriever_event)) == []
|
||||
assert list(pipeline._handle_annotation_reply_event(annotation_event)) == []
|
||||
assert calls == {"retriever": 1, "annotation": 1}
|
||||
|
||||
def test_handle_message_replace_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
|
||||
|
||||
event = QueueMessageReplaceEvent(
|
||||
text="new",
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_message_replace_event(event)) == ["replace"]
|
||||
|
||||
def test_handle_human_input_events(self):
|
||||
pipeline = _make_pipeline()
|
||||
persisted: list[str] = []
|
||||
pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved")
|
||||
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
|
||||
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
|
||||
|
||||
filled_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
rendered_content="content",
|
||||
action_id="action",
|
||||
action_text="action",
|
||||
)
|
||||
timeout_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
expiration_time=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
|
||||
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
|
||||
assert persisted == ["saved"]
|
||||
|
||||
def test_save_message_strips_markdown_and_sets_usage(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._recorded_files = [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "remote",
|
||||
"remote_url": "http://example.com/file.png",
|
||||
"related_id": "file-id",
|
||||
}
|
||||
]
|
||||
pipeline._task_state.answer = " hello"
|
||||
pipeline._task_state.is_streaming_response = True
|
||||
pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1
|
||||
pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2
|
||||
|
||||
message = SimpleNamespace(
|
||||
id="message-id",
|
||||
status=MessageStatus.PAUSED,
|
||||
answer="",
|
||||
updated_at=None,
|
||||
provider_response_latency=None,
|
||||
message_tokens=None,
|
||||
message_unit_price=None,
|
||||
message_price_unit=None,
|
||||
answer_tokens=None,
|
||||
answer_unit_price=None,
|
||||
answer_price_unit=None,
|
||||
total_price=None,
|
||||
currency=None,
|
||||
message_metadata=None,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_account_id=None,
|
||||
from_end_user_id="end-user",
|
||||
)
|
||||
|
||||
class _Session:
|
||||
def scalar(self, *args, **kwargs):
|
||||
return message
|
||||
|
||||
def add_all(self, items):
|
||||
self.items = items
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state)
|
||||
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
assert message.answer == "hello"
|
||||
assert message.message_metadata
|
||||
|
||||
def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_end_to_stream_response = lambda: "end"
|
||||
saved: list[str] = []
|
||||
|
||||
def _save_message(**kwargs):
|
||||
saved.append("saved")
|
||||
|
||||
pipeline._save_message = _save_message
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)))
|
||||
|
||||
assert responses == ["end"]
|
||||
assert saved == ["saved"]
|
||||
|
||||
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
|
||||
pipeline._message_end_to_stream_response = lambda: "end"
|
||||
|
||||
saved: list[str] = []
|
||||
|
||||
def _save_message(**kwargs):
|
||||
saved.append("saved")
|
||||
|
||||
pipeline._save_message = _save_message
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent()))
|
||||
|
||||
assert responses == ["replace", "end"]
|
||||
assert saved == ["saved"]
|
||||
|
||||
def test_dispatch_event_handles_node_exception(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
|
||||
pipeline._save_output_for_event = lambda *args, **kwargs: None
|
||||
|
||||
event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._dispatch_event(event)) == ["failed"]
|
||||
|
|
@ -0,0 +1,302 @@
|
|||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.agent_chat.app_config_manager import (
|
||||
AgentChatAppConfigManager,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
|
||||
|
||||
class TestAgentChatAppConfigManagerGetAppConfig:
|
||||
def test_get_app_config_override_config(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"ignored": True}
|
||||
|
||||
override_config = {"model": {"provider": "p"}}
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=override_config,
|
||||
)
|
||||
|
||||
assert result.app_model_config_dict == override_config
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert result.variables == "variables"
|
||||
assert result.external_data_variables == "external"
|
||||
|
||||
def test_get_app_config_conversation_specific(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
conversation = mocker.MagicMock()
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=None,
|
||||
)
|
||||
|
||||
assert result.app_model_config_dict == app_model_config.to_dict.return_value
|
||||
assert result.app_model_config_from.value == "conversation-specific-config"
|
||||
|
||||
def test_get_app_config_latest_config(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=None,
|
||||
)
|
||||
|
||||
assert result.app_model_config_from.value == "app-latest-config"
|
||||
|
||||
|
||||
class TestAgentChatAppConfigManagerConfigValidate:
|
||||
def test_config_validate_filters_related_keys(self, mocker):
|
||||
config = {
|
||||
"model": {},
|
||||
"user_input_form": {},
|
||||
"file_upload": {},
|
||||
"prompt_template": {},
|
||||
"agent_mode": {},
|
||||
"opening_statement": {},
|
||||
"suggested_questions_after_answer": {},
|
||||
"speech_to_text": {},
|
||||
"text_to_speech": {},
|
||||
"retriever_resource": {},
|
||||
"dataset": {},
|
||||
"moderation": {},
|
||||
"extra": "value",
|
||||
}
|
||||
|
||||
def return_with_key(key):
|
||||
return config, [key]
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("model"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("file_upload"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda app_mode, cfg: return_with_key("prompt_template"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
AgentChatAppConfigManager,
|
||||
"validate_agent_mode_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("opening_statement"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("speech_to_text"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("text_to_speech"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("retriever_resource"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("moderation"),
|
||||
)
|
||||
|
||||
filtered = AgentChatAppConfigManager.config_validate("tenant", config)
|
||||
assert set(filtered.keys()) == {
|
||||
"model",
|
||||
"user_input_form",
|
||||
"file_upload",
|
||||
"prompt_template",
|
||||
"agent_mode",
|
||||
"opening_statement",
|
||||
"suggested_questions_after_answer",
|
||||
"speech_to_text",
|
||||
"text_to_speech",
|
||||
"retriever_resource",
|
||||
"dataset",
|
||||
"moderation",
|
||||
}
|
||||
assert "extra" not in filtered
|
||||
|
||||
|
||||
class TestValidateAgentModeAndSetDefaults:
|
||||
def test_defaults_when_missing(self):
|
||||
config = {}
|
||||
updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert "agent_mode" in updated
|
||||
assert updated["agent_mode"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"] == []
|
||||
assert keys == ["agent_mode"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_mode",
|
||||
["invalid", 123],
|
||||
)
|
||||
def test_agent_mode_type_validation(self, agent_mode):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode})
|
||||
|
||||
def test_agent_mode_empty_list_defaults(self):
|
||||
config = {"agent_mode": []}
|
||||
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert updated["agent_mode"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"] == []
|
||||
|
||||
def test_enabled_must_be_bool(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}})
|
||||
|
||||
def test_strategy_must_be_valid(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}}
|
||||
)
|
||||
|
||||
def test_tools_must_be_list(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}}
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_requires_id(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}}
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_id_must_be_uuid(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}},
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_id_not_exists(self, mocker):
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
|
||||
return_value=False,
|
||||
)
|
||||
dataset_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}},
|
||||
)
|
||||
|
||||
def test_old_tool_enabled_must_be_bool(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"])
|
||||
def test_new_style_tool_requires_fields(self, missing_key):
|
||||
tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"}
|
||||
tool.pop(missing_key, None)
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": [tool]}}
|
||||
)
|
||||
|
||||
def test_valid_old_and_new_style_tools(self, mocker):
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
|
||||
return_value=True,
|
||||
)
|
||||
dataset_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": PlanningStrategy.ROUTER.value,
|
||||
"tools": [
|
||||
{"dataset": {"id": dataset_id}},
|
||||
{
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "p1",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"][1]["enabled"] is False
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
import contextlib
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id):
|
||||
self.id = user_id
|
||||
self.session_id = f"session-{user_id}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = AgentChatAppGenerator()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.current_app",
|
||||
new=mocker.MagicMock(_get_current_object=mocker.MagicMock()),
|
||||
)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx")
|
||||
return gen
|
||||
|
||||
|
||||
class TestAgentChatAppGeneratorGenerate:
|
||||
def test_generate_rejects_blocking_mode(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False)
|
||||
|
||||
def test_generate_requires_query(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock())
|
||||
|
||||
def test_generate_rejects_non_string_query(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"query": 123, "inputs": {}},
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
def test_generate_override_requires_debugger(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_success_with_debugger_override(self, generator, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
user = DummyAccount("user")
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
|
||||
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
|
||||
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
|
||||
generator._init_generate_records = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
|
||||
)
|
||||
generator._handle_response = mocker.MagicMock(return_value="response")
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate",
|
||||
return_value={"validated": True},
|
||||
)
|
||||
app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[])
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
|
||||
return_value=["file-obj"],
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ConversationService.get_conversation",
|
||||
return_value=mocker.MagicMock(id="conv"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
queue_manager = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
|
||||
return_value=queue_manager,
|
||||
)
|
||||
|
||||
thread_obj = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.threading.Thread",
|
||||
return_value=thread_obj,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
|
||||
return_value=app_entity,
|
||||
)
|
||||
|
||||
args = {
|
||||
"query": "hello",
|
||||
"inputs": {"name": "world"},
|
||||
"conversation_id": "conv",
|
||||
"model_config": {"model": {"provider": "p"}},
|
||||
"files": [{"id": "f1"}],
|
||||
}
|
||||
|
||||
result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
thread_obj.start.assert_called_once()
|
||||
|
||||
def test_generate_without_file_config(self, generator, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
user = DummyAccount("user")
|
||||
|
||||
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
|
||||
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
|
||||
generator._init_generate_records = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
|
||||
)
|
||||
generator._handle_response = mocker.MagicMock(return_value="response")
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
|
||||
return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
|
||||
return_value=["file-obj"],
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
thread_obj = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.threading.Thread",
|
||||
return_value=thread_obj,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
|
||||
return_value=app_entity,
|
||||
)
|
||||
|
||||
args = {"query": "hello", "inputs": {"name": "world"}}
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
class TestAgentChatAppGeneratorWorker:
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_context(self, mocker):
|
||||
@contextlib.contextmanager
|
||||
def ctx_manager(*args, **kwargs):
|
||||
yield
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager)
|
||||
|
||||
def test_generate_worker_handles_generate_task_stopped(self, generator, mocker):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = GenerateTaskStoppedError()
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
queue_manager.publish_error.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error",
|
||||
[
|
||||
InvokeAuthorizationError("bad"),
|
||||
ValidationError.from_exception_data("TestModel", []),
|
||||
ValueError("bad"),
|
||||
Exception("bad"),
|
||||
],
|
||||
)
|
||||
def test_generate_worker_publishes_errors(self, generator, mocker, error):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = error
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert queue_manager.publish_error.called
|
||||
|
||||
def test_generate_worker_logs_value_error_when_debug(self, generator, mocker):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = ValueError("bad")
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True))
|
||||
logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
logger.exception.assert_called_once()
|
||||
|
|
@ -0,0 +1,413 @@
|
|||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return AgentChatAppRunner()
|
||||
|
||||
|
||||
class TestAgentChatAppRunnerRun:
|
||||
def test_run_app_not_found(self, runner, mocker):
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
|
||||
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_moderation_error_direct_output(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_annotation_reply_short_circuits(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
user_id="user",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
annotation = mocker.MagicMock(id="anno", content="answer")
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation)
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
|
||||
queue_manager = mocker.MagicMock()
|
||||
runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
queue_manager.publish.assert_called_once()
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_hosting_moderation_short_circuits(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=True)
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_model_schema_missing(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = None
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "expected_runner"),
|
||||
[
|
||||
(LLMMode.CHAT, "CotChatAgentRunner"),
|
||||
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
|
||||
],
|
||||
)
|
||||
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: mode}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
runner_instance.run.assert_called_once()
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_invalid_llm_mode_raises(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = [ModelFeature.TOOL_CALL]
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
def test_run_conversation_not_found(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, None],
|
||||
)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_message_not_found(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
|
||||
)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentChatAppGenerateResponseConverterBlocking:
|
||||
def test_convert_blocking_full_response(self):
|
||||
blocking = ChatbotAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={"a": 1},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert result["event"] == "message"
|
||||
assert result["answer"] == "answer"
|
||||
assert result["metadata"] == {"a": 1}
|
||||
|
||||
def test_convert_blocking_simple_response_with_dict_metadata(self):
|
||||
blocking = ChatbotAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
"usage": {"prompt_tokens": 1},
|
||||
},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
|
||||
def test_convert_blocking_simple_response_with_non_dict_metadata(self):
|
||||
blocking = ChatbotAppBlockingResponse.model_construct(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data.model_construct(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata="bad",
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert result["metadata"] == {}
|
||||
|
||||
|
||||
class TestAgentChatAppGenerateResponseConverterStream:
|
||||
def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
def _gen():
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=2,
|
||||
stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=3,
|
||||
stream_response=MessageEndStreamResponse(
|
||||
task_id="t",
|
||||
id="m1",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
"summary": "summary",
|
||||
"extra": "ignored",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
"usage": {"prompt_tokens": 1},
|
||||
},
|
||||
),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=4,
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")),
|
||||
)
|
||||
|
||||
return _gen()
|
||||
|
||||
def test_convert_stream_full_response(self):
|
||||
items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream()))
|
||||
assert items[0] == "ping"
|
||||
assert items[1]["event"] == "message"
|
||||
assert "answer" in items[1]
|
||||
assert items[2]["event"] == "message_end"
|
||||
assert items[3]["event"] == "error"
|
||||
|
||||
def test_convert_stream_simple_response(self):
|
||||
items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream()))
|
||||
assert items[0] == "ping"
|
||||
# Assert the message event structure and content at items[1]
|
||||
assert items[1]["event"] == "message"
|
||||
assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"]
|
||||
assert items[2]["event"] == "message_end"
|
||||
assert "metadata" in items[2]
|
||||
metadata = items[2]["metadata"]
|
||||
assert "annotation_reply" not in metadata
|
||||
assert "usage" not in metadata
|
||||
assert metadata["retriever_resources"] == [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
"summary": "summary",
|
||||
}
|
||||
]
|
||||
assert items[3]["event"] == "error"
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestChatAppConfigManager:
|
||||
def test_get_app_config_uses_override_dict(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value)
|
||||
app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"})
|
||||
override = {"model": "override"}
|
||||
|
||||
model_entity = ModelConfigEntity(provider="p", model="m")
|
||||
prompt_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="hi",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None),
|
||||
patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])),
|
||||
):
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=override,
|
||||
)
|
||||
|
||||
assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert app_config.app_model_config_dict == override
|
||||
assert app_config.app_mode == AppMode.CHAT
|
||||
|
||||
def test_config_validate_filters_related_keys(self):
|
||||
config = {"extra": 1}
|
||||
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
config = args[-1]
|
||||
config = {**config, key: value}
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("model", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("inputs", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 3),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("prompt", 4),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("dataset", 5),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("opening_statement", 6),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("suggested_questions_after_answer", 7),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("speech_to_text", 8),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 9),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("retriever_resource", 10),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 11),
|
||||
),
|
||||
):
|
||||
filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config)
|
||||
|
||||
assert filtered["model"] == 1
|
||||
assert filtered["inputs"] == 2
|
||||
assert filtered["file_upload"] == 3
|
||||
assert filtered["prompt"] == 4
|
||||
assert filtered["dataset"] == 5
|
||||
assert filtered["opening_statement"] == 6
|
||||
assert filtered["suggested_questions_after_answer"] == 7
|
||||
assert filtered["speech_to_text"] == 8
|
||||
assert filtered["text_to_speech"] == 9
|
||||
assert filtered["retriever_resource"] == 10
|
||||
assert filtered["sensitive_word_avoidance"] == 11
|
||||
|
|
@ -0,0 +1,280 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
class DummyQueueManager:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.published = []
|
||||
|
||||
def publish_error(self, error, pub_from):
|
||||
self.published.append((error, pub_from))
|
||||
|
||||
def publish(self, event, pub_from):
|
||||
self.published.append((event, pub_from))
|
||||
|
||||
|
||||
class TestChatAppGenerator:
|
||||
def test_generate_requires_query(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_rejects_non_string_query(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
args={"query": 1, "inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_debugger_overrides_model_config(self):
|
||||
generator = ChatAppGenerator()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
user = SimpleNamespace(id="user-1", session_id="session-1")
|
||||
args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}}
|
||||
|
||||
with (
|
||||
patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppConfigManager.config_validate", return_value={"x": 1}),
|
||||
patch(
|
||||
"core.app.apps.chat.app_generator.ChatAppConfigManager.get_app_config",
|
||||
return_value=SimpleNamespace(
|
||||
variables=[], external_data_variables=[], app_model_config_dict={}, app_mode=AppMode.CHAT
|
||||
),
|
||||
),
|
||||
patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None),
|
||||
patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity),
|
||||
patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager),
|
||||
patch(
|
||||
"core.app.apps.chat.app_generator.ChatAppGenerateResponseConverter.convert", return_value={"ok": True}
|
||||
),
|
||||
patch.object(ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})),
|
||||
patch.object(ChatAppGenerator, "_prepare_user_inputs", return_value={}),
|
||||
patch.object(
|
||||
ChatAppGenerator,
|
||||
"_init_generate_records",
|
||||
return_value=(SimpleNamespace(id="c1", mode="chat"), SimpleNamespace(id="m1")),
|
||||
),
|
||||
patch.object(ChatAppGenerator, "_handle_response", return_value={"response": True}),
|
||||
patch("core.app.apps.chat.app_generator.copy_current_request_context", side_effect=lambda f: f),
|
||||
patch("core.app.apps.chat.app_generator.threading.Thread") as mock_thread,
|
||||
):
|
||||
mock_thread.return_value.start.return_value = None
|
||||
result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_generate_rejects_model_config_override_for_non_debugger(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
with (
|
||||
patch.object(
|
||||
ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})
|
||||
),
|
||||
):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(tenant_id="t1", id="a1", mode=AppMode.CHAT.value),
|
||||
user=SimpleNamespace(id="u1", session_id="s1"),
|
||||
args={"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_worker_handles_exceptions(self):
|
||||
generator = ChatAppGenerator()
|
||||
queue_manager = DummyQueueManager()
|
||||
entity = DummyGenerateEntity(task_id="t1", user_id="u1")
|
||||
|
||||
with (
|
||||
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
|
||||
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=InvokeAuthorizationError()),
|
||||
patch("core.app.apps.chat.app_generator.db.session.close"),
|
||||
):
|
||||
generator._generate_worker(
|
||||
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
|
||||
application_generate_entity=entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
assert queue_manager.published
|
||||
|
||||
with (
|
||||
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
|
||||
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=GenerateTaskStoppedError()),
|
||||
patch("core.app.apps.chat.app_generator.db.session.close"),
|
||||
):
|
||||
generator._generate_worker(
|
||||
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
|
||||
application_generate_entity=entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
class TestChatAppRunner:
|
||||
def test_run_raises_when_app_missing(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1", tenant_id="tenant-1", prompt_template=None, external_data_variables=[]
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
def test_run_moderation_error_direct_output(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
|
||||
patch.object(ChatAppRunner, "direct_output") as mock_direct,
|
||||
):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
mock_direct.assert_called_once()
|
||||
|
||||
def test_run_annotation_reply_short_circuits(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
annotation = SimpleNamespace(id="ann-1", content="answer")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
|
||||
patch.object(ChatAppRunner, "direct_output") as mock_direct,
|
||||
):
|
||||
queue_manager = DummyQueueManager()
|
||||
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
assert any(isinstance(item[0], QueueAnnotationReplyEvent) for item in queue_manager.published)
|
||||
mock_direct.assert_called_once()
|
||||
|
||||
def test_run_returns_when_hosting_moderation_blocks(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
|
||||
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
|
||||
):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestChatAppGenerateResponseConverter:
|
||||
def test_convert_blocking_simple_response_metadata(self):
|
||||
data = ChatbotAppBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
answer="hi",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
)
|
||||
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
|
||||
|
||||
response = ChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_convert_stream_responses(self):
|
||||
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t1"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageStreamResponse(task_id="t1", id="m1", answer="hi"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
|
||||
)
|
||||
|
||||
full = list(ChatAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
assert full[0] == "ping"
|
||||
assert full[1]["event"] == "message"
|
||||
assert full[2]["event"] == "error"
|
||||
|
||||
simple = list(ChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert simple[0] == "ping"
|
||||
assert simple[-1]["event"] == "message_end"
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.completion.app_runner as module
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CompletionAppRunner()
|
||||
|
||||
|
||||
def _build_app_config(dataset=None, external_tools=None, additional_features=None):
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app1"
|
||||
app_config.tenant_id = "tenant"
|
||||
app_config.prompt_template = MagicMock()
|
||||
app_config.dataset = dataset
|
||||
app_config.external_data_variables = external_tools or []
|
||||
app_config.additional_features = additional_features
|
||||
app_config.app_model_config_dict = {"file_upload": {"enabled": True}}
|
||||
return app_config
|
||||
|
||||
|
||||
def _build_generate_entity(app_config, file_upload_config=None):
|
||||
model_conf = MagicMock(
|
||||
provider_model_bundle="bundle",
|
||||
model="model",
|
||||
parameters={"max_tokens": 10},
|
||||
stop=["stop"],
|
||||
)
|
||||
return SimpleNamespace(
|
||||
app_config=app_config,
|
||||
model_conf=model_conf,
|
||||
inputs={"qvar": "query_from_input"},
|
||||
query="original_query",
|
||||
files=[],
|
||||
file_upload_config=file_upload_config,
|
||||
stream=True,
|
||||
user_id="user",
|
||||
invoke_from=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionAppRunner:
|
||||
def test_run_app_not_found(self, runner, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock())
|
||||
|
||||
def test_run_moderation_error_outputs_direct(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(side_effect=ModerationError("blocked"))
|
||||
runner.direct_output = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
|
||||
def test_run_hosting_moderation_stops(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
|
||||
def test_run_dataset_and_external_tools_flow(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
session.close = MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
retrieve_config = MagicMock(query_variable="qvar")
|
||||
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
|
||||
additional_features = MagicMock(show_retrieve_source=True)
|
||||
app_config = _build_app_config(
|
||||
dataset=dataset_config,
|
||||
external_tools=["tool"],
|
||||
additional_features=additional_features,
|
||||
)
|
||||
|
||||
file_upload_config = MagicMock()
|
||||
file_upload_config.image_config.detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
app_generate_entity = _build_generate_entity(app_config, file_upload_config=file_upload_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(side_effect=[(["pm1"], ["stop"]), (["pm2"], ["stop"])])
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.fill_in_inputs_from_external_data_tools = MagicMock(return_value=app_generate_entity.inputs)
|
||||
runner.check_hosting_moderation = MagicMock(return_value=False)
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
dataset_retrieval = MagicMock()
|
||||
dataset_retrieval.retrieve.return_value = ("ctx", ["file1"])
|
||||
mocker.patch.object(module, "DatasetRetrieval", return_value=dataset_retrieval)
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.invoke_llm.return_value = "invoke_result"
|
||||
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
|
||||
|
||||
dataset_retrieval.retrieve.assert_called_once()
|
||||
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_uses_low_image_detail_default(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
assert (
|
||||
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]
|
||||
== ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import core.app.apps.completion.app_config_manager as module
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestCompletionAppConfigManager:
|
||||
def test_get_app_config_with_override(self, mocker):
|
||||
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
|
||||
app_model_config = MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
|
||||
override_config = {"model": {"provider": "override"}}
|
||||
|
||||
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
|
||||
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
|
||||
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
|
||||
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
|
||||
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
|
||||
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=(["v1"], ["ext1"]))
|
||||
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_config,
|
||||
)
|
||||
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert result.app_model_config_dict == override_config
|
||||
assert result.variables == ["v1"]
|
||||
assert result.external_data_variables == ["ext1"]
|
||||
assert result.app_mode == AppMode.COMPLETION
|
||||
|
||||
def test_get_app_config_without_override_uses_model_config(self, mocker):
|
||||
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
|
||||
app_model_config = MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
|
||||
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
|
||||
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
|
||||
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
|
||||
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
|
||||
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
|
||||
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=([], []))
|
||||
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = CompletionAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
|
||||
assert result.app_model_config_dict == {"model": {"provider": "x"}}
|
||||
|
||||
def test_config_validate_filters_related_keys(self, mocker):
|
||||
config = {
|
||||
"model": {"provider": "x"},
|
||||
"variables": ["v"],
|
||||
"file_upload": {"enabled": True},
|
||||
"prompt": {"template": "t"},
|
||||
"dataset": {"enabled": True},
|
||||
"tts": {"enabled": True},
|
||||
"more_like_this": {"enabled": True},
|
||||
"moderation": {"enabled": True},
|
||||
"extra": "drop",
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
module.ModelConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["model"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.BasicVariablesConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["variables"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.FileUploadConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["file_upload"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PromptTemplateConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["prompt"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DatasetConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["dataset"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.TextToSpeechConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["tts"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.MoreLikeThisConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["more_like_this"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.SensitiveWordAvoidanceConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["moderation"]),
|
||||
)
|
||||
|
||||
filtered = CompletionAppConfigManager.config_validate("tenant", config)
|
||||
|
||||
assert "extra" not in filtered
|
||||
assert set(filtered.keys()) == {
|
||||
"model",
|
||||
"variables",
|
||||
"file_upload",
|
||||
"prompt",
|
||||
"dataset",
|
||||
"tts",
|
||||
"more_like_this",
|
||||
"moderation",
|
||||
}
|
||||
|
|
@ -0,0 +1,321 @@
|
|||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
import core.app.apps.completion.app_generator as module
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = CompletionAppGenerator()
|
||||
|
||||
mocker.patch.object(module, "copy_current_request_context", side_effect=lambda fn: fn)
|
||||
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "current_app", MagicMock(_get_current_object=MagicMock(return_value=flask_app)))
|
||||
|
||||
thread = MagicMock()
|
||||
mocker.patch.object(module.threading, "Thread", return_value=thread)
|
||||
|
||||
mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock())
|
||||
mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock())
|
||||
mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def _build_app_model():
|
||||
return MagicMock(tenant_id="tenant", id="app1", mode="completion")
|
||||
|
||||
|
||||
def _build_user():
|
||||
return MagicMock(id="user", session_id="session")
|
||||
|
||||
|
||||
def _build_app_model_config():
|
||||
config = MagicMock(id="cfg")
|
||||
config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
return config
|
||||
|
||||
|
||||
class TestCompletionAppGenerator:
|
||||
def test_generate_invalid_query_type(self, generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": 123, "inputs": {}, "files": []},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
def test_generate_override_not_debugger(self, generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {}, "files": [], "model_config": {}},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_success_no_file_config(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings")
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
conversation = MagicMock(id="conv", mode="completion")
|
||||
message = MagicMock(id="msg")
|
||||
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {"a": 1}, "files": []},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
module.file_factory.build_from_mappings.assert_not_called()
|
||||
|
||||
def test_generate_success_with_files(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
|
||||
file_extra_config = MagicMock()
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
conversation = MagicMock(id="conv", mode="completion")
|
||||
message = MagicMock(id="msg")
|
||||
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {"a": 1}, "files": [{"id": "f"}]},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
module.file_factory.build_from_mappings.assert_called_once()
|
||||
|
||||
def test_generate_override_model_config_debugger(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
|
||||
override_config = {"model": {"provider": "override"}}
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "config_validate", return_value=override_config)
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
get_app_config = mocker.patch.object(
|
||||
module.CompletionAppConfigManager,
|
||||
"get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_init_generate_records",
|
||||
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
|
||||
)
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {}, "files": [], "model_config": override_config},
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert get_app_config.call_args.kwargs["override_config_dict"] == override_config
|
||||
|
||||
def test_generate_more_like_this_message_not_found(self, generator, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=_build_app_model(),
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_disabled(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=False, more_like_this_dict={"enabled": False})
|
||||
|
||||
message = MagicMock()
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MoreLikeThisDisabledError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_app_model_config_missing(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = None
|
||||
|
||||
message = MagicMock()
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MoreLikeThisDisabledError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_message_config_none(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
|
||||
|
||||
message = MagicMock(app_model_config=None)
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_success(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
|
||||
|
||||
message = MagicMock()
|
||||
message.message_files = [{"id": "f"}]
|
||||
message.inputs = {"a": 1}
|
||||
message.query = "q"
|
||||
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.to_dict.return_value = {
|
||||
"model": {"completion_params": {"temperature": 0.1}},
|
||||
"file_upload": {"enabled": True},
|
||||
}
|
||||
message.app_model_config = app_model_config
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
file_extra_config = MagicMock()
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
get_app_config = mocker.patch.object(
|
||||
module.CompletionAppConfigManager,
|
||||
"get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_init_generate_records",
|
||||
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
|
||||
)
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
override_dict = get_app_config.call_args.kwargs["override_config_dict"]
|
||||
assert override_dict["model"]["completion_params"]["temperature"] == 0.9
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error", "should_publish"),
|
||||
[
|
||||
(GenerateTaskStoppedError(), False),
|
||||
(InvokeAuthorizationError("bad"), True),
|
||||
(
|
||||
ValidationError.from_exception_data(
|
||||
"Model",
|
||||
[{"type": "missing", "loc": ("x",), "msg": "Field required", "input": {}}],
|
||||
),
|
||||
True,
|
||||
),
|
||||
(ValueError("bad"), True),
|
||||
(RuntimeError("boom"), True),
|
||||
],
|
||||
)
|
||||
def test_generate_worker_error_handling(self, generator, mocker, error, should_publish):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
|
||||
session = mocker.MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(generator, "_get_message", return_value=MagicMock())
|
||||
|
||||
runner_instance = MagicMock()
|
||||
runner_instance.run.side_effect = error
|
||||
mocker.patch.object(module, "CompletionAppRunner", return_value=runner_instance)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert queue_manager.publish_error.called is should_publish
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
CompletionAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionAppGenerateResponseConverter:
|
||||
def test_convert_blocking_full_response(self):
|
||||
blocking = CompletionAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={"k": "v"},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert result["event"] == "message"
|
||||
assert result["task_id"] == "task"
|
||||
assert result["message_id"] == "msg"
|
||||
assert result["answer"] == "answer"
|
||||
assert result["metadata"] == {"k": "v"}
|
||||
|
||||
def test_convert_blocking_simple_response_metadata_simplified(self):
|
||||
metadata = {
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "c",
|
||||
"summary": "sum",
|
||||
"extra": "x",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"a": 1},
|
||||
"usage": {"t": 2},
|
||||
}
|
||||
blocking = CompletionAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata=metadata,
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
|
||||
assert "extra" not in result["metadata"]["retriever_resources"][0]
|
||||
|
||||
def test_convert_blocking_simple_response_metadata_not_dict(self):
|
||||
data = CompletionAppBlockingResponse.Data.model_construct(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata="bad",
|
||||
created_at=123,
|
||||
)
|
||||
blocking = CompletionAppBlockingResponse.model_construct(task_id="task", data=data)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert result["metadata"] == {}
|
||||
|
||||
def test_convert_stream_full_response(self):
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
message_id="m",
|
||||
created_at=1,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
message_id="m",
|
||||
created_at=2,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=MessageStreamResponse(task_id="t", id="1", answer="ok"),
|
||||
message_id="m",
|
||||
created_at=3,
|
||||
)
|
||||
|
||||
result = list(CompletionAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "error"
|
||||
assert result[1]["code"] == "invalid_param"
|
||||
assert result[2]["event"] == "message"
|
||||
|
||||
def test_convert_stream_simple_response(self):
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
message_id="m",
|
||||
created_at=1,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=MessageEndStreamResponse(
|
||||
task_id="t",
|
||||
id="end",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "c",
|
||||
"summary": "sum",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"a": 1},
|
||||
"usage": {"t": 2},
|
||||
},
|
||||
),
|
||||
message_id="m",
|
||||
created_at=2,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
message_id="m",
|
||||
created_at=3,
|
||||
)
|
||||
|
||||
result = list(CompletionAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "message_end"
|
||||
assert "annotation_reply" not in result[1]["metadata"]
|
||||
assert "usage" not in result[1]["metadata"]
|
||||
assert result[2]["event"] == "error"
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import core.app.apps.pipeline.pipeline_config_manager as module
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_get_pipeline_config(mocker):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe1")
|
||||
workflow = MagicMock(id="wf1")
|
||||
|
||||
mocker.patch.object(
|
||||
module.WorkflowVariablesConfigManager,
|
||||
"convert_rag_pipeline_variable",
|
||||
return_value=["var1"],
|
||||
)
|
||||
mocker.patch.object(module, "PipelineConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow, start_node_id="start")
|
||||
|
||||
assert result.tenant_id == "tenant"
|
||||
assert result.app_id == "pipe1"
|
||||
assert result.workflow_id == "wf1"
|
||||
assert result.app_mode == AppMode.RAG_PIPELINE
|
||||
assert result.rag_pipeline_variables == ["var1"]
|
||||
|
||||
|
||||
def test_config_validate_filters_related_keys(mocker):
|
||||
config = {
|
||||
"file_upload": {"enabled": True},
|
||||
"tts": {"enabled": True},
|
||||
"moderation": {"enabled": True},
|
||||
"extra": "drop",
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
module.FileUploadConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["file_upload"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.TextToSpeechConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["tts"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.SensitiveWordAvoidanceConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["moderation"]),
|
||||
)
|
||||
|
||||
filtered = PipelineConfigManager.config_validate("tenant", config)
|
||||
|
||||
assert set(filtered.keys()) == {"file_upload", "tts", "moderation"}
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
def test_convert_blocking_full_and_simple_response():
|
||||
blocking = WorkflowAppBlockingResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id="id",
|
||||
workflow_id="wf",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"k": "v"},
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=10,
|
||||
total_steps=1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
full = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
simple = WorkflowAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert full == simple
|
||||
assert full["workflow_run_id"] == "run"
|
||||
assert full["data"]["status"] == WorkflowExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_convert_stream_full_response():
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
workflow_run_id="run",
|
||||
)
|
||||
yield WorkflowAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
workflow_run_id="run",
|
||||
)
|
||||
|
||||
result = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "error"
|
||||
assert result[1]["code"] == "invalid_param"
|
||||
|
||||
|
||||
def test_convert_stream_simple_response_node_ignore_details():
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t",
|
||||
workflow_run_id="run",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="nid",
|
||||
node_id="node",
|
||||
node_type="type",
|
||||
title="Title",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"a": 1},
|
||||
inputs_truncated=False,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t",
|
||||
workflow_run_id="run",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="nid",
|
||||
node_id="node",
|
||||
node_type="type",
|
||||
title="Title",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"a": 1},
|
||||
inputs_truncated=False,
|
||||
process_data=None,
|
||||
process_data_truncated=False,
|
||||
outputs={"b": 2},
|
||||
outputs_truncated=False,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
execution_metadata=None,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
files=[],
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(stream_response=node_start, workflow_run_id="run")
|
||||
yield WorkflowAppStreamResponse(stream_response=node_finish, workflow_run_id="run")
|
||||
|
||||
result = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
|
||||
assert result[0]["event"] == "node_started"
|
||||
assert result[0]["data"]["inputs"] is None
|
||||
assert result[1]["event"] == "node_finished"
|
||||
assert result[1]["data"]["inputs"] is None
|
||||
|
|
@ -0,0 +1,699 @@
|
|||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_generator as module
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
|
||||
|
||||
class FakeRagPipelineGenerateEntity(SimpleNamespace):
|
||||
class SingleIterationRunEntity(SimpleNamespace):
|
||||
pass
|
||||
|
||||
class SingleLoopRunEntity(SimpleNamespace):
|
||||
pass
|
||||
|
||||
def model_dump(self):
|
||||
return dict(self.__dict__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = module.PipelineGenerator()
|
||||
|
||||
mocker.patch.object(module, "RagPipelineGenerateEntity", FakeRagPipelineGenerateEntity)
|
||||
mocker.patch.object(module, "RagPipelineInvokeEntity", side_effect=lambda **kwargs: kwargs)
|
||||
mocker.patch.object(module.contexts, "plugin_tool_providers", SimpleNamespace(set=MagicMock()))
|
||||
mocker.patch.object(module.contexts, "plugin_tool_providers_lock", SimpleNamespace(set=MagicMock()))
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def _build_pipeline_dataset():
|
||||
return SimpleNamespace(
|
||||
id="ds",
|
||||
name="dataset",
|
||||
description="desc",
|
||||
chunk_structure="chunk",
|
||||
built_in_field_enabled=True,
|
||||
tenant_id="tenant",
|
||||
)
|
||||
|
||||
|
||||
def _build_pipeline():
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
pipeline.retrieve_dataset.return_value = _build_pipeline_dataset()
|
||||
return pipeline
|
||||
|
||||
|
||||
def _build_workflow():
|
||||
return MagicMock(id="wf", graph_dict={"nodes": [], "edges": []}, tenant_id="tenant")
|
||||
|
||||
|
||||
def _build_user():
|
||||
return MagicMock(id="user", name="User", session_id="session")
|
||||
|
||||
|
||||
def _build_args():
|
||||
return {
|
||||
"inputs": {"k": "v"},
|
||||
"start_node_id": "start",
|
||||
"datasource_type": DatasourceProviderType.LOCAL_FILE.value,
|
||||
"datasource_info_list": [{"name": "file"}],
|
||||
}
|
||||
|
||||
|
||||
def _patch_session(mocker, session):
|
||||
mocker.patch.object(module, "Session", return_value=session)
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
|
||||
def _dummy_preserve(*args, **kwargs):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
class DummySession:
|
||||
def __init__(self):
|
||||
self.scalar = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def test_generate_dataset_missing(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
pipeline.retrieve_dataset.return_value = None
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=_build_workflow(),
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_debugger_calls_generate(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=[{"name": "file"}],
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
def test_generate_published_pipeline_creates_documents_and_delay(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
datasource_info_list = [{"name": "file1"}, {"name": "file2"}]
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=datasource_info_list,
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1)
|
||||
|
||||
document1 = SimpleNamespace(
|
||||
id="doc1",
|
||||
position=1,
|
||||
data_source_type=DatasourceProviderType.LOCAL_FILE,
|
||||
data_source_info="{}",
|
||||
name="file1",
|
||||
indexing_status="",
|
||||
error=None,
|
||||
enabled=True,
|
||||
)
|
||||
document2 = SimpleNamespace(
|
||||
id="doc2",
|
||||
position=2,
|
||||
data_source_type=DatasourceProviderType.LOCAL_FILE,
|
||||
data_source_info="{}",
|
||||
name="file2",
|
||||
indexing_status="",
|
||||
error=None,
|
||||
enabled=True,
|
||||
)
|
||||
mocker.patch.object(generator, "_build_document", side_effect=[document1, document2])
|
||||
|
||||
mocker.patch.object(module, "DocumentPipelineExecutionLog", return_value=MagicMock())
|
||||
|
||||
db_session = MagicMock()
|
||||
mocker.patch.object(module.db, "session", db_session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
task_proxy = MagicMock()
|
||||
mocker.patch.object(module, "RagPipelineTaskProxy", return_value=task_proxy)
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result["batch"]
|
||||
assert len(result["documents"]) == 2
|
||||
task_proxy.delay.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_is_retry_calls_generate(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=[{"name": "file"}],
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||
streaming=True,
|
||||
is_retry=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
def test_generate_worker_handles_errors(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
application_generate_entity = FakeRagPipelineGenerateEntity(
|
||||
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
session = DummySession()
|
||||
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
|
||||
_patch_session(mocker, session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
runner_instance.run.side_effect = ValueError("bad")
|
||||
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
context=contextlib.nullcontext(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
queue_manager.publish_error.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_worker_sets_system_user_id_for_external_call(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
application_generate_entity = FakeRagPipelineGenerateEntity(
|
||||
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
session = DummySession()
|
||||
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
|
||||
_patch_session(mocker, session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
context=contextlib.nullcontext(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
assert module.PipelineRunner.call_args.kwargs["system_user_id"] == "session"
|
||||
|
||||
|
||||
def test_generate_raises_when_workflow_not_found(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=contextlib.nullcontext(),
|
||||
pipeline=_build_pipeline(),
|
||||
workflow_id="wf",
|
||||
user=_build_user(),
|
||||
application_generate_entity=FakeRagPipelineGenerateEntity(
|
||||
task_id="t",
|
||||
app_config=SimpleNamespace(app_id="pipe"),
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_success_returns_converted(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
mocker.patch.object(module, "PipelineQueueManager", return_value=queue_manager)
|
||||
|
||||
worker_thread = MagicMock()
|
||||
mocker.patch.object(module.threading, "Thread", return_value=worker_thread)
|
||||
|
||||
mocker.patch.object(generator, "_get_draft_var_saver_factory", return_value=MagicMock())
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.WorkflowAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=contextlib.nullcontext(),
|
||||
pipeline=_build_pipeline(),
|
||||
workflow_id="wf",
|
||||
user=_build_user(),
|
||||
application_generate_entity=FakeRagPipelineGenerateEntity(
|
||||
task_id="t",
|
||||
app_config=SimpleNamespace(app_id="pipe"),
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
|
||||
|
||||
def test_single_iteration_generate_validates_inputs(generator, mocker):
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(_build_pipeline(), _build_workflow(), "", _build_user(), {})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(
|
||||
_build_pipeline(), _build_workflow(), "node", _build_user(), {"inputs": None}
|
||||
)
|
||||
|
||||
|
||||
def test_single_iteration_generate_dataset_required(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
pipeline.retrieve_dataset.return_value = None
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
)
|
||||
|
||||
|
||||
def test_single_iteration_generate_success(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
|
||||
|
||||
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
|
||||
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"ok": True})
|
||||
|
||||
result = generator.single_iteration_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
|
||||
def test_single_loop_generate_success(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
|
||||
|
||||
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
|
||||
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"ok": True})
|
||||
|
||||
result = generator.single_loop_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
|
||||
def test_handle_response_value_error_triggers_generate_task_stopped(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
app_entity = FakeRagPipelineGenerateEntity(task_id="t")
|
||||
|
||||
task_pipeline = MagicMock()
|
||||
task_pipeline.process.side_effect = ValueError("I/O operation on closed file.")
|
||||
mocker.patch.object(module, "WorkflowAppGenerateTaskPipeline", return_value=task_pipeline)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_response(
|
||||
application_generate_entity=app_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=MagicMock(),
|
||||
user=_build_user(),
|
||||
draft_var_saver_factory=MagicMock(),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_build_document_sets_metadata_for_builtin_fields(generator, mocker):
|
||||
class DummyDocument(SimpleNamespace):
|
||||
pass
|
||||
|
||||
mocker.patch.object(module, "Document", side_effect=lambda **kwargs: DummyDocument(**kwargs))
|
||||
|
||||
document = generator._build_document(
|
||||
tenant_id="tenant",
|
||||
dataset_id="ds",
|
||||
built_in_field_enabled=True,
|
||||
datasource_type=DatasourceProviderType.LOCAL_FILE,
|
||||
datasource_info={"name": "file"},
|
||||
created_from="rag-pipeline",
|
||||
position=1,
|
||||
account=_build_user(),
|
||||
batch="batch",
|
||||
document_form="text",
|
||||
)
|
||||
|
||||
assert document.name == "file"
|
||||
assert document.doc_metadata
|
||||
|
||||
|
||||
def test_build_document_invalid_datasource_type(generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator._build_document(
|
||||
tenant_id="tenant",
|
||||
dataset_id="ds",
|
||||
built_in_field_enabled=False,
|
||||
datasource_type="invalid",
|
||||
datasource_info={},
|
||||
created_from="rag-pipeline",
|
||||
position=1,
|
||||
account=_build_user(),
|
||||
batch="batch",
|
||||
document_form="text",
|
||||
)
|
||||
|
||||
|
||||
def test_format_datasource_info_list_non_online_drive(generator):
|
||||
result = generator._format_datasource_info_list(
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
[{"name": "file"}],
|
||||
_build_pipeline(),
|
||||
_build_workflow(),
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
assert result == [{"name": "file"}]
|
||||
|
||||
|
||||
def test_format_datasource_info_list_missing_node_data(generator):
|
||||
workflow = MagicMock(graph_dict={"nodes": []})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator._format_datasource_info_list(
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
[],
|
||||
_build_pipeline(),
|
||||
workflow,
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
|
||||
def test_format_datasource_info_list_online_drive_folder(generator, mocker):
|
||||
workflow = MagicMock(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"data": {
|
||||
"plugin_id": "p",
|
||||
"provider_name": "provider",
|
||||
"datasource_name": "drive",
|
||||
"credential_id": "cred",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.runtime = SimpleNamespace(credentials=None)
|
||||
runtime.datasource_provider_type.return_value = DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=runtime,
|
||||
)
|
||||
mocker.patch.object(module.DatasourceProviderService, "get_datasource_credentials", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_get_files_in_folder",
|
||||
side_effect=lambda *args, **kwargs: args[4].append({"id": "f"}),
|
||||
)
|
||||
|
||||
result = generator._format_datasource_info_list(
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
[{"id": "folder", "type": "folder", "name": "Folder", "bucket": "b"}],
|
||||
_build_pipeline(),
|
||||
workflow,
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
assert result == [{"id": "f"}]
|
||||
|
||||
|
||||
def test_get_files_in_folder_recurses_and_collects(generator):
|
||||
class File:
|
||||
def __init__(self, id, name, type):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.type = type
|
||||
|
||||
class FilesPage:
|
||||
def __init__(self, files, is_truncated=False, next_page_parameters=None):
|
||||
self.files = files
|
||||
self.is_truncated = is_truncated
|
||||
self.next_page_parameters = next_page_parameters
|
||||
|
||||
class Result:
|
||||
def __init__(self, result):
|
||||
self.result = result
|
||||
|
||||
class Runtime:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def datasource_provider_type(self):
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
def online_drive_browse_files(self, user_id, request, provider_type):
|
||||
self.calls.append(request.next_page_parameters)
|
||||
if request.prefix == "fd":
|
||||
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
|
||||
if request.next_page_parameters is None:
|
||||
return iter(
|
||||
[
|
||||
Result(
|
||||
[FilesPage([File("f1", "file", "file"), File("fd", "folder", "folder")], True, {"page": 2})]
|
||||
)
|
||||
]
|
||||
)
|
||||
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
|
||||
|
||||
runtime = Runtime()
|
||||
all_files = []
|
||||
|
||||
generator._get_files_in_folder(
|
||||
datasource_runtime=runtime,
|
||||
prefix="root",
|
||||
bucket="b",
|
||||
user_id="user",
|
||||
all_files=all_files,
|
||||
datasource_info={},
|
||||
)
|
||||
|
||||
assert {f["id"] for f in all_files} == {"f1", "f2"}
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_queue_manager as module
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
|
||||
|
||||
def test_publish_sets_stop_listen_and_raises_on_stopped(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=True)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
|
||||
def test_publish_stop_events_trigger_stop_listen(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=False)
|
||||
|
||||
for event in [
|
||||
QueueErrorEvent(error=ValueError("bad")),
|
||||
QueueMessageEndEvent(llm_result=LLMResult.model_construct()),
|
||||
QueueWorkflowSucceededEvent(),
|
||||
QueueWorkflowFailedEvent(error="failed", exceptions_count=1),
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=1),
|
||||
]:
|
||||
manager.stop_listen.reset_mock()
|
||||
manager._publish(event, PublishFrom.TASK_PIPELINE)
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
|
||||
def test_publish_non_stop_event_no_stop_listen(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=False)
|
||||
|
||||
non_stop_event = mocker.MagicMock(spec=module.AppQueueEvent)
|
||||
manager._publish(non_stop_event, PublishFrom.TASK_PIPELINE)
|
||||
manager.stop_listen.assert_not_called()
|
||||
|
|
@ -0,0 +1,297 @@
|
|||
"""
|
||||
Unit tests for PipelineRunner behavior.
|
||||
Asserts correct event handling, error propagation, and user invocation logic.
|
||||
Primary collaborators: PipelineRunner, InvokeFrom, GraphRunFailedEvent, UserFrom, and mocked dependencies.
|
||||
Cross-references: core.app.apps.pipeline.pipeline_runner, core.app.entities.app_invoke_entities.
|
||||
"""
|
||||
|
||||
"""Unit tests for PipelineRunner behavior.
|
||||
|
||||
This module validates core control-flow outcomes for
|
||||
``core.app.apps.pipeline.pipeline_runner``: app/workflow lookup, graph
|
||||
initialization guards, invoke-source to user-source resolution, and failed-run
|
||||
event handling. Invariants asserted here include strict graph-config
|
||||
validation, correct ``InvokeFrom`` to ``UserFrom`` mapping, and publishing
|
||||
error paths driven by ``GraphRunFailedEvent`` through mocked collaborators.
|
||||
Primary collaborators include ``PipelineRunner``,
|
||||
``core.app.entities.app_invoke_entities.InvokeFrom``, ``GraphRunFailedEvent``,
|
||||
``UserFrom``, and patched DB/runtime dependencies used by the runner.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_runner as module
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from dify_graph.graph_events import GraphRunFailedEvent
|
||||
|
||||
|
||||
def _build_app_generate_entity() -> SimpleNamespace:
|
||||
app_config = SimpleNamespace(app_id="pipe", workflow_id="wf", tenant_id="tenant")
|
||||
return SimpleNamespace(
|
||||
app_config=app_config,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
trace_manager=MagicMock(),
|
||||
inputs={"input1": "v1"},
|
||||
files=[],
|
||||
workflow_execution_id="run",
|
||||
document_id="doc",
|
||||
original_document_id=None,
|
||||
batch="batch",
|
||||
dataset_id="ds",
|
||||
datasource_type="local_file",
|
||||
datasource_info={"name": "file"},
|
||||
start_node_id="start",
|
||||
call_depth=0,
|
||||
single_iteration_run=None,
|
||||
single_loop_run=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
queue_manager = MagicMock()
|
||||
variable_loader = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow_execution_repository = MagicMock()
|
||||
workflow_node_execution_repository = MagicMock()
|
||||
|
||||
return PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
|
||||
def test_get_app_id(runner):
|
||||
assert runner._get_app_id() == "pipe"
|
||||
|
||||
|
||||
def test_get_workflow_returns_workflow(mocker, runner):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
workflow = MagicMock(id="wf")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query)))
|
||||
|
||||
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
|
||||
def test_init_rag_pipeline_graph_invalid_config(mocker, runner):
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
workflow.graph_dict = {"nodes": "bad", "edges": []}
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
workflow.graph_dict = {"nodes": [], "edges": "bad"}
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
|
||||
def test_init_rag_pipeline_graph_not_found(mocker, runner):
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={"nodes": [], "edges": []})
|
||||
mocker.patch.object(module.Graph, "init", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
|
||||
def test_update_document_status_on_failure(mocker, runner):
|
||||
document = MagicMock()
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = document
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
event = GraphRunFailedEvent(error="boom")
|
||||
|
||||
runner._update_document_status(event, document_id="doc", dataset_id="ds")
|
||||
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error == "boom"
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_run_pipeline_not_found(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
app_generate_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
app_generate_entity.single_iteration_run = None
|
||||
app_generate_entity.single_loop_run = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run()
|
||||
|
||||
|
||||
def test_run_workflow_not_initialized(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_pipeline
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
runner.get_workflow = MagicMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run()
|
||||
|
||||
|
||||
def test_run_single_iteration_path(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
app_generate_entity.single_iteration_run = MagicMock()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
query_end_user = MagicMock()
|
||||
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = [query_end_user, query_pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
|
||||
runner.get_workflow = MagicMock(
|
||||
return_value=MagicMock(
|
||||
id="wf",
|
||||
tenant_id="tenant",
|
||||
app_id="pipe",
|
||||
graph_dict={},
|
||||
type="rag-pipeline",
|
||||
version="v1",
|
||||
)
|
||||
)
|
||||
runner._prepare_single_node_execution = MagicMock(return_value=("graph", "pool", "state"))
|
||||
runner._update_document_status = MagicMock()
|
||||
runner._handle_event = MagicMock()
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
workflow_entry.graph_engine = MagicMock()
|
||||
workflow_entry.run.return_value = [MagicMock()]
|
||||
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
|
||||
|
||||
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
|
||||
|
||||
runner.run()
|
||||
|
||||
runner._prepare_single_node_execution.assert_called_once()
|
||||
runner._handle_event.assert_called()
|
||||
|
||||
|
||||
def test_run_normal_path_builds_graph(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
query_end_user = MagicMock()
|
||||
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = [query_end_user, query_pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
workflow = MagicMock(
|
||||
id="wf",
|
||||
tenant_id="tenant",
|
||||
app_id="pipe",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
environment_variables=[],
|
||||
rag_pipeline_variables=[{"variable": "input1", "belong_to_node_id": "start"}],
|
||||
type="rag-pipeline",
|
||||
version="v1",
|
||||
)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
|
||||
runner.get_workflow = MagicMock(return_value=workflow)
|
||||
runner._init_rag_pipeline_graph = MagicMock(return_value="graph")
|
||||
runner._update_document_status = MagicMock()
|
||||
runner._handle_event = MagicMock()
|
||||
|
||||
mocker.patch.object(
|
||||
module.RAGPipelineVariable,
|
||||
"model_validate",
|
||||
return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"),
|
||||
)
|
||||
mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
workflow_entry.graph_engine = MagicMock()
|
||||
workflow_entry.run.return_value = []
|
||||
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
|
||||
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
|
||||
|
||||
runner.run()
|
||||
|
||||
runner._init_rag_pipeline_graph.assert_called_once()
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
|
|
@ -366,3 +368,132 @@ def test_validate_inputs_optional_file_with_empty_string_ignores_default():
|
|||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBaseAppGeneratorExtras:
|
||||
def test_prepare_user_inputs_converts_files_and_lists(self, monkeypatch):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="file",
|
||||
label="file",
|
||||
type=VariableEntityType.FILE,
|
||||
required=False,
|
||||
allowed_file_types=[],
|
||||
allowed_file_extensions=[],
|
||||
allowed_file_upload_methods=[],
|
||||
),
|
||||
VariableEntity(
|
||||
variable="file_list",
|
||||
label="file_list",
|
||||
type=VariableEntityType.FILE_LIST,
|
||||
required=False,
|
||||
allowed_file_types=[],
|
||||
allowed_file_extensions=[],
|
||||
allowed_file_upload_methods=[],
|
||||
),
|
||||
VariableEntity(
|
||||
variable="json",
|
||||
label="json",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_generator.file_factory.build_from_mapping",
|
||||
lambda mapping, tenant_id, config, strict_type_validation=False: "file-object",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_generator.file_factory.build_from_mappings",
|
||||
lambda mappings, tenant_id, config: ["file-1", "file-2"],
|
||||
)
|
||||
|
||||
user_inputs = {
|
||||
"file": {"id": "file-id"},
|
||||
"file_list": [{"id": "file-1"}, {"id": "file-2"}],
|
||||
"json": {"key": "value"},
|
||||
}
|
||||
|
||||
prepared = base_app_generator._prepare_user_inputs(
|
||||
user_inputs=user_inputs,
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
assert prepared["file"] == "file-object"
|
||||
assert prepared["file_list"] == ["file-1", "file-2"]
|
||||
assert prepared["json"] == {"key": "value"}
|
||||
|
||||
def test_prepare_user_inputs_rejects_invalid_dict_inputs(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="text",
|
||||
label="text",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
base_app_generator._prepare_user_inputs(
|
||||
user_inputs={"text": {"unexpected": "dict"}},
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_prepare_user_inputs_rejects_invalid_list_inputs(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="text",
|
||||
label="text",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
base_app_generator._prepare_user_inputs(
|
||||
user_inputs={"text": [{"unexpected": "dict"}]},
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_convert_to_event_stream(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
assert base_app_generator.convert_to_event_stream({"ok": True}) == {"ok": True}
|
||||
|
||||
def _gen():
|
||||
yield {"delta": "hi"}
|
||||
yield "ping"
|
||||
|
||||
converted = list(base_app_generator.convert_to_event_stream(_gen()))
|
||||
|
||||
assert converted[0].startswith("data: ")
|
||||
assert "\n\n" in converted[0]
|
||||
assert converted[1] == "event: ping\n\n"
|
||||
|
||||
def test_get_draft_var_saver_factory_debugger(self):
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.enums import NodeType
|
||||
from models import Account
|
||||
|
||||
base_app_generator = BaseAppGenerator()
|
||||
account = Account(name="Tester", email="tester@example.com")
|
||||
account.id = "account-id"
|
||||
account.tenant_id = "tenant-id"
|
||||
|
||||
factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account)
|
||||
saver = factory(
|
||||
session=MagicMock(),
|
||||
app_id="app-id",
|
||||
node_id="node-id",
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="node-exec-id",
|
||||
)
|
||||
|
||||
assert saver is not None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueErrorEvent
|
||||
|
||||
|
||||
class DummyQueueManager(AppQueueManager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.published = []
|
||||
|
||||
def _publish(self, event, pub_from):
|
||||
self.published.append((event, pub_from))
|
||||
|
||||
|
||||
class TestBaseAppQueueManager:
|
||||
def test_init_requires_user_id(self):
|
||||
with pytest.raises(ValueError):
|
||||
DummyQueueManager(task_id="t1", user_id="", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
def test_publish_error_records_event(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
manager.publish_error(ValueError("boom"), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
assert isinstance(manager.published[0][0], QueueErrorEvent)
|
||||
|
||||
def test_set_stop_flag_checks_user(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = b"end-user-u1"
|
||||
AppQueueManager.set_stop_flag(task_id="t1", invoke_from=InvokeFrom.SERVICE_API, user_id="u1")
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_set_stop_flag_no_user_check(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id="t1")
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_is_stopped_reads_cache(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
mock_redis.get.return_value = b"1"
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
assert manager._is_stopped() is True
|
||||
|
||||
def test_check_for_sqlalchemy_models_raises(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
bad = SimpleNamespace(_sa_instance_state=True)
|
||||
with pytest.raises(TypeError):
|
||||
manager._check_for_sqlalchemy_models(bad)
|
||||
|
|
@ -0,0 +1,442 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class _DummyParameterRule:
|
||||
def __init__(self, name: str, use_template: str | None = None) -> None:
|
||||
self.name = name
|
||||
self.use_template = use_template
|
||||
|
||||
|
||||
class _QueueRecorder:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[object] = []
|
||||
|
||||
def publish(self, event, pub_from):
|
||||
_ = pub_from
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
class TestAppRunner:
|
||||
def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
model_schema = SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
|
||||
parameter_rules=[_DummyParameterRule("max_tokens")],
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
provider_model_bundle=object(),
|
||||
model="mock",
|
||||
model_schema=model_schema,
|
||||
parameters={"max_tokens": 30},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ModelInstance",
|
||||
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 80),
|
||||
)
|
||||
|
||||
runner.recalc_llm_max_tokens(model_config, prompt_messages=[AssistantPromptMessage(content="hi")])
|
||||
|
||||
assert model_config.parameters["max_tokens"] == 20
|
||||
|
||||
def test_recalc_llm_max_tokens_returns_minus_one_when_no_context(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
model_schema = SimpleNamespace(
|
||||
model_properties={},
|
||||
parameter_rules=[_DummyParameterRule("max_tokens")],
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
provider_model_bundle=object(),
|
||||
model="mock",
|
||||
model_schema=model_schema,
|
||||
parameters={"max_tokens": 30},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ModelInstance",
|
||||
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 10),
|
||||
)
|
||||
|
||||
assert runner.recalc_llm_max_tokens(model_config, prompt_messages=[]) == -1
|
||||
|
||||
def test_direct_output_streaming_publishes_chunks_and_end(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
app_generate_entity = SimpleNamespace(model_conf=SimpleNamespace(model="mock"), stream=True)
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.time.sleep", lambda _: None)
|
||||
|
||||
runner.direct_output(
|
||||
queue_manager=queue,
|
||||
app_generate_entity=app_generate_entity,
|
||||
prompt_messages=[],
|
||||
text="hi",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert any(isinstance(event, QueueLLMChunkEvent) for event in queue.events)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
|
||||
def test_handle_invoke_result_direct_publishes_end_event(self):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
llm_result = LLMResult(
|
||||
model="mock",
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content="done"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=llm_result,
|
||||
queue_manager=queue,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
|
||||
def test_handle_invoke_result_invalid_type_raises(self):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=["unexpected"],
|
||||
queue_manager=queue,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def test_organize_prompt_messages_simple_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="chat", stop=["STOP"])
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="hello",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.SimplePromptTransform.get_prompt",
|
||||
lambda self, **kwargs: (["simple-message"], ["simple-stop"]),
|
||||
)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["simple-message"]
|
||||
assert stop == ["simple-stop"]
|
||||
|
||||
def test_organize_prompt_messages_advanced_completion_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="completion", stop=["<END>"])
|
||||
captured: dict[str, object] = {}
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="answer",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="U", assistant="A"),
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_advanced_prompt(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["advanced-completion-message"]
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["advanced-completion-message"]
|
||||
assert stop == ["<END>"]
|
||||
memory_config = captured["memory_config"]
|
||||
assert memory_config.role_prefix.user == "U"
|
||||
assert memory_config.role_prefix.assistant == "A"
|
||||
|
||||
def test_organize_prompt_messages_advanced_chat_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="chat", stop=["<END>"])
|
||||
captured: dict[str, object] = {}
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(text="hello", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="world", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_advanced_prompt(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["advanced-chat-message"]
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["advanced-chat-message"]
|
||||
assert stop == ["<END>"]
|
||||
assert len(captured["prompt_template"]) == 2
|
||||
|
||||
def test_organize_prompt_messages_advanced_missing_templates_raise(self):
|
||||
runner = AppRunner()
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Advanced completion prompt template is required"):
|
||||
runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=SimpleNamespace(mode="completion", stop=[]),
|
||||
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Advanced chat prompt template is required"):
|
||||
runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=SimpleNamespace(mode="chat", stop=[]),
|
||||
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
def test_handle_invoke_result_stream_routes_chunks_and_builds_message(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
warning_logger = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner._logger.warning", warning_logger)
|
||||
|
||||
image_content = ImagePromptMessageContent(
|
||||
url="https://example.com/image.png", format="png", mime_type="image/png"
|
||||
)
|
||||
|
||||
def _stream():
|
||||
yield LLMResultChunk(
|
||||
model="stream-model",
|
||||
prompt_messages=[AssistantPromptMessage(content="prompt")],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage.model_construct(
|
||||
content=[
|
||||
"a",
|
||||
TextPromptMessageContent(data="b"),
|
||||
SimpleNamespace(data="c"),
|
||||
image_content,
|
||||
]
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=_stream(),
|
||||
queue_manager=queue,
|
||||
stream=True,
|
||||
agent=False,
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[0], QueueLLMChunkEvent)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
assert queue.events[-1].llm_result.message.content == "abc"
|
||||
warning_logger.assert_called_once()
|
||||
|
||||
def test_handle_invoke_result_stream_agent_mode_handles_multimodal_errors(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
exception_logger = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner._logger.exception", exception_logger)
|
||||
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_handle_multimodal_image_content",
|
||||
MagicMock(side_effect=RuntimeError("failed to save image")),
|
||||
)
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
def _stream():
|
||||
yield LLMResultChunk(
|
||||
model="agent-model",
|
||||
prompt_messages=[AssistantPromptMessage(content="prompt")],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/image.png",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
TextPromptMessageContent(data="done"),
|
||||
]
|
||||
),
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result_stream(
|
||||
invoke_result=_stream(),
|
||||
queue_manager=queue,
|
||||
agent=True,
|
||||
message_id="message-id",
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[0], QueueAgentMessageEvent)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
assert queue.events[-1].llm_result.usage == usage
|
||||
exception_logger.assert_called_once()
|
||||
|
||||
def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
class _ToggleBool:
|
||||
def __init__(self, values: list[bool]):
|
||||
self._values = values
|
||||
self._index = 0
|
||||
|
||||
def __bool__(self):
|
||||
value = self._values[min(self._index, len(self._values) - 1)]
|
||||
self._index += 1
|
||||
return value
|
||||
|
||||
content = SimpleNamespace(
|
||||
url=_ToggleBool([False, False]),
|
||||
base64_data=_ToggleBool([True, False]),
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock(), refresh=MagicMock())
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.ToolFileManager", lambda: MagicMock())
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.db", SimpleNamespace(session=db_session))
|
||||
|
||||
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.SERVICE_API, publish=MagicMock())
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id="message-id",
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
|
||||
db_session.add.assert_not_called()
|
||||
queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_check_hosting_moderation_direct_output_called(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
app_generate_entity = SimpleNamespace(stream=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.HostingModerationFeature.check",
|
||||
lambda self, application_generate_entity, prompt_messages: True,
|
||||
)
|
||||
direct_output = MagicMock()
|
||||
monkeypatch.setattr(runner, "direct_output", direct_output)
|
||||
|
||||
result = runner.check_hosting_moderation(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=queue,
|
||||
prompt_messages=[],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert direct_output.called
|
||||
|
||||
def test_fill_in_inputs_from_external_data_tools(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ExternalDataFetch.fetch",
|
||||
lambda self, tenant_id, app_id, external_data_tools, inputs, query: {"foo": "bar"},
|
||||
)
|
||||
|
||||
result = runner.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
external_data_tools=[],
|
||||
inputs={},
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert result == {"foo": "bar"}
|
||||
|
||||
def test_moderation_for_inputs_returns_result(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.InputModeration.check",
|
||||
lambda self, app_id, tenant_id, app_config, inputs, query, message_id, trace_manager: (True, {}, ""),
|
||||
)
|
||||
app_generate_entity = SimpleNamespace(app_config=SimpleNamespace(), trace_manager=None)
|
||||
|
||||
result = runner.moderation_for_inputs(
|
||||
app_id="app",
|
||||
tenant_id="tenant",
|
||||
app_generate_entity=app_generate_entity,
|
||||
inputs={},
|
||||
query="q",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert result == (True, {}, "")
|
||||
|
||||
def test_query_app_annotations_to_reply(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.AnnotationReplyFeature.query",
|
||||
lambda self, app_record, message, query, user_id, invoke_from: "reply",
|
||||
)
|
||||
|
||||
response = runner.query_app_annotations_to_reply(
|
||||
app_record=SimpleNamespace(),
|
||||
message=SimpleNamespace(),
|
||||
query="hello",
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
assert response == "reply"
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
|
||||
|
||||
class TestAppsExceptions:
|
||||
def test_generate_task_stopped_error(self):
|
||||
err = GenerateTaskStoppedError("stopped")
|
||||
assert str(err) == "stopped"
|
||||
|
|
@ -13,9 +13,11 @@ from core.app.app_config.entities import (
|
|||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from models.model import AppMode, Conversation, Message
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
|
||||
|
||||
class DummyModelConf:
|
||||
|
|
@ -125,3 +127,55 @@ def test_init_generate_records_sets_conversation_fields_for_chat_entity():
|
|||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
|
||||
|
||||
class TestMessageBasedAppGeneratorExtras:
|
||||
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
class _Pipeline:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
_ = kwargs
|
||||
|
||||
def process(self):
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.message_based_app_generator.EasyUIBasedGenerateTaskPipeline",
|
||||
_Pipeline,
|
||||
)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_response(
|
||||
application_generate_entity=_make_chat_generate_entity(_make_app_config(AppMode.CHAT)),
|
||||
queue_manager=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv"),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
user=SimpleNamespace(),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def test_get_app_model_config_requires_valid_config(self, monkeypatch):
|
||||
generator = MessageBasedAppGenerator()
|
||||
app_model = SimpleNamespace(id="app", app_model_config_id=None, app_model_config=None)
|
||||
|
||||
with pytest.raises(AppModelConfigBrokenError):
|
||||
generator._get_app_model_config(app_model, conversation=None)
|
||||
|
||||
conversation = SimpleNamespace(app_model_config_id="missing-id")
|
||||
monkeypatch.setattr(
|
||||
message_based_app_generator, "db", SimpleNamespace(session=SimpleNamespace(scalar=lambda _: None))
|
||||
)
|
||||
|
||||
with pytest.raises(AppModelConfigBrokenError):
|
||||
generator._get_app_model_config(app_model=SimpleNamespace(id="app"), conversation=conversation)
|
||||
|
||||
def test_get_conversation_introduction_handles_missing_inputs(self):
|
||||
app_config = _make_app_config(AppMode.CHAT)
|
||||
app_config.additional_features.opening_statement = "Hello {{name}}"
|
||||
entity = _make_chat_generate_entity(app_config)
|
||||
entity.inputs = {}
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
assert generator._get_conversation_introduction(entity) == "Hello {name}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueErrorEvent, QueueMessageEndEvent, QueueStopEvent
|
||||
|
||||
|
||||
class TestMessageBasedAppQueueManager:
|
||||
def test_publish_stops_on_terminal_events(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager.stop_listen = Mock()
|
||||
manager._is_stopped = Mock(return_value=False)
|
||||
|
||||
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), Mock())
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
def test_publish_raises_when_stopped(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager._is_stopped = Mock(return_value=True)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager._publish(QueueErrorEvent(error=ValueError("boom")), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def test_publish_enqueues_message_end(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager._is_stopped = Mock(return_value=False)
|
||||
manager.stop_listen = Mock()
|
||||
|
||||
manager._publish(QueueMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
assert manager._q.qsize() == 1
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestMessageGenerator:
|
||||
def test_get_response_topic(self):
|
||||
channel = Mock()
|
||||
channel.topic.return_value = "topic"
|
||||
|
||||
with patch("core.app.apps.message_generator.get_pubsub_broadcast_channel", return_value=channel):
|
||||
topic = MessageGenerator.get_response_topic(AppMode.WORKFLOW, "run-1")
|
||||
|
||||
assert topic == "topic"
|
||||
expected_key = MessageGenerator._make_channel_key(AppMode.WORKFLOW, "run-1")
|
||||
channel.topic.assert_called_once_with(expected_key)
|
||||
|
||||
def test_retrieve_events_passes_arguments(self):
|
||||
with (
|
||||
patch("core.app.apps.message_generator.MessageGenerator.get_response_topic", return_value="topic"),
|
||||
patch(
|
||||
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
|
||||
) as mock_stream,
|
||||
):
|
||||
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
|
||||
|
||||
assert events == [{"event": "ping"}]
|
||||
mock_stream.assert_called_once()
|
||||
|
|
@ -6,6 +6,7 @@ import queue
|
|||
import pytest
|
||||
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.streaming_utils import _normalize_terminal_events, stream_topic_events
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models.model import AppMode
|
||||
|
||||
|
|
@ -78,3 +79,30 @@ def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
|
|||
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
|
||||
def test_normalize_terminal_events_defaults():
|
||||
assert _normalize_terminal_events(None) == {
|
||||
StreamEvent.WORKFLOW_FINISHED.value,
|
||||
StreamEvent.WORKFLOW_PAUSED.value,
|
||||
}
|
||||
|
||||
|
||||
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
|
||||
topic = FakeTopic()
|
||||
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
|
||||
|
||||
def fake_time():
|
||||
return times.pop(0)
|
||||
|
||||
monkeypatch.setattr("core.app.apps.streaming_utils.time.time", fake_time)
|
||||
|
||||
generator = stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=10.0,
|
||||
ping_interval=1.0,
|
||||
)
|
||||
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
# next receive yields None -> ping interval triggers
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
|
|
|
|||
|
|
@ -0,0 +1,261 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
|
||||
|
||||
class TestWorkflowBasedAppRunner:
|
||||
def test_resolve_user_from(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
assert runner._resolve_user_from(InvokeFrom.EXPLORE) == UserFrom.ACCOUNT
|
||||
assert runner._resolve_user_from(InvokeFrom.DEBUGGER) == UserFrom.ACCOUNT
|
||||
assert runner._resolve_user_from(InvokeFrom.WEB_APP) == UserFrom.END_USER
|
||||
|
||||
def test_init_graph_validates_graph_structure(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="nodes or edges not found"):
|
||||
runner._init_graph(
|
||||
graph_config={},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="nodes in workflow graph must be a list"):
|
||||
runner._init_graph(
|
||||
graph_config={"nodes": {}, "edges": []},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="edges in workflow graph must be a list"):
|
||||
runner._init_graph(
|
||||
graph_config={"nodes": [], "edges": {}},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
def test_prepare_single_node_execution_requires_run(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
workflow = SimpleNamespace(environment_variables=[], graph_dict={})
|
||||
|
||||
with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"):
|
||||
runner._prepare_single_node_execution(workflow, None, None)
|
||||
|
||||
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
graph_config = {
|
||||
"nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}],
|
||||
"edges": [],
|
||||
}
|
||||
workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.Graph.init",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
|
||||
class _NodeCls:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(graph_config, config):
|
||||
return {}
|
||||
|
||||
from core.app.apps import workflow_app_runner
|
||||
|
||||
monkeypatch.setitem(
|
||||
workflow_app_runner.NODE_TYPE_CLASSES_MAPPING,
|
||||
NodeType.START,
|
||||
{"1": _NodeCls},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.load_into_variable_pool",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
|
||||
graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id="node-1",
|
||||
user_inputs={},
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
assert graph is not None
|
||||
assert variable_pool is graph_runtime_state.variable_pool
|
||||
|
||||
def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch):
|
||||
published: list[object] = []
|
||||
|
||||
class _QueueManager:
|
||||
def publish(self, event, publish_from):
|
||||
published.append((event, publish_from))
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_runtime_state.register_paused_node("node-1")
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
|
||||
emails: list[dict] = []
|
||||
|
||||
class _Dispatch:
|
||||
def apply_async(self, *, kwargs, queue):
|
||||
emails.append({"kwargs": kwargs, "queue": queue})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.dispatch_human_input_email_task",
|
||||
_Dispatch(),
|
||||
)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Node",
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry, GraphRunStartedEvent())
|
||||
runner._handle_event(workflow_entry, GraphRunSucceededEvent(outputs={"ok": True}))
|
||||
runner._handle_event(workflow_entry, GraphRunPausedEvent(reasons=[reason], outputs={}))
|
||||
|
||||
assert any(isinstance(event, QueueWorkflowStartedEvent) for event, _ in published)
|
||||
assert any(isinstance(event, QueueWorkflowSucceededEvent) for event, _ in published)
|
||||
paused_event = next(event for event, _ in published if isinstance(event, QueueWorkflowPausedEvent))
|
||||
assert paused_event.paused_nodes == ["node-1"]
|
||||
assert emails
|
||||
|
||||
def test_handle_node_events_publishes_queue_events(self):
|
||||
published: list[object] = []
|
||||
|
||||
class _QueueManager:
|
||||
def publish(self, event, publish_from):
|
||||
published.append(event)
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunStartedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
node_title="Start",
|
||||
start_at=datetime.utcnow(),
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
selector=["node", "text"],
|
||||
chunk="hi",
|
||||
is_final=False,
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunAgentLogEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
message_id="msg",
|
||||
label="label",
|
||||
node_execution_id="exec",
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status="done",
|
||||
data={},
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunIterationSucceededEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Iter",
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={"ok": True},
|
||||
metadata={},
|
||||
steps=1,
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunLoopFailedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Loop",
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
metadata={},
|
||||
steps=1,
|
||||
error="boom",
|
||||
),
|
||||
)
|
||||
|
||||
assert any(isinstance(event, QueueTextChunkEvent) for event in published)
|
||||
assert any(isinstance(event, QueueAgentLogEvent) for event in published)
|
||||
assert any(isinstance(event, QueueIterationCompletedEvent) for event in published)
|
||||
assert any(isinstance(event, QueueLoopCompletedEvent) for event in published)
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestWorkflowAppConfigManager:
|
||||
def test_get_app_config(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
workflow = SimpleNamespace(id="wf-1", features_dict={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.WorkflowVariablesConfigManager.convert",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model, workflow)
|
||||
|
||||
assert app_config.workflow_id == "wf-1"
|
||||
assert app_config.app_mode == AppMode.WORKFLOW
|
||||
|
||||
def test_config_validate_filters_keys(self):
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
# Support both positional and keyword arguments for config
|
||||
if "config" in kwargs:
|
||||
config = kwargs["config"]
|
||||
elif len(args) > 0:
|
||||
config = args[0]
|
||||
else:
|
||||
config = {}
|
||||
config[key] = value
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 3),
|
||||
),
|
||||
):
|
||||
filtered = WorkflowAppConfigManager.config_validate(tenant_id="t1", config={})
|
||||
|
||||
assert filtered["file_upload"] == 1
|
||||
assert filtered["text_to_speech"] == 2
|
||||
assert filtered["sensitive_word_avoidance"] == 3
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestWorkflowAppGeneratorValidation:
|
||||
def test_should_prepare_user_inputs(self):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
assert generator._should_prepare_user_inputs({}) is True
|
||||
assert generator._should_prepare_user_inputs({SKIP_PREPARE_USER_INPUTS_KEY: True}) is False
|
||||
|
||||
def test_single_iteration_generate_validates_args(self):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
with pytest.raises(ValueError, match="node_id is required"):
|
||||
generator.single_iteration_generate(
|
||||
app_model=SimpleNamespace(),
|
||||
workflow=SimpleNamespace(),
|
||||
node_id="",
|
||||
user=SimpleNamespace(),
|
||||
args={"inputs": {}},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="inputs is required"):
|
||||
generator.single_iteration_generate(
|
||||
app_model=SimpleNamespace(),
|
||||
workflow=SimpleNamespace(),
|
||||
node_id="node",
|
||||
user=SimpleNamespace(),
|
||||
args={},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_single_loop_generate_validates_args(self):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
with pytest.raises(ValueError, match="node_id is required"):
|
||||
generator.single_loop_generate(
|
||||
app_model=SimpleNamespace(),
|
||||
workflow=SimpleNamespace(),
|
||||
node_id="",
|
||||
user=SimpleNamespace(),
|
||||
args=SimpleNamespace(inputs={}),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="inputs is required"):
|
||||
generator.single_loop_generate(
|
||||
app_model=SimpleNamespace(),
|
||||
workflow=SimpleNamespace(),
|
||||
node_id="node",
|
||||
user=SimpleNamespace(),
|
||||
args=SimpleNamespace(inputs=None),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowAppGeneratorHandleResponse:
|
||||
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
class _Pipeline:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
_ = kwargs
|
||||
|
||||
def process(self):
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateTaskPipeline",
|
||||
_Pipeline,
|
||||
)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=SimpleNamespace(),
|
||||
queue_manager=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowAppGeneratorGenerate:
|
||||
def test_generate_skips_prepare_inputs_when_flag_set(self, monkeypatch):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config",
|
||||
lambda app_model, workflow: app_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.FileUploadConfigManager.convert",
|
||||
lambda features_dict, is_vision=False: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.file_factory.build_from_mappings",
|
||||
lambda **kwargs: [],
|
||||
)
|
||||
DummyTraceQueueManager = type(
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.TraceQueueManager",
|
||||
DummyTraceQueueManager,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.sessionmaker",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
|
||||
prepare_inputs = pytest.fail
|
||||
monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: prepare_inputs())
|
||||
|
||||
monkeypatch.setattr(generator, "_generate", lambda **kwargs: {"ok": True})
|
||||
|
||||
result = generator.generate(
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
|
||||
workflow=SimpleNamespace(features_dict={}),
|
||||
user=SimpleNamespace(id="user", session_id="session"),
|
||||
args={"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueMessageEndEvent, QueuePingEvent
|
||||
|
||||
|
||||
class TestWorkflowAppQueueManager:
|
||||
def test_publish_stop_events_trigger_stop(self):
|
||||
manager = WorkflowAppQueueManager(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
app_mode="workflow",
|
||||
)
|
||||
manager._is_stopped = lambda: True
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager._publish(QueueMessageEndEvent(llm_result=None), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def test_publish_non_stop_event_does_not_raise(self):
|
||||
manager = WorkflowAppQueueManager(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
app_mode="workflow",
|
||||
)
|
||||
|
||||
manager._publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError
|
||||
|
||||
|
||||
class TestWorkflowErrors:
|
||||
def test_workflow_paused_in_blocking_mode_error_attributes(self):
|
||||
err = WorkflowPausedInBlockingModeError()
|
||||
assert err.error_code == "workflow_paused_in_blocking_mode"
|
||||
assert err.code == 400
|
||||
assert "blocking response mode" in err.description
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestWorkflowGenerateResponseConverter:
|
||||
def test_blocking_full_response(self):
|
||||
blocking = WorkflowAppBlockingResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id="exec-1",
|
||||
workflow_id="wf-1",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"ok": True},
|
||||
error=None,
|
||||
elapsed_time=1.2,
|
||||
total_tokens=10,
|
||||
total_steps=2,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
assert response["workflow_run_id"] == "r1"
|
||||
|
||||
def test_stream_simple_response_node_events(self):
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
elapsed_time=0.1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=PingStreamResponse(task_id="t1"))
|
||||
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_start)
|
||||
yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_finish)
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="r1", stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom"))
|
||||
)
|
||||
|
||||
converted = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert converted[0] == "ping"
|
||||
assert converted[1]["event"] == "node_started"
|
||||
assert converted[2]["event"] == "node_finished"
|
||||
assert converted[3]["event"] == "error"
|
||||
|
||||
def test_convert_stream_simple_response_handles_ping_and_nodes(self):
|
||||
def _gen():
|
||||
yield WorkflowAppStreamResponse(stream_response=PingStreamResponse(task_id="task"))
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="run",
|
||||
stream_response=NodeStartStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="node-exec",
|
||||
node_id="node",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
index=1,
|
||||
created_at=1,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="run",
|
||||
stream_response=NodeFinishStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="node-exec",
|
||||
node_id="node",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
elapsed_time=1.0,
|
||||
error=None,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(_gen()))
|
||||
|
||||
assert chunks[0] == "ping"
|
||||
assert chunks[1]["event"] == "node_started"
|
||||
assert chunks[2]["event"] == "node_finished"
|
||||
|
||||
def test_convert_stream_full_response_handles_error(self):
|
||||
def _gen():
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id="run",
|
||||
stream_response=ErrorStreamResponse(task_id="task", err=ValueError("boom")),
|
||||
)
|
||||
|
||||
chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(_gen()))
|
||||
|
||||
assert chunks[0]["event"] == "error"
|
||||
|
|
@ -0,0 +1,868 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
from dify_graph.enums import NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode, EndUser
|
||||
|
||||
|
||||
def _make_pipeline():
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras={},
|
||||
call_depth=0,
|
||||
)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
user = SimpleNamespace(id="user", session_id="session")
|
||||
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
user=user,
|
||||
stream=False,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class TestWorkflowGenerateTaskPipeline:
|
||||
def test_to_blocking_response_handles_pause(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
def _gen():
|
||||
yield WorkflowPauseStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id="run",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
created_at=1,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert response.data.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
def test_to_blocking_response_handles_finish(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
def _gen():
|
||||
yield WorkflowFinishStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id="run",
|
||||
workflow_id="workflow-id",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"ok": True},
|
||||
error=None,
|
||||
elapsed_time=1.0,
|
||||
total_tokens=5,
|
||||
total_steps=2,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert response.data.outputs == {"ok": True}
|
||||
|
||||
def test_listen_audio_msg_returns_audio_stream(self):
|
||||
pipeline = _make_pipeline()
|
||||
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
|
||||
|
||||
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
|
||||
|
||||
assert isinstance(response, MessageAudioStreamResponse)
|
||||
|
||||
def test_handle_ping_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
|
||||
|
||||
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
|
||||
|
||||
assert isinstance(responses[0], PingStreamResponse)
|
||||
|
||||
def test_handle_error_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
|
||||
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
|
||||
|
||||
assert isinstance(responses[0], ValueError)
|
||||
|
||||
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
monkeypatch.setattr(pipeline, "_save_workflow_app_log", lambda **kwargs: None)
|
||||
|
||||
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
|
||||
|
||||
assert pipeline._workflow_execution_id == "run-id"
|
||||
assert responses == ["started"]
|
||||
|
||||
def test_handle_node_succeeded_event_saves_output(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
assert responses == ["done"]
|
||||
|
||||
def test_handle_workflow_failed_event_yields_error(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_workflow_failed_and_stop_events(QueueWorkflowFailedEvent(error="fail", exceptions_count=1))
|
||||
)
|
||||
|
||||
assert responses[0] == "finish"
|
||||
|
||||
def test_handle_text_chunk_event_publishes_tts(self):
|
||||
pipeline = _make_pipeline()
|
||||
published: list[object] = []
|
||||
|
||||
class _Publisher:
|
||||
def publish(self, message):
|
||||
published.append(message)
|
||||
|
||||
event = QueueTextChunkEvent(text="hi", from_variable_selector=["x"])
|
||||
queue_message = SimpleNamespace(event=event)
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
|
||||
)
|
||||
|
||||
assert responses[0].data.text == "hi"
|
||||
assert published == [queue_message]
|
||||
|
||||
def test_dispatch_event_handles_node_failed(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
|
||||
|
||||
event = QueueNodeFailedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._dispatch_event(event)) == ["done"]
|
||||
|
||||
def test_handle_stop_event_yields_finish(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_workflow_failed_and_stop_events(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)
|
||||
)
|
||||
)
|
||||
|
||||
assert responses == ["finish"]
|
||||
|
||||
def test_save_workflow_app_log_created_from(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
pipeline._user_id = "user"
|
||||
added: list[object] = []
|
||||
|
||||
class _Session:
|
||||
def add(self, item):
|
||||
added.append(item)
|
||||
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
|
||||
|
||||
assert added
|
||||
|
||||
def test_iteration_loop_and_human_input_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: "iter"
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: "done"
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
|
||||
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
|
||||
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
|
||||
pipeline._workflow_response_converter.handle_agent_log = lambda **kwargs: "log"
|
||||
|
||||
iter_start = QueueIterationStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_next = QueueIterationNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_done = QueueIterationCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_start = QueueLoopStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_next = QueueLoopNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_done = QueueLoopCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
filled_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
rendered_content="content",
|
||||
action_id="action",
|
||||
action_text="action",
|
||||
)
|
||||
timeout_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
expiration_time=datetime.utcnow(),
|
||||
)
|
||||
agent_event = QueueAgentLogEvent(
|
||||
id="log",
|
||||
label="label",
|
||||
node_execution_id="exec",
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status="done",
|
||||
data={},
|
||||
metadata={},
|
||||
node_id="node",
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter"]
|
||||
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["next"]
|
||||
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["done"]
|
||||
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop"]
|
||||
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
|
||||
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
|
||||
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
|
||||
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
|
||||
assert list(pipeline._handle_agent_log_event(agent_event)) == ["log"]
|
||||
|
||||
def test_wrapper_process_stream_response_emits_audio_end(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_features_dict = {
|
||||
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
|
||||
}
|
||||
pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
|
||||
|
||||
class _Publisher:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.calls = 0
|
||||
|
||||
def check_and_get_audio(self):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return AudioTrunk(status="stream", audio="data")
|
||||
if self.calls == 2:
|
||||
return None
|
||||
return AudioTrunk(status="finish", audio="")
|
||||
|
||||
def publish(self, message):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
|
||||
_Publisher,
|
||||
)
|
||||
|
||||
responses = list(pipeline._wrapper_process_stream_response())
|
||||
|
||||
assert any(isinstance(item, MessageAudioStreamResponse) for item in responses)
|
||||
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
|
||||
|
||||
def test_init_with_end_user_sets_role_and_system_user(self):
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = WorkflowAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="end-user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras={},
|
||||
call_depth=0,
|
||||
)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None)
|
||||
end_user = EndUser(tenant_id="tenant", type="session", name="user", session_id="session-id")
|
||||
end_user.id = "end-user-id"
|
||||
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=end_user,
|
||||
stream=False,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
assert pipeline._created_by_role == CreatorUserRole.END_USER
|
||||
assert pipeline._workflow_system_variables.user_id == "session-id"
|
||||
|
||||
def test_process_returns_stream_and_blocking_variants(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.stream = True
|
||||
pipeline._wrapper_process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
|
||||
|
||||
stream_response = list(pipeline.process())
|
||||
assert len(stream_response) == 1
|
||||
assert stream_response[0].workflow_run_id is None
|
||||
|
||||
pipeline._base_task_pipeline.stream = False
|
||||
pipeline._wrapper_process_stream_response = lambda **kwargs: iter(
|
||||
[
|
||||
WorkflowFinishStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
blocking_response = pipeline.process()
|
||||
assert blocking_response.workflow_run_id == "run-id"
|
||||
|
||||
def test_to_blocking_response_handles_error_and_unexpected_end(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
def _error_gen():
|
||||
yield ErrorStreamResponse(task_id="task", err=ValueError("boom"))
|
||||
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
pipeline._to_blocking_response(_error_gen())
|
||||
|
||||
def _unexpected_gen():
|
||||
yield PingStreamResponse(task_id="task")
|
||||
|
||||
with pytest.raises(ValueError, match="queue listening stopped unexpectedly"):
|
||||
pipeline._to_blocking_response(_unexpected_gen())
|
||||
|
||||
def test_to_stream_response_tracks_workflow_run_id(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
def _gen():
|
||||
yield WorkflowStartStreamResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
inputs={},
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
yield PingStreamResponse(task_id="task")
|
||||
|
||||
stream_responses = list(pipeline._to_stream_response(_gen()))
|
||||
assert stream_responses[0].workflow_run_id == "run-id"
|
||||
assert stream_responses[1].workflow_run_id == "run-id"
|
||||
|
||||
def test_listen_audio_msg_returns_none_without_publisher(self):
|
||||
pipeline = _make_pipeline()
|
||||
assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None
|
||||
|
||||
def test_wrapper_process_stream_response_without_tts(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_features_dict = {}
|
||||
pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")])
|
||||
|
||||
responses = list(pipeline._wrapper_process_stream_response())
|
||||
assert responses == [PingStreamResponse(task_id="task")]
|
||||
|
||||
def test_wrapper_process_stream_response_final_audio_none_then_finish(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_features_dict = {
|
||||
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
|
||||
}
|
||||
pipeline._process_stream_response = lambda **kwargs: iter([])
|
||||
|
||||
sleep_spy = []
|
||||
|
||||
class _Publisher:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.calls = 0
|
||||
|
||||
def check_and_get_audio(self):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return None
|
||||
return AudioTrunk(status="finish", audio="")
|
||||
|
||||
def publish(self, message):
|
||||
_ = message
|
||||
|
||||
time_values = iter([0.0, 0.0, 0.2])
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: next(time_values))
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.time.sleep", lambda _: sleep_spy.append(True)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
|
||||
_Publisher,
|
||||
)
|
||||
|
||||
responses = list(pipeline._wrapper_process_stream_response())
|
||||
|
||||
assert sleep_spy
|
||||
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
|
||||
|
||||
def test_wrapper_process_stream_response_handles_audio_exception(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_features_dict = {
|
||||
"text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"}
|
||||
}
|
||||
pipeline._process_stream_response = lambda **kwargs: iter([])
|
||||
|
||||
class _Publisher:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.called = False
|
||||
|
||||
def check_and_get_audio(self):
|
||||
if not self.called:
|
||||
self.called = True
|
||||
raise RuntimeError("tts failure")
|
||||
return AudioTrunk(status="finish", audio="")
|
||||
|
||||
def publish(self, message):
|
||||
_ = message
|
||||
|
||||
logger_exception = []
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: 0.0)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.logger.exception",
|
||||
lambda *args, **kwargs: logger_exception.append((args, kwargs)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher",
|
||||
_Publisher,
|
||||
)
|
||||
|
||||
responses = list(pipeline._wrapper_process_stream_response())
|
||||
|
||||
assert logger_exception
|
||||
assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses)
|
||||
|
||||
def test_database_session_rolls_back_on_error(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
calls = {"commit": 0, "rollback": 0}
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def commit(self):
|
||||
calls["commit"] += 1
|
||||
|
||||
def rollback(self):
|
||||
calls["rollback"] += 1
|
||||
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
|
||||
|
||||
with pytest.raises(RuntimeError, match="db error"):
|
||||
with pipeline._database_session():
|
||||
raise RuntimeError("db error")
|
||||
|
||||
assert calls["commit"] == 0
|
||||
assert calls["rollback"] == 1
|
||||
|
||||
def test_node_retry_and_started_handlers_cover_none_and_value(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
|
||||
retry_event = QueueNodeRetryEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_title="title",
|
||||
node_type=NodeType.LLM,
|
||||
node_run_index=1,
|
||||
start_at=datetime.utcnow(),
|
||||
provider_type="provider",
|
||||
provider_id="provider-id",
|
||||
error="error",
|
||||
retry_index=1,
|
||||
)
|
||||
started_event = QueueNodeStartedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_title="title",
|
||||
node_type=NodeType.LLM,
|
||||
node_run_index=1,
|
||||
start_at=datetime.utcnow(),
|
||||
provider_type="provider",
|
||||
provider_id="provider-id",
|
||||
)
|
||||
|
||||
pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: None
|
||||
assert list(pipeline._handle_node_retry_event(retry_event)) == []
|
||||
pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: "retry"
|
||||
assert list(pipeline._handle_node_retry_event(retry_event)) == ["retry"]
|
||||
|
||||
pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: None
|
||||
assert list(pipeline._handle_node_started_event(started_event)) == []
|
||||
pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: "started"
|
||||
assert list(pipeline._handle_node_started_event(started_event)) == ["started"]
|
||||
|
||||
def test_handle_node_exception_event_saves_output(self):
|
||||
pipeline = _make_pipeline()
|
||||
saved_ids: list[str] = []
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: saved_ids.append(node_execution_id)
|
||||
|
||||
event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec-id",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="boom",
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_node_failed_events(event))
|
||||
assert responses == ["failed"]
|
||||
assert saved_ids == ["exec-id"]
|
||||
|
||||
def test_success_partial_and_pause_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
assert list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) == ["finish"]
|
||||
assert list(
|
||||
pipeline._handle_workflow_partial_success_event(
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=2, outputs={})
|
||||
)
|
||||
) == ["finish"]
|
||||
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: [
|
||||
"pause-a",
|
||||
"pause-b",
|
||||
]
|
||||
pause_event = QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=["node"])
|
||||
assert list(pipeline._handle_workflow_paused_event(pause_event)) == ["pause-a", "pause-b"]
|
||||
|
||||
def test_text_chunk_handler_returns_empty_when_text_missing(self):
|
||||
pipeline = _make_pipeline()
|
||||
event = QueueTextChunkEvent.model_construct(text=None, from_variable_selector=None)
|
||||
assert list(pipeline._handle_text_chunk_event(event)) == []
|
||||
|
||||
def test_dispatch_event_direct_failed_and_unhandled_paths(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_execution_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
|
||||
assert list(pipeline._dispatch_event(QueuePingEvent())) == ["ping"]
|
||||
|
||||
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["workflow-failed"])
|
||||
assert list(pipeline._dispatch_event(QueueWorkflowFailedEvent(error="failed", exceptions_count=1))) == [
|
||||
"workflow-failed"
|
||||
]
|
||||
|
||||
assert list(pipeline._dispatch_event(SimpleNamespace())) == []
|
||||
|
||||
def test_process_stream_response_main_match_paths_and_cleanup(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
[
|
||||
SimpleNamespace(event=QueueWorkflowStartedEvent()),
|
||||
SimpleNamespace(event=QueueTextChunkEvent(text="hello")),
|
||||
SimpleNamespace(event=QueuePingEvent()),
|
||||
SimpleNamespace(event=QueueErrorEvent(error="e")),
|
||||
]
|
||||
)
|
||||
pipeline._handle_workflow_started_event = lambda event, **kwargs: iter(["started"])
|
||||
pipeline._handle_text_chunk_event = lambda event, **kwargs: iter(["text"])
|
||||
pipeline._dispatch_event = lambda event, **kwargs: iter(["dispatched"])
|
||||
pipeline._handle_error_event = lambda event, **kwargs: iter(["error"])
|
||||
publisher_calls: list[object] = []
|
||||
|
||||
class _Publisher:
|
||||
def publish(self, message):
|
||||
publisher_calls.append(message)
|
||||
|
||||
responses = list(pipeline._process_stream_response(tts_publisher=_Publisher()))
|
||||
assert responses == ["started", "text", "dispatched", "error"]
|
||||
assert publisher_calls == [None]
|
||||
|
||||
def test_process_stream_response_break_paths(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
[SimpleNamespace(event=QueueWorkflowFailedEvent(error="fail", exceptions_count=1))]
|
||||
)
|
||||
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["failed"])
|
||||
assert list(pipeline._process_stream_response()) == ["failed"]
|
||||
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
[SimpleNamespace(event=QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=[]))]
|
||||
)
|
||||
pipeline._handle_workflow_paused_event = lambda event, **kwargs: iter(["paused"])
|
||||
assert list(pipeline._process_stream_response()) == ["paused"]
|
||||
|
||||
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(
|
||||
[SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))]
|
||||
)
|
||||
pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["stopped"])
|
||||
assert list(pipeline._process_stream_response()) == ["stopped"]
|
||||
|
||||
def test_save_workflow_app_log_covers_invoke_from_variants(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._user_id = "user-id"
|
||||
added: list[object] = []
|
||||
|
||||
class _Session:
|
||||
def add(self, item):
|
||||
added.append(item)
|
||||
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.EXPLORE
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
|
||||
assert added[-1].created_from == "installed-app"
|
||||
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
|
||||
assert added[-1].created_from == "web-app"
|
||||
|
||||
count_before = len(added)
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.DEBUGGER
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id")
|
||||
assert len(added) == count_before
|
||||
|
||||
pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None)
|
||||
assert len(added) == count_before
|
||||
|
||||
def test_save_output_for_event_writes_draft_variables(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
saver_calls: list[tuple[object, object]] = []
|
||||
captured_factory_args: dict[str, object] = {}
|
||||
|
||||
class _Saver:
|
||||
def save(self, process_data, outputs):
|
||||
saver_calls.append((process_data, outputs))
|
||||
|
||||
def _factory(**kwargs):
|
||||
captured_factory_args.update(kwargs)
|
||||
return _Saver()
|
||||
|
||||
class _Begin:
|
||||
def __enter__(self):
|
||||
return None
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class _Session:
|
||||
def __init__(self, *args, **kwargs):
|
||||
_ = args, kwargs
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def begin(self):
|
||||
return _Begin()
|
||||
|
||||
pipeline._draft_var_saver_factory = _factory
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object()))
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="exec-id",
|
||||
node_id="node-id",
|
||||
node_type=NodeType.START,
|
||||
in_loop_id="loop-id",
|
||||
start_at=datetime.utcnow(),
|
||||
process_data={"k": "v"},
|
||||
outputs={"out": 1},
|
||||
)
|
||||
pipeline._save_output_for_event(event=event, node_execution_id="exec-id")
|
||||
|
||||
assert captured_factory_args["node_execution_id"] == "exec-id"
|
||||
assert captured_factory_args["enclosing_node_id"] == "loop-id"
|
||||
assert saver_calls == [({"k": "v"}, {"out": 1})]
|
||||
Loading…
Reference in New Issue