mirror of https://github.com/langgenius/dify.git
test: unit test cases for core.callback, core.base, core.entities module (#32471)
This commit is contained in:
parent
36c1f4d506
commit
c59685748c
|
|
@ -0,0 +1,390 @@
|
|||
import base64
|
||||
import queue
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.base.tts.app_generator_tts_publisher import (
|
||||
AppGeneratorTTSPublisher,
|
||||
AudioTrunk,
|
||||
_invoice_tts,
|
||||
_process_future,
|
||||
)
|
||||
|
||||
# =========================
|
||||
# Fixtures
|
||||
# =========================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance(mocker):
|
||||
model = mocker.MagicMock()
|
||||
model.invoke_tts.return_value = [b"audio1", b"audio2"]
|
||||
model.get_tts_voices.return_value = [{"value": "voice1"}, {"value": "voice2"}]
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_manager(mocker, mock_model_instance):
|
||||
manager = mocker.MagicMock()
|
||||
manager.get_default_model_instance.return_value = mock_model_instance
|
||||
mocker.patch(
|
||||
"core.base.tts.app_generator_tts_publisher.ModelManager",
|
||||
return_value=manager,
|
||||
)
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_threads(mocker):
|
||||
"""Prevent real threads from starting during tests"""
|
||||
mocker.patch("threading.Thread.start", return_value=None)
|
||||
|
||||
|
||||
# =========================
|
||||
# AudioTrunk Tests
|
||||
# =========================
|
||||
|
||||
|
||||
class TestAudioTrunk:
|
||||
def test_audio_trunk_initialization(self):
|
||||
trunk = AudioTrunk("responding", b"data")
|
||||
assert trunk.status == "responding"
|
||||
assert trunk.audio == b"data"
|
||||
|
||||
|
||||
# =========================
|
||||
# _invoice_tts Tests
|
||||
# =========================
|
||||
|
||||
|
||||
class TestInvoiceTTS:
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[None, "", " "],
|
||||
)
|
||||
def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance):
|
||||
result = _invoice_tts(text, mock_model_instance, "tenant", "voice1")
|
||||
assert result is None
|
||||
mock_model_instance.invoke_tts.assert_not_called()
|
||||
|
||||
def test_invoice_tts_valid_text(self, mock_model_instance):
|
||||
result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1")
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="hello",
|
||||
user="responding_tts",
|
||||
tenant_id="tenant",
|
||||
voice="voice1",
|
||||
)
|
||||
assert result == [b"audio1", b"audio2"]
|
||||
|
||||
|
||||
# =========================
|
||||
# _process_future Tests
|
||||
# =========================
|
||||
|
||||
|
||||
class TestProcessFuture:
|
||||
def test_process_future_normal_flow(self):
|
||||
future_queue = queue.Queue()
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
future = MagicMock()
|
||||
future.result.return_value = [b"abc"]
|
||||
|
||||
future_queue.put(future)
|
||||
future_queue.put(None)
|
||||
|
||||
_process_future(future_queue, audio_queue)
|
||||
|
||||
first = audio_queue.get()
|
||||
assert first.status == "responding"
|
||||
assert first.audio == base64.b64encode(b"abc")
|
||||
|
||||
finish = audio_queue.get()
|
||||
assert finish.status == "finish"
|
||||
|
||||
def test_process_future_empty_result(self):
|
||||
future_queue = queue.Queue()
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
future = MagicMock()
|
||||
future.result.return_value = None
|
||||
|
||||
future_queue.put(future)
|
||||
future_queue.put(None)
|
||||
|
||||
_process_future(future_queue, audio_queue)
|
||||
|
||||
finish = audio_queue.get()
|
||||
assert finish.status == "finish"
|
||||
|
||||
def test_process_future_exception(self, mocker):
|
||||
future_queue = queue.Queue()
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
future = MagicMock()
|
||||
future.result.side_effect = Exception("error")
|
||||
|
||||
future_queue.put(future)
|
||||
|
||||
_process_future(future_queue, audio_queue)
|
||||
|
||||
finish = audio_queue.get()
|
||||
assert finish.status == "finish"
|
||||
|
||||
|
||||
# =========================
|
||||
# AppGeneratorTTSPublisher Tests
|
||||
# =========================
|
||||
|
||||
|
||||
class TestAppGeneratorTTSPublisher:
|
||||
def test_initialization_valid_voice(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
assert publisher.voice == "voice1"
|
||||
assert publisher.max_sentence == 2
|
||||
assert publisher.msg_text == ""
|
||||
|
||||
def test_initialization_invalid_voice_fallback(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "invalid_voice")
|
||||
assert publisher.voice == "voice1"
|
||||
|
||||
def test_publish_puts_message_in_queue(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
message = MagicMock()
|
||||
publisher.publish(message)
|
||||
assert publisher._msg_queue.get() == message
|
||||
|
||||
def test_check_and_get_audio_no_audio(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
result = publisher.check_and_get_audio()
|
||||
assert result is None
|
||||
|
||||
def test_check_and_get_audio_non_finish_event(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
trunk = AudioTrunk("responding", b"abc")
|
||||
publisher._audio_queue.put(trunk)
|
||||
|
||||
result = publisher.check_and_get_audio()
|
||||
|
||||
assert result.status == "responding"
|
||||
assert publisher._last_audio_event == trunk
|
||||
|
||||
def test_check_and_get_audio_finish_event(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
finish_trunk = AudioTrunk("finish", b"")
|
||||
publisher._audio_queue.put(finish_trunk)
|
||||
|
||||
result = publisher.check_and_get_audio()
|
||||
|
||||
assert result.status == "finish"
|
||||
publisher.executor.shutdown.assert_called_once()
|
||||
|
||||
def test_check_and_get_audio_cached_finish(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
publisher._last_audio_event = AudioTrunk("finish", b"")
|
||||
|
||||
result = publisher.check_and_get_audio()
|
||||
|
||||
assert result.status == "finish"
|
||||
publisher.executor.shutdown.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("text", "expected_sentences", "expected_remaining"),
|
||||
[
|
||||
("Hello world.", ["Hello world."], ""),
|
||||
("Hello world! How are you?", ["Hello world!", " How are you?"], ""),
|
||||
("No punctuation", [], "No punctuation"),
|
||||
("", [], ""),
|
||||
],
|
||||
)
|
||||
def test_extract_sentence(self, mock_model_manager, text, expected_sentences, expected_remaining):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
sentences, remaining = publisher._extract_sentence(text)
|
||||
assert sentences == expected_sentences
|
||||
assert remaining == expected_remaining
|
||||
|
||||
def test_runtime_handles_none_message_with_buffer(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
publisher.msg_text = "Hello."
|
||||
|
||||
publisher._msg_queue.put(None)
|
||||
publisher._runtime()
|
||||
|
||||
publisher.executor.submit.assert_called_once()
|
||||
|
||||
def test_runtime_handles_none_message_without_buffer(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
publisher.msg_text = " "
|
||||
|
||||
publisher._msg_queue.put(None)
|
||||
publisher._runtime()
|
||||
|
||||
publisher.executor.submit.assert_not_called()
|
||||
|
||||
def test_runtime_sentence_threshold_triggers_submit(self, mock_model_manager, mocker):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
# Force sentence extraction to hit threshold condition
|
||||
mocker.patch.object(
|
||||
publisher,
|
||||
"_extract_sentence",
|
||||
return_value=(["Hello world.", " Second sentence."], ""),
|
||||
)
|
||||
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueTextChunkEvent)
|
||||
event.event.text = "Hello world. Second sentence."
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.executor.submit.called
|
||||
|
||||
def test_runtime_handles_text_chunk_event(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueTextChunkEvent)
|
||||
event.event.text = "Hello world."
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.executor.submit.called
|
||||
|
||||
def test_runtime_handles_node_succeeded_event_with_output(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueNodeSucceededEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueNodeSucceededEvent)
|
||||
event.event.outputs = {"output": "Hello world."}
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.executor.submit.called
|
||||
|
||||
def test_runtime_handles_node_succeeded_event_without_output(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueNodeSucceededEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueNodeSucceededEvent)
|
||||
event.event.outputs = None
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
publisher.executor.submit.assert_not_called()
|
||||
|
||||
def test_runtime_handles_agent_message_event_list_content(self, mock_model_manager, mocker):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
chunk = LLMResultChunk(
|
||||
model="model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="Hello "),
|
||||
ImagePromptMessageContent(format="png", mime_type="image/png", base64_data="a"),
|
||||
]
|
||||
),
|
||||
),
|
||||
)
|
||||
event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk))
|
||||
|
||||
mocker.patch.object(publisher, "_extract_sentence", return_value=([], ""))
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.msg_text == "Hello "
|
||||
|
||||
def test_runtime_handles_agent_message_event_empty_content(self, mock_model_manager, mocker):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
chunk = LLMResultChunk(
|
||||
model="model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
),
|
||||
)
|
||||
event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk))
|
||||
|
||||
mocker.patch.object(publisher, "_extract_sentence", return_value=([], ""))
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.msg_text == ""
|
||||
|
||||
def test_runtime_resets_msg_text_when_text_tmp_not_str(self, mock_model_manager, mocker):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher.executor = MagicMock()
|
||||
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent
|
||||
|
||||
event = MagicMock()
|
||||
event.event = MagicMock(spec=QueueTextChunkEvent)
|
||||
event.event.text = "Hello world. Another sentence."
|
||||
|
||||
mocker.patch.object(publisher, "_extract_sentence", return_value=(["A.", "B."], None))
|
||||
|
||||
publisher._msg_queue.put(event)
|
||||
publisher._msg_queue.put(None)
|
||||
|
||||
publisher._runtime()
|
||||
|
||||
assert publisher.msg_text == ""
|
||||
|
||||
def test_runtime_exception_path(self, mock_model_manager):
|
||||
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
|
||||
publisher._msg_queue = MagicMock()
|
||||
publisher._msg_queue.get.side_effect = Exception("error")
|
||||
|
||||
publisher._runtime()
|
||||
|
|
@ -0,0 +1,197 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.callback_handler.agent_tool_callback_handler as module
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_debug(mocker):
|
||||
mocker.patch.object(module.dify_config, "DEBUG", True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_debug(mocker):
|
||||
mocker.patch.object(module.dify_config, "DEBUG", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_print(mocker):
|
||||
return mocker.patch("builtins.print")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
return module.DifyAgentCallbackHandler(color="blue")
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# get_colored_text Tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestGetColoredText:
|
||||
@pytest.mark.parametrize(
|
||||
("color", "expected_code"),
|
||||
[
|
||||
("blue", "36;1"),
|
||||
("yellow", "33;1"),
|
||||
("pink", "38;5;200"),
|
||||
("green", "32;1"),
|
||||
("red", "31;1"),
|
||||
],
|
||||
)
|
||||
def test_get_colored_text_valid_colors(self, color, expected_code):
|
||||
text = "hello"
|
||||
result = module.get_colored_text(text, color)
|
||||
assert expected_code in result
|
||||
assert text in result
|
||||
assert result.endswith("\u001b[0m")
|
||||
|
||||
def test_get_colored_text_invalid_color_raises(self):
|
||||
with pytest.raises(KeyError):
|
||||
module.get_colored_text("hello", "invalid")
|
||||
|
||||
def test_get_colored_text_empty_string(self):
|
||||
result = module.get_colored_text("", "green")
|
||||
assert "\u001b[" in result
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# print_text Tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestPrintText:
|
||||
def test_print_text_without_color(self, mock_print):
|
||||
module.print_text("hello")
|
||||
mock_print.assert_called_once_with("hello", end="", file=None)
|
||||
|
||||
def test_print_text_with_color(self, mocker, mock_print):
|
||||
mock_get_color = mocker.patch(
|
||||
"core.callback_handler.agent_tool_callback_handler.get_colored_text",
|
||||
return_value="colored_text",
|
||||
)
|
||||
|
||||
module.print_text("hello", color="green")
|
||||
|
||||
mock_get_color.assert_called_once_with("hello", "green")
|
||||
mock_print.assert_called_once_with("colored_text", end="", file=None)
|
||||
|
||||
def test_print_text_with_file_flush(self, mocker):
|
||||
mock_file = MagicMock()
|
||||
mock_print = mocker.patch("builtins.print")
|
||||
|
||||
module.print_text("hello", file=mock_file)
|
||||
|
||||
mock_print.assert_called_once_with("hello", end="", file=mock_file)
|
||||
mock_file.flush.assert_called_once()
|
||||
|
||||
def test_print_text_with_end_parameter(self, mock_print):
|
||||
module.print_text("hello", end="\n")
|
||||
mock_print.assert_called_once_with("hello", end="\n", file=None)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# DifyAgentCallbackHandler Tests
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class TestDifyAgentCallbackHandler:
|
||||
def test_init_default_color(self):
|
||||
handler = module.DifyAgentCallbackHandler()
|
||||
assert handler.color == "green"
|
||||
assert handler.current_loop == 1
|
||||
|
||||
def test_on_tool_start_debug_enabled(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_start("tool1", {"a": 1})
|
||||
|
||||
mock_print_text.assert_called()
|
||||
|
||||
def test_on_tool_start_debug_disabled(self, handler, disable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_start("tool1", {"a": 1})
|
||||
|
||||
mock_print_text.assert_not_called()
|
||||
|
||||
def test_on_tool_end_debug_enabled_and_trace(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
mock_trace_manager = MagicMock()
|
||||
|
||||
handler.on_tool_end(
|
||||
tool_name="tool1",
|
||||
tool_inputs={"a": 1},
|
||||
tool_outputs="output",
|
||||
message_id="msg1",
|
||||
timer=123,
|
||||
trace_manager=mock_trace_manager,
|
||||
)
|
||||
|
||||
assert mock_print_text.call_count >= 1
|
||||
mock_trace_manager.add_trace_task.assert_called_once()
|
||||
|
||||
def test_on_tool_end_without_trace_manager(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_end(
|
||||
tool_name="tool1",
|
||||
tool_inputs={},
|
||||
tool_outputs="output",
|
||||
)
|
||||
|
||||
assert mock_print_text.call_count >= 1
|
||||
|
||||
def test_on_tool_error_debug_enabled(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_error(Exception("error"))
|
||||
|
||||
mock_print_text.assert_called_once()
|
||||
|
||||
def test_on_tool_error_debug_disabled(self, handler, disable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_tool_error(Exception("error"))
|
||||
|
||||
mock_print_text.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("thought", ["thinking", ""])
|
||||
def test_on_agent_start(self, handler, enable_debug, mocker, thought):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_agent_start(thought)
|
||||
|
||||
mock_print_text.assert_called()
|
||||
|
||||
def test_on_agent_finish_increments_loop(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
current_loop = handler.current_loop
|
||||
handler.on_agent_finish()
|
||||
|
||||
assert handler.current_loop == current_loop + 1
|
||||
mock_print_text.assert_called()
|
||||
|
||||
def test_on_datasource_start_debug_enabled(self, handler, enable_debug, mocker):
|
||||
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
|
||||
|
||||
handler.on_datasource_start("ds1", {"x": 1})
|
||||
|
||||
mock_print_text.assert_called_once()
|
||||
|
||||
def test_ignore_agent_property(self, disable_debug, handler):
|
||||
assert handler.ignore_agent is True
|
||||
|
||||
def test_ignore_chat_model_property(self, disable_debug, handler):
|
||||
assert handler.ignore_chat_model is True
|
||||
|
||||
def test_ignore_properties_when_debug_enabled(self, enable_debug, handler):
|
||||
assert handler.ignore_agent is False
|
||||
assert handler.ignore_chat_model is False
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import (
|
||||
DatasetIndexToolCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(mocker):
|
||||
return mocker.Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler(mock_queue_manager, mocker):
|
||||
mocker.patch(
|
||||
"core.callback_handler.index_tool_callback_handler.db",
|
||||
)
|
||||
return DatasetIndexToolCallbackHandler(
|
||||
queue_manager=mock_queue_manager,
|
||||
app_id="app-1",
|
||||
message_id="msg-1",
|
||||
user_id="user-1",
|
||||
invoke_from=mocker.Mock(),
|
||||
)
|
||||
|
||||
|
||||
class TestOnQuery:
|
||||
@pytest.mark.parametrize(
|
||||
("invoke_from", "expected_role"),
|
||||
[
|
||||
(InvokeFrom.EXPLORE, "account"),
|
||||
(InvokeFrom.DEBUGGER, "account"),
|
||||
(InvokeFrom.WEB_APP, "end_user"),
|
||||
],
|
||||
)
|
||||
def test_on_query_success_roles(self, mocker, mock_queue_manager, invoke_from, expected_role):
|
||||
# Arrange
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
handler = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=mock_queue_manager,
|
||||
app_id="app-1",
|
||||
message_id="msg-1",
|
||||
user_id="user-1",
|
||||
invoke_from=mocker.Mock(),
|
||||
)
|
||||
|
||||
handler._invoke_from = invoke_from
|
||||
|
||||
# Act
|
||||
handler.on_query("test query", "dataset-1")
|
||||
|
||||
# Assert
|
||||
mock_db.session.add.assert_called_once()
|
||||
dataset_query = mock_db.session.add.call_args.args[0]
|
||||
assert dataset_query.created_by_role == expected_role
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_query_none_values(self, mocker, mock_queue_manager):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
handler = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=mock_queue_manager,
|
||||
app_id=None,
|
||||
message_id=None,
|
||||
user_id=None,
|
||||
invoke_from=None,
|
||||
)
|
||||
|
||||
handler.on_query(None, None)
|
||||
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestOnToolEnd:
|
||||
def test_on_tool_end_no_metadata(self, handler, mocker):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
document = mocker.Mock()
|
||||
document.metadata = None
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
def test_on_tool_end_dataset_document_not_found(self, handler, mocker):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
document = mocker.Mock()
|
||||
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
|
||||
def test_on_tool_end_parent_child_index_with_child(self, handler, mocker):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
mock_dataset_doc = mocker.Mock()
|
||||
from core.callback_handler.index_tool_callback_handler import IndexStructureType
|
||||
|
||||
mock_dataset_doc.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||
mock_dataset_doc.dataset_id = "dataset-1"
|
||||
mock_dataset_doc.id = "doc-1"
|
||||
|
||||
mock_child_chunk = mocker.Mock()
|
||||
mock_child_chunk.segment_id = "segment-1"
|
||||
|
||||
mock_db.session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk]
|
||||
|
||||
document = mocker.Mock()
|
||||
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
|
||||
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
|
||||
|
||||
mock_dataset_doc = mocker.Mock()
|
||||
mock_dataset_doc.doc_form = "OTHER"
|
||||
|
||||
mock_db.session.scalar.return_value = mock_dataset_doc
|
||||
|
||||
document = mocker.Mock()
|
||||
document.metadata = {
|
||||
"document_id": "doc-1",
|
||||
"doc_id": "node-1",
|
||||
"dataset_id": "dataset-1",
|
||||
}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_empty_documents(self, handler):
|
||||
handler.on_tool_end([])
|
||||
|
||||
|
||||
class TestReturnRetrieverResourceInfo:
|
||||
def test_publish_called(self, handler, mock_queue_manager, mocker):
|
||||
mock_event = mocker.patch("core.callback_handler.index_tool_callback_handler.QueueRetrieverResourcesEvent")
|
||||
|
||||
resources = [mocker.Mock()]
|
||||
|
||||
handler.return_retriever_resource_info(resources)
|
||||
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
from unittest.mock import MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import (
|
||||
DifyWorkflowCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
class DummyToolInvokeMessage:
|
||||
"""Lightweight dummy to simulate ToolInvokeMessage behavior."""
|
||||
|
||||
def __init__(self, json_value: str):
|
||||
self._json_value = json_value
|
||||
|
||||
def model_dump_json(self):
|
||||
return self._json_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
"""Fixture to create handler instance with deterministic color."""
|
||||
instance = DifyWorkflowCallbackHandler()
|
||||
instance.color = "blue"
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_print_text(mocker):
|
||||
"""Mock print_text to avoid real stdout printing."""
|
||||
return mocker.patch("core.callback_handler.workflow_tool_callback_handler.print_text")
|
||||
|
||||
|
||||
class TestDifyWorkflowCallbackHandler:
|
||||
def test_on_tool_execution_single_output_success(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "test_tool"
|
||||
tool_inputs = {"a": 1}
|
||||
message = DummyToolInvokeMessage('{"key": "value"}')
|
||||
|
||||
# Act
|
||||
results = list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs=tool_inputs,
|
||||
tool_outputs=[message],
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert results == [message]
|
||||
assert mock_print_text.call_count == 4
|
||||
mock_print_text.assert_has_calls(
|
||||
[
|
||||
call("\n[on_tool_execution]\n", color="blue"),
|
||||
call("Tool: test_tool\n", color="blue"),
|
||||
call(
|
||||
"Outputs: " + message.model_dump_json()[:1000] + "\n",
|
||||
color="blue",
|
||||
),
|
||||
call("\n"),
|
||||
]
|
||||
)
|
||||
|
||||
def test_on_tool_execution_multiple_outputs(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "multi_tool"
|
||||
outputs = [
|
||||
DummyToolInvokeMessage('{"id": 1}'),
|
||||
DummyToolInvokeMessage('{"id": 2}'),
|
||||
]
|
||||
|
||||
# Act
|
||||
results = list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=outputs,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert results == outputs
|
||||
assert mock_print_text.call_count == 4 * len(outputs)
|
||||
|
||||
def test_on_tool_execution_empty_iterable(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "empty_tool"
|
||||
|
||||
# Act
|
||||
results = list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert results == []
|
||||
mock_print_text.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invalid_outputs", "expected_exception"),
|
||||
[
|
||||
(None, TypeError),
|
||||
(123, TypeError),
|
||||
("not_iterable", AttributeError),
|
||||
],
|
||||
)
|
||||
def test_on_tool_execution_invalid_outputs_type(self, handler, invalid_outputs, expected_exception):
|
||||
# Arrange
|
||||
tool_name = "invalid_tool"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(expected_exception):
|
||||
list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=invalid_outputs,
|
||||
)
|
||||
)
|
||||
|
||||
def test_on_tool_execution_long_json_truncation(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "long_json_tool"
|
||||
long_json = "x" * 1500
|
||||
message = DummyToolInvokeMessage(long_json)
|
||||
|
||||
# Act
|
||||
list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=[message],
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
expected_truncated = long_json[:1000]
|
||||
mock_print_text.assert_any_call(
|
||||
"Outputs: " + expected_truncated + "\n",
|
||||
color="blue",
|
||||
)
|
||||
|
||||
def test_on_tool_execution_model_dump_json_exception(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "exception_tool"
|
||||
bad_message = MagicMock()
|
||||
bad_message.model_dump_json.side_effect = ValueError("JSON error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=[bad_message],
|
||||
)
|
||||
)
|
||||
|
||||
# Ensure first two prints happened before failure
|
||||
assert mock_print_text.call_count >= 2
|
||||
|
||||
def test_on_tool_execution_none_message_id_and_trace_manager(self, handler, mock_print_text):
|
||||
# Arrange
|
||||
tool_name = "optional_params_tool"
|
||||
message = DummyToolInvokeMessage('{"data": "ok"}')
|
||||
|
||||
# Act
|
||||
results = list(
|
||||
handler.on_tool_execution(
|
||||
tool_name=tool_name,
|
||||
tool_inputs={},
|
||||
tool_outputs=[message],
|
||||
message_id=None,
|
||||
timer=None,
|
||||
trace_manager=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert results == [message]
|
||||
assert mock_print_text.call_count == 4
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from core.entities.agent_entities import PlanningStrategy
|
||||
|
||||
|
||||
def test_planning_strategy_values_are_stable() -> None:
|
||||
# Arrange / Act / Assert
|
||||
assert PlanningStrategy.ROUTER.value == "router"
|
||||
assert PlanningStrategy.REACT_ROUTER.value == "react_router"
|
||||
assert PlanningStrategy.REACT.value == "react"
|
||||
assert PlanningStrategy.FUNCTION_CALL.value == "function_call"
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from core.entities.document_task import DocumentTask
|
||||
|
||||
|
||||
def test_document_task_keeps_indexing_identifiers() -> None:
|
||||
# Arrange
|
||||
document_ids = ("doc-1", "doc-2")
|
||||
|
||||
# Act
|
||||
task = DocumentTask(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert task.tenant_id == "tenant-1"
|
||||
assert task.dataset_id == "dataset-1"
|
||||
assert task.document_ids == document_ids
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
from core.entities.embedding_type import EmbeddingInputType
|
||||
|
||||
|
||||
def test_embedding_input_type_values_are_stable() -> None:
|
||||
# Arrange / Act / Assert
|
||||
assert EmbeddingInputType.DOCUMENT.value == "document"
|
||||
assert EmbeddingInputType.QUERY.value == "query"
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from core.entities.execution_extra_content import (
|
||||
ExecutionExtraContentDomainModel,
|
||||
HumanInputContent,
|
||||
HumanInputFormDefinition,
|
||||
HumanInputFormSubmissionData,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.enums import FormInputType
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
|
||||
|
||||
def test_human_input_content_defaults_and_domain_alias() -> None:
|
||||
# Arrange
|
||||
form_definition = HumanInputFormDefinition(
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
form_content="Please confirm",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="answer")],
|
||||
actions=[UserAction(id="confirm", title="Confirm")],
|
||||
resolved_default_values={"answer": "yes"},
|
||||
expiration_time=1_700_000_000,
|
||||
)
|
||||
submission_data = HumanInputFormSubmissionData(
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
rendered_content="Please confirm",
|
||||
action_id="confirm",
|
||||
action_text="Confirm",
|
||||
)
|
||||
|
||||
# Act
|
||||
content = HumanInputContent(
|
||||
workflow_run_id="workflow-run-1",
|
||||
submitted=True,
|
||||
form_definition=form_definition,
|
||||
form_submission_data=submission_data,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert form_definition.model_config.get("frozen") is True
|
||||
assert content.type == ExecutionContentType.HUMAN_INPUT
|
||||
assert content.form_definition is form_definition
|
||||
assert content.form_submission_data is submission_data
|
||||
assert ExecutionExtraContentDomainModel is HumanInputContent
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from core.entities.knowledge_entities import (
|
||||
PipelineDataset,
|
||||
PipelineDocument,
|
||||
PipelineGenerateResponse,
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_dataset_normalizes_none_description() -> None:
|
||||
# Arrange / Act
|
||||
dataset = PipelineDataset(
|
||||
id="dataset-1",
|
||||
name="Dataset",
|
||||
description=None,
|
||||
chunk_structure="parent-child",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert dataset.description == ""
|
||||
|
||||
|
||||
def test_pipeline_generate_response_builds_nested_models() -> None:
|
||||
# Arrange
|
||||
dataset = PipelineDataset(
|
||||
id="dataset-1",
|
||||
name="Dataset",
|
||||
description="Knowledge base",
|
||||
chunk_structure="parent-child",
|
||||
)
|
||||
document = PipelineDocument(
|
||||
id="doc-1",
|
||||
position=1,
|
||||
data_source_type="file",
|
||||
data_source_info={"name": "spec.pdf"},
|
||||
name="spec.pdf",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Act
|
||||
response = PipelineGenerateResponse(batch="batch-1", dataset=dataset, documents=[document])
|
||||
|
||||
# Assert
|
||||
assert response.batch == "batch-1"
|
||||
assert response.dataset.id == "dataset-1"
|
||||
assert response.documents[0].id == "doc-1"
|
||||
|
|
@ -0,0 +1,450 @@
|
|||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities import mcp_provider as mcp_provider_module
|
||||
from core.entities.mcp_provider import (
|
||||
DEFAULT_EXPIRES_IN,
|
||||
DEFAULT_TOKEN_TYPE,
|
||||
MCPProviderEntity,
|
||||
)
|
||||
from core.mcp.types import OAuthTokens
|
||||
|
||||
|
||||
def _build_mcp_provider_entity() -> MCPProviderEntity:
|
||||
now = datetime(2025, 1, 1, tzinfo=UTC)
|
||||
return MCPProviderEntity(
|
||||
id="provider-1",
|
||||
provider_id="server-1",
|
||||
name="Example MCP",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
server_url="encrypted-server-url",
|
||||
headers={},
|
||||
timeout=30,
|
||||
sse_read_timeout=300,
|
||||
authed=False,
|
||||
credentials={},
|
||||
tools=[],
|
||||
icon={"en_US": "icon.png"},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def test_from_db_model_maps_fields() -> None:
|
||||
# Arrange
|
||||
now = datetime(2025, 1, 1, tzinfo=UTC)
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
server_identifier="server-1",
|
||||
name="Example MCP",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
server_url="encrypted-server-url",
|
||||
headers={"Authorization": "enc"},
|
||||
timeout=15,
|
||||
sse_read_timeout=120,
|
||||
authed=True,
|
||||
credentials={"access_token": "enc-token"},
|
||||
tool_dict=[{"name": "search"}],
|
||||
icon=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Act
|
||||
entity = MCPProviderEntity.from_db_model(db_provider)
|
||||
|
||||
# Assert
|
||||
assert entity.provider_id == "server-1"
|
||||
assert entity.tools == [{"name": "search"}]
|
||||
assert entity.icon == ""
|
||||
|
||||
|
||||
def test_redirect_url_uses_console_api_url(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
monkeypatch.setattr(mcp_provider_module.dify_config, "CONSOLE_API_URL", "https://console.example.com")
|
||||
|
||||
# Act
|
||||
redirect_url = entity.redirect_url
|
||||
|
||||
# Assert
|
||||
assert redirect_url == "https://console.example.com/console/api/mcp/oauth/callback"
|
||||
|
||||
|
||||
def test_client_metadata_for_authorization_code_flow() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
|
||||
# Act
|
||||
metadata = entity.client_metadata
|
||||
|
||||
# Assert
|
||||
assert metadata.grant_types == ["refresh_token", "authorization_code"]
|
||||
assert metadata.redirect_uris == [entity.redirect_url]
|
||||
assert metadata.response_types == ["code"]
|
||||
|
||||
|
||||
def test_client_metadata_for_client_credentials_flow() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
credentials = {"client_information": {"grant_types": ["client_credentials"]}}
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
|
||||
# Act
|
||||
metadata = entity.client_metadata
|
||||
|
||||
# Assert
|
||||
assert metadata.grant_types == ["refresh_token", "client_credentials"]
|
||||
assert metadata.redirect_uris == []
|
||||
assert metadata.response_types == []
|
||||
|
||||
|
||||
def test_client_metadata_prefers_nested_authorization_code_grant_type() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
credentials = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_information": {"grant_types": ["authorization_code"]},
|
||||
}
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
|
||||
# Act
|
||||
metadata = entity.client_metadata
|
||||
|
||||
# Assert
|
||||
assert metadata.grant_types == ["refresh_token", "authorization_code"]
|
||||
assert metadata.redirect_uris == [entity.redirect_url]
|
||||
assert metadata.response_types == ["code"]
|
||||
|
||||
|
||||
def test_provider_icon_returns_icon_dict_as_is() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}})
|
||||
|
||||
# Act
|
||||
icon = entity.provider_icon
|
||||
|
||||
# Assert
|
||||
assert icon == {"en_US": "icon.png"}
|
||||
|
||||
|
||||
def test_provider_icon_uses_signed_url_for_plain_path() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"icon": "icons/mcp.png"})
|
||||
|
||||
with patch(
|
||||
"core.entities.mcp_provider.file_helpers.get_signed_file_url",
|
||||
return_value="https://signed.example.com/icons/mcp.png",
|
||||
) as mock_get_signed_url:
|
||||
# Act
|
||||
icon = entity.provider_icon
|
||||
|
||||
# Assert
|
||||
mock_get_signed_url.assert_called_once_with("icons/mcp.png")
|
||||
assert icon == "https://signed.example.com/icons/mcp.png"
|
||||
|
||||
|
||||
def test_to_api_response_without_sensitive_data_skips_auth_related_work() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}})
|
||||
|
||||
with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"):
|
||||
# Act
|
||||
response = entity.to_api_response(include_sensitive=False)
|
||||
|
||||
# Assert
|
||||
assert response["author"] == "Anonymous"
|
||||
assert response["masked_headers"] == {}
|
||||
assert response["is_dynamic_registration"] is True
|
||||
assert "authentication" not in response
|
||||
|
||||
|
||||
def test_to_api_response_with_sensitive_data_includes_masked_values() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(
|
||||
update={
|
||||
"credentials": {"client_information": {"is_dynamic_registration": False}},
|
||||
"icon": {"en_US": "icon.png"},
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"):
|
||||
with patch.object(MCPProviderEntity, "masked_headers", return_value={"Authorization": "Be****"}):
|
||||
with patch.object(MCPProviderEntity, "masked_credentials", return_value={"client_id": "cl****"}):
|
||||
# Act
|
||||
response = entity.to_api_response(user_name="Rajat", include_sensitive=True)
|
||||
|
||||
# Assert
|
||||
assert response["author"] == "Rajat"
|
||||
assert response["masked_headers"] == {"Authorization": "Be****"}
|
||||
assert response["authentication"] == {"client_id": "cl****"}
|
||||
assert response["is_dynamic_registration"] is False
|
||||
|
||||
|
||||
def test_retrieve_client_information_decrypts_nested_secret() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
credentials = {"client_information": {"client_id": "client-1", "encrypted_client_secret": "enc-secret"}}
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
|
||||
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="plain-secret") as mock_decrypt:
|
||||
# Act
|
||||
client_info = entity.retrieve_client_information()
|
||||
|
||||
# Assert
|
||||
assert client_info is not None
|
||||
assert client_info.client_id == "client-1"
|
||||
assert client_info.client_secret == "plain-secret"
|
||||
mock_decrypt.assert_called_once_with("tenant-1", "enc-secret")
|
||||
|
||||
|
||||
def test_retrieve_client_information_returns_none_for_missing_data() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
|
||||
# Act
|
||||
result_empty = entity.retrieve_client_information()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}):
|
||||
# Act
|
||||
result_invalid = entity.retrieve_client_information()
|
||||
|
||||
# Assert
|
||||
assert result_empty is None
|
||||
assert result_invalid is None
|
||||
|
||||
|
||||
def test_masked_server_url_hides_path_segments() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(
|
||||
MCPProviderEntity,
|
||||
"decrypt_server_url",
|
||||
return_value="https://api.example.com/v1/mcp?query=1",
|
||||
):
|
||||
# Act
|
||||
masked_url = entity.masked_server_url()
|
||||
|
||||
# Assert
|
||||
assert masked_url == "https://api.example.com/******?query=1"
|
||||
|
||||
|
||||
def test_mask_value_covers_short_and_long_values() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
# Act
|
||||
short_masked = entity._mask_value("short")
|
||||
long_masked = entity._mask_value("abcdefghijkl")
|
||||
|
||||
# Assert
|
||||
assert short_masked == "*****"
|
||||
assert long_masked == "ab********kl"
|
||||
|
||||
|
||||
def test_masked_headers_masks_all_decrypted_header_values() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "abcdefgh"}):
|
||||
# Act
|
||||
masked = entity.masked_headers()
|
||||
|
||||
# Assert
|
||||
assert masked == {"Authorization": "ab****gh"}
|
||||
|
||||
|
||||
def test_masked_credentials_handles_nested_secret_fields() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
credentials = {
|
||||
"client_information": {
|
||||
"client_id": "client-id",
|
||||
"encrypted_client_secret": "encrypted-value",
|
||||
"client_secret": "plain-secret",
|
||||
}
|
||||
}
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
|
||||
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="decrypted-secret"):
|
||||
# Act
|
||||
masked = entity.masked_credentials()
|
||||
|
||||
# Assert
|
||||
assert masked["client_id"] == "cl*****id"
|
||||
assert masked["client_secret"] == "pl********et"
|
||||
|
||||
|
||||
def test_masked_credentials_returns_empty_for_missing_client_information() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
|
||||
# Act
|
||||
masked_empty = entity.masked_credentials()
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}):
|
||||
# Act
|
||||
masked_invalid = entity.masked_credentials()
|
||||
|
||||
# Assert
|
||||
assert masked_empty == {}
|
||||
assert masked_invalid == {}
|
||||
|
||||
|
||||
def test_retrieve_tokens_returns_defaults_when_optional_fields_missing() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}})
|
||||
|
||||
with patch.object(
|
||||
MCPProviderEntity,
|
||||
"decrypt_credentials",
|
||||
return_value={"access_token": "token", "expires_in": "", "refresh_token": "refresh"},
|
||||
):
|
||||
# Act
|
||||
tokens = entity.retrieve_tokens()
|
||||
|
||||
# Assert
|
||||
assert isinstance(tokens, OAuthTokens)
|
||||
assert tokens.access_token == "token"
|
||||
assert tokens.token_type == DEFAULT_TOKEN_TYPE
|
||||
assert tokens.expires_in == DEFAULT_EXPIRES_IN
|
||||
assert tokens.refresh_token == "refresh"
|
||||
|
||||
|
||||
def test_retrieve_tokens_returns_none_when_access_token_missing() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}})
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"access_token": ""}) as mock_decrypt:
|
||||
# Act
|
||||
tokens = entity.retrieve_tokens()
|
||||
|
||||
# Assert
|
||||
mock_decrypt.assert_called_once()
|
||||
assert tokens is None
|
||||
|
||||
|
||||
def test_decrypt_server_url_delegates_to_encrypter() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="https://api.example.com") as mock:
|
||||
# Act
|
||||
decrypted = entity.decrypt_server_url()
|
||||
|
||||
# Assert
|
||||
mock.assert_called_once_with("tenant-1", "encrypted-server-url")
|
||||
assert decrypted == "https://api.example.com"
|
||||
|
||||
|
||||
def test_decrypt_authentication_injects_authorization_for_oauth() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(update={"authed": True, "headers": {}})
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={}):
|
||||
with patch.object(
|
||||
MCPProviderEntity,
|
||||
"retrieve_tokens",
|
||||
return_value=OAuthTokens(access_token="abc123", token_type="bearer"),
|
||||
):
|
||||
# Act
|
||||
headers = entity.decrypt_authentication()
|
||||
|
||||
# Assert
|
||||
assert headers["Authorization"] == "Bearer abc123"
|
||||
|
||||
|
||||
def test_decrypt_authentication_does_not_overwrite_existing_headers() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity().model_copy(
|
||||
update={"authed": True, "headers": {"Authorization": "encrypted-header"}}
|
||||
)
|
||||
|
||||
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "existing"}):
|
||||
with patch.object(
|
||||
MCPProviderEntity,
|
||||
"retrieve_tokens",
|
||||
return_value=OAuthTokens(access_token="abc", token_type="bearer"),
|
||||
) as mock_tokens:
|
||||
# Act
|
||||
headers = entity.decrypt_authentication()
|
||||
|
||||
# Assert
|
||||
mock_tokens.assert_not_called()
|
||||
assert headers == {"Authorization": "existing"}
|
||||
|
||||
|
||||
def test_decrypt_dict_returns_empty_for_empty_input() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
# Act
|
||||
decrypted = entity._decrypt_dict({})
|
||||
|
||||
# Assert
|
||||
assert decrypted == {}
|
||||
|
||||
|
||||
def test_decrypt_dict_returns_original_data_when_no_encrypted_fields() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
input_data = {"nested": {"k": "v"}, "count": 2, "empty": ""}
|
||||
|
||||
# Act
|
||||
result = entity._decrypt_dict(input_data)
|
||||
|
||||
# Assert
|
||||
assert result is input_data
|
||||
|
||||
|
||||
def test_decrypt_dict_only_decrypts_top_level_string_values() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
decryptor = Mock()
|
||||
decryptor.decrypt.return_value = {"api_key": "plain-key"}
|
||||
|
||||
def _fake_create_provider_encrypter(*, tenant_id: str, config: list, cache):
|
||||
assert tenant_id == "tenant-1"
|
||||
assert any(item.name == "api_key" for item in config)
|
||||
return decryptor, None
|
||||
|
||||
with patch("core.tools.utils.encryption.create_provider_encrypter", side_effect=_fake_create_provider_encrypter):
|
||||
# Act
|
||||
result = entity._decrypt_dict(
|
||||
{
|
||||
"api_key": "encrypted-key",
|
||||
"nested": {"client_id": "unchanged"},
|
||||
"empty": "",
|
||||
"count": 2,
|
||||
}
|
||||
)
|
||||
|
||||
# Assert
|
||||
decryptor.decrypt.assert_called_once_with({"api_key": "encrypted-key"})
|
||||
assert result["api_key"] == "plain-key"
|
||||
assert result["nested"] == {"client_id": "unchanged"}
|
||||
assert result["count"] == 2
|
||||
|
||||
|
||||
def test_decrypt_headers_and_credentials_delegate_to_decrypt_dict() -> None:
|
||||
# Arrange
|
||||
entity = _build_mcp_provider_entity()
|
||||
|
||||
with patch.object(MCPProviderEntity, "_decrypt_dict", side_effect=[{"h": "v"}, {"c": "v"}]) as mock:
|
||||
# Act
|
||||
headers = entity.decrypt_headers()
|
||||
credentials = entity.decrypt_credentials()
|
||||
|
||||
# Assert
|
||||
assert mock.call_count == 2
|
||||
assert headers == {"h": "v"}
|
||||
assert credentials == {"c": "v"}
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
"""Unit tests for model entity behavior and invariants.
|
||||
|
||||
Covers DefaultModelEntity, DefaultModelProviderEntity, ModelStatus,
|
||||
ProviderModelWithStatusEntity, and SimpleModelProviderEntity. Assumes i18n
|
||||
labels are provided via I18nObject, model metadata aligns with FetchFrom and
|
||||
ModelType expectations, and ProviderEntity/ConfigurateMethod interactions
|
||||
drive provider mapping behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.model_entities import (
|
||||
DefaultModelEntity,
|
||||
DefaultModelProviderEntity,
|
||||
ModelStatus,
|
||||
ProviderModelWithStatusEntity,
|
||||
SimpleModelProviderEntity,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
|
||||
|
||||
def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity:
|
||||
return ProviderModelWithStatusEntity(
|
||||
model="gpt-4",
|
||||
label=I18nObject(en_US="GPT-4"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
|
||||
# Arrange
|
||||
provider_entity = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
|
||||
# Act
|
||||
simple_provider = SimpleModelProviderEntity(provider_entity)
|
||||
|
||||
# Assert
|
||||
assert simple_provider.provider == "openai"
|
||||
assert simple_provider.label.en_US == "OpenAI"
|
||||
assert simple_provider.supported_model_types == [ModelType.LLM]
|
||||
|
||||
|
||||
def test_provider_model_with_status_raises_for_known_error_statuses() -> None:
|
||||
# Arrange
|
||||
expectations = {
|
||||
ModelStatus.NO_CONFIGURE: "Model is not configured",
|
||||
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
|
||||
ModelStatus.NO_PERMISSION: "No permission to use this model",
|
||||
ModelStatus.DISABLED: "Model is disabled",
|
||||
}
|
||||
|
||||
for status, message in expectations.items():
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_build_model_with_status(status).raise_for_status()
|
||||
|
||||
|
||||
def test_provider_model_with_status_allows_active_and_credential_removed() -> None:
|
||||
# Arrange
|
||||
active_model = _build_model_with_status(ModelStatus.ACTIVE)
|
||||
removed_model = _build_model_with_status(ModelStatus.CREDENTIAL_REMOVED)
|
||||
|
||||
# Act / Assert
|
||||
active_model.raise_for_status()
|
||||
removed_model.raise_for_status()
|
||||
|
||||
|
||||
def test_default_model_entity_accepts_model_field_name() -> None:
|
||||
# Arrange / Act
|
||||
default_model = DefaultModelEntity(
|
||||
model="gpt-4o-mini",
|
||||
model_type=ModelType.LLM,
|
||||
provider=DefaultModelProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert default_model.model == "gpt-4o-mini"
|
||||
assert default_model.provider.provider == "openai"
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from core.entities.parameter_entities import (
|
||||
AppSelectorScope,
|
||||
CommonParameterType,
|
||||
ModelSelectorScope,
|
||||
ToolSelectorScope,
|
||||
)
|
||||
|
||||
|
||||
def test_common_parameter_type_values_are_stable() -> None:
|
||||
# Arrange / Act / Assert
|
||||
assert CommonParameterType.SECRET_INPUT.value == "secret-input"
|
||||
assert CommonParameterType.MODEL_SELECTOR.value == "model-selector"
|
||||
assert CommonParameterType.DYNAMIC_SELECT.value == "dynamic-select"
|
||||
assert CommonParameterType.ARRAY.value == "array"
|
||||
assert CommonParameterType.OBJECT.value == "object"
|
||||
|
||||
|
||||
def test_selector_scope_values_are_stable() -> None:
|
||||
# Arrange / Act / Assert
|
||||
assert AppSelectorScope.WORKFLOW.value == "workflow"
|
||||
assert ModelSelectorScope.TEXT_EMBEDDING.value == "text-embedding"
|
||||
assert ToolSelectorScope.BUILTIN.value == "builtin"
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,72 @@
|
|||
import pytest
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope
|
||||
from core.entities.provider_entities import (
|
||||
BasicProviderConfig,
|
||||
ModelSettings,
|
||||
ProviderConfig,
|
||||
ProviderQuotaType,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
def test_provider_quota_type_value_of_returns_enum_member() -> None:
|
||||
# Arrange / Act
|
||||
quota_type = ProviderQuotaType.value_of(ProviderQuotaType.TRIAL.value)
|
||||
|
||||
# Assert
|
||||
assert quota_type == ProviderQuotaType.TRIAL
|
||||
|
||||
|
||||
def test_provider_quota_type_value_of_rejects_unknown_values() -> None:
|
||||
# Arrange / Act / Assert
|
||||
with pytest.raises(ValueError, match="No matching enum found"):
|
||||
ProviderQuotaType.value_of("enterprise")
|
||||
|
||||
|
||||
def test_basic_provider_config_type_value_of_handles_known_values() -> None:
|
||||
# Arrange / Act
|
||||
parameter_type = BasicProviderConfig.Type.value_of("text-input")
|
||||
|
||||
# Assert
|
||||
assert parameter_type == BasicProviderConfig.Type.TEXT_INPUT
|
||||
|
||||
|
||||
def test_basic_provider_config_type_value_of_rejects_invalid_values() -> None:
|
||||
# Arrange / Act / Assert
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
BasicProviderConfig.Type.value_of("unknown")
|
||||
|
||||
|
||||
def test_provider_config_to_basic_provider_config_keeps_type_and_name() -> None:
|
||||
# Arrange
|
||||
provider_config = ProviderConfig(
|
||||
type=BasicProviderConfig.Type.SELECT,
|
||||
name="workspace",
|
||||
scope=AppSelectorScope.ALL,
|
||||
options=[ProviderConfig.Option(value="all", label=I18nObject(en_US="All"))],
|
||||
)
|
||||
|
||||
# Act
|
||||
basic_config = provider_config.to_basic_provider_config()
|
||||
|
||||
# Assert
|
||||
assert isinstance(basic_config, BasicProviderConfig)
|
||||
assert basic_config.type == BasicProviderConfig.Type.SELECT
|
||||
assert basic_config.name == "workspace"
|
||||
|
||||
|
||||
def test_model_settings_accepts_model_field_name() -> None:
|
||||
# Arrange / Act
|
||||
settings = ModelSettings(
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
load_balancing_configs=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert settings.model == "gpt-4o"
|
||||
assert settings.model_type == ModelType.LLM
|
||||
Loading…
Reference in New Issue