test: unit test cases for core.callback, core.base, core.entities module (#32471)

This commit is contained in:
Rajat Agarwal 2026-03-12 08:39:08 +05:30 committed by GitHub
parent 36c1f4d506
commit c59685748c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 3543 additions and 0 deletions

View File

@ -0,0 +1,390 @@
import base64
import queue
from unittest.mock import MagicMock
import pytest
from core.base.tts.app_generator_tts_publisher import (
AppGeneratorTTSPublisher,
AudioTrunk,
_invoice_tts,
_process_future,
)
# =========================
# Fixtures
# =========================
@pytest.fixture
def mock_model_instance(mocker):
model = mocker.MagicMock()
model.invoke_tts.return_value = [b"audio1", b"audio2"]
model.get_tts_voices.return_value = [{"value": "voice1"}, {"value": "voice2"}]
return model
@pytest.fixture
def mock_model_manager(mocker, mock_model_instance):
manager = mocker.MagicMock()
manager.get_default_model_instance.return_value = mock_model_instance
mocker.patch(
"core.base.tts.app_generator_tts_publisher.ModelManager",
return_value=manager,
)
return manager
@pytest.fixture(autouse=True)
def patch_threads(mocker):
"""Prevent real threads from starting during tests"""
mocker.patch("threading.Thread.start", return_value=None)
# =========================
# AudioTrunk Tests
# =========================
class TestAudioTrunk:
def test_audio_trunk_initialization(self):
trunk = AudioTrunk("responding", b"data")
assert trunk.status == "responding"
assert trunk.audio == b"data"
# =========================
# _invoice_tts Tests
# =========================
class TestInvoiceTTS:
@pytest.mark.parametrize(
"text",
[None, "", " "],
)
def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance):
result = _invoice_tts(text, mock_model_instance, "tenant", "voice1")
assert result is None
mock_model_instance.invoke_tts.assert_not_called()
def test_invoice_tts_valid_text(self, mock_model_instance):
result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1")
mock_model_instance.invoke_tts.assert_called_once_with(
content_text="hello",
user="responding_tts",
tenant_id="tenant",
voice="voice1",
)
assert result == [b"audio1", b"audio2"]
# =========================
# _process_future Tests
# =========================
class TestProcessFuture:
def test_process_future_normal_flow(self):
future_queue = queue.Queue()
audio_queue = queue.Queue()
future = MagicMock()
future.result.return_value = [b"abc"]
future_queue.put(future)
future_queue.put(None)
_process_future(future_queue, audio_queue)
first = audio_queue.get()
assert first.status == "responding"
assert first.audio == base64.b64encode(b"abc")
finish = audio_queue.get()
assert finish.status == "finish"
def test_process_future_empty_result(self):
future_queue = queue.Queue()
audio_queue = queue.Queue()
future = MagicMock()
future.result.return_value = None
future_queue.put(future)
future_queue.put(None)
_process_future(future_queue, audio_queue)
finish = audio_queue.get()
assert finish.status == "finish"
def test_process_future_exception(self, mocker):
future_queue = queue.Queue()
audio_queue = queue.Queue()
future = MagicMock()
future.result.side_effect = Exception("error")
future_queue.put(future)
_process_future(future_queue, audio_queue)
finish = audio_queue.get()
assert finish.status == "finish"
# =========================
# AppGeneratorTTSPublisher Tests
# =========================
class TestAppGeneratorTTSPublisher:
def test_initialization_valid_voice(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
assert publisher.voice == "voice1"
assert publisher.max_sentence == 2
assert publisher.msg_text == ""
def test_initialization_invalid_voice_fallback(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "invalid_voice")
assert publisher.voice == "voice1"
def test_publish_puts_message_in_queue(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
message = MagicMock()
publisher.publish(message)
assert publisher._msg_queue.get() == message
def test_check_and_get_audio_no_audio(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
result = publisher.check_and_get_audio()
assert result is None
def test_check_and_get_audio_non_finish_event(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
trunk = AudioTrunk("responding", b"abc")
publisher._audio_queue.put(trunk)
result = publisher.check_and_get_audio()
assert result.status == "responding"
assert publisher._last_audio_event == trunk
def test_check_and_get_audio_finish_event(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
finish_trunk = AudioTrunk("finish", b"")
publisher._audio_queue.put(finish_trunk)
result = publisher.check_and_get_audio()
assert result.status == "finish"
publisher.executor.shutdown.assert_called_once()
def test_check_and_get_audio_cached_finish(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
publisher._last_audio_event = AudioTrunk("finish", b"")
result = publisher.check_and_get_audio()
assert result.status == "finish"
publisher.executor.shutdown.assert_called_once()
@pytest.mark.parametrize(
("text", "expected_sentences", "expected_remaining"),
[
("Hello world.", ["Hello world."], ""),
("Hello world! How are you?", ["Hello world!", " How are you?"], ""),
("No punctuation", [], "No punctuation"),
("", [], ""),
],
)
def test_extract_sentence(self, mock_model_manager, text, expected_sentences, expected_remaining):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
sentences, remaining = publisher._extract_sentence(text)
assert sentences == expected_sentences
assert remaining == expected_remaining
def test_runtime_handles_none_message_with_buffer(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
publisher.msg_text = "Hello."
publisher._msg_queue.put(None)
publisher._runtime()
publisher.executor.submit.assert_called_once()
def test_runtime_handles_none_message_without_buffer(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
publisher.msg_text = " "
publisher._msg_queue.put(None)
publisher._runtime()
publisher.executor.submit.assert_not_called()
def test_runtime_sentence_threshold_triggers_submit(self, mock_model_manager, mocker):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
# Force sentence extraction to hit threshold condition
mocker.patch.object(
publisher,
"_extract_sentence",
return_value=(["Hello world.", " Second sentence."], ""),
)
from core.app.entities.queue_entities import QueueTextChunkEvent
event = MagicMock()
event.event = MagicMock(spec=QueueTextChunkEvent)
event.event.text = "Hello world. Second sentence."
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.executor.submit.called
def test_runtime_handles_text_chunk_event(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueTextChunkEvent
event = MagicMock()
event.event = MagicMock(spec=QueueTextChunkEvent)
event.event.text = "Hello world."
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.executor.submit.called
def test_runtime_handles_node_succeeded_event_with_output(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueNodeSucceededEvent
event = MagicMock()
event.event = MagicMock(spec=QueueNodeSucceededEvent)
event.event.outputs = {"output": "Hello world."}
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.executor.submit.called
def test_runtime_handles_node_succeeded_event_without_output(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueNodeSucceededEvent
event = MagicMock()
event.event = MagicMock(spec=QueueNodeSucceededEvent)
event.event.outputs = None
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
publisher.executor.submit.assert_not_called()
def test_runtime_handles_agent_message_event_list_content(self, mock_model_manager, mocker):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueAgentMessageEvent
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
TextPromptMessageContent,
)
chunk = LLMResultChunk(
model="model",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data="Hello "),
ImagePromptMessageContent(format="png", mime_type="image/png", base64_data="a"),
]
),
),
)
event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk))
mocker.patch.object(publisher, "_extract_sentence", return_value=([], ""))
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.msg_text == "Hello "
def test_runtime_handles_agent_message_event_empty_content(self, mock_model_manager, mocker):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueAgentMessageEvent
from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
chunk = LLMResultChunk(
model="model",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=""),
),
)
event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk))
mocker.patch.object(publisher, "_extract_sentence", return_value=([], ""))
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.msg_text == ""
def test_runtime_resets_msg_text_when_text_tmp_not_str(self, mock_model_manager, mocker):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher.executor = MagicMock()
from core.app.entities.queue_entities import QueueTextChunkEvent
event = MagicMock()
event.event = MagicMock(spec=QueueTextChunkEvent)
event.event.text = "Hello world. Another sentence."
mocker.patch.object(publisher, "_extract_sentence", return_value=(["A.", "B."], None))
publisher._msg_queue.put(event)
publisher._msg_queue.put(None)
publisher._runtime()
assert publisher.msg_text == ""
def test_runtime_exception_path(self, mock_model_manager):
publisher = AppGeneratorTTSPublisher("tenant", "voice1")
publisher._msg_queue = MagicMock()
publisher._msg_queue.get.side_effect = Exception("error")
publisher._runtime()

View File

@ -0,0 +1,197 @@
from unittest.mock import MagicMock
import pytest
import core.callback_handler.agent_tool_callback_handler as module
# -----------------------------
# Fixtures
# -----------------------------
@pytest.fixture
def enable_debug(mocker):
mocker.patch.object(module.dify_config, "DEBUG", True)
@pytest.fixture
def disable_debug(mocker):
mocker.patch.object(module.dify_config, "DEBUG", False)
@pytest.fixture
def mock_print(mocker):
return mocker.patch("builtins.print")
@pytest.fixture
def handler():
return module.DifyAgentCallbackHandler(color="blue")
# -----------------------------
# get_colored_text Tests
# -----------------------------
class TestGetColoredText:
@pytest.mark.parametrize(
("color", "expected_code"),
[
("blue", "36;1"),
("yellow", "33;1"),
("pink", "38;5;200"),
("green", "32;1"),
("red", "31;1"),
],
)
def test_get_colored_text_valid_colors(self, color, expected_code):
text = "hello"
result = module.get_colored_text(text, color)
assert expected_code in result
assert text in result
assert result.endswith("\u001b[0m")
def test_get_colored_text_invalid_color_raises(self):
with pytest.raises(KeyError):
module.get_colored_text("hello", "invalid")
def test_get_colored_text_empty_string(self):
result = module.get_colored_text("", "green")
assert "\u001b[" in result
# -----------------------------
# print_text Tests
# -----------------------------
class TestPrintText:
def test_print_text_without_color(self, mock_print):
module.print_text("hello")
mock_print.assert_called_once_with("hello", end="", file=None)
def test_print_text_with_color(self, mocker, mock_print):
mock_get_color = mocker.patch(
"core.callback_handler.agent_tool_callback_handler.get_colored_text",
return_value="colored_text",
)
module.print_text("hello", color="green")
mock_get_color.assert_called_once_with("hello", "green")
mock_print.assert_called_once_with("colored_text", end="", file=None)
def test_print_text_with_file_flush(self, mocker):
mock_file = MagicMock()
mock_print = mocker.patch("builtins.print")
module.print_text("hello", file=mock_file)
mock_print.assert_called_once_with("hello", end="", file=mock_file)
mock_file.flush.assert_called_once()
def test_print_text_with_end_parameter(self, mock_print):
module.print_text("hello", end="\n")
mock_print.assert_called_once_with("hello", end="\n", file=None)
# -----------------------------
# DifyAgentCallbackHandler Tests
# -----------------------------
class TestDifyAgentCallbackHandler:
def test_init_default_color(self):
handler = module.DifyAgentCallbackHandler()
assert handler.color == "green"
assert handler.current_loop == 1
def test_on_tool_start_debug_enabled(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_start("tool1", {"a": 1})
mock_print_text.assert_called()
def test_on_tool_start_debug_disabled(self, handler, disable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_start("tool1", {"a": 1})
mock_print_text.assert_not_called()
def test_on_tool_end_debug_enabled_and_trace(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
mock_trace_manager = MagicMock()
handler.on_tool_end(
tool_name="tool1",
tool_inputs={"a": 1},
tool_outputs="output",
message_id="msg1",
timer=123,
trace_manager=mock_trace_manager,
)
assert mock_print_text.call_count >= 1
mock_trace_manager.add_trace_task.assert_called_once()
def test_on_tool_end_without_trace_manager(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_end(
tool_name="tool1",
tool_inputs={},
tool_outputs="output",
)
assert mock_print_text.call_count >= 1
def test_on_tool_error_debug_enabled(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_error(Exception("error"))
mock_print_text.assert_called_once()
def test_on_tool_error_debug_disabled(self, handler, disable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_tool_error(Exception("error"))
mock_print_text.assert_not_called()
@pytest.mark.parametrize("thought", ["thinking", ""])
def test_on_agent_start(self, handler, enable_debug, mocker, thought):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_agent_start(thought)
mock_print_text.assert_called()
def test_on_agent_finish_increments_loop(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
current_loop = handler.current_loop
handler.on_agent_finish()
assert handler.current_loop == current_loop + 1
mock_print_text.assert_called()
def test_on_datasource_start_debug_enabled(self, handler, enable_debug, mocker):
mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text")
handler.on_datasource_start("ds1", {"x": 1})
mock_print_text.assert_called_once()
def test_ignore_agent_property(self, disable_debug, handler):
assert handler.ignore_agent is True
def test_ignore_chat_model_property(self, disable_debug, handler):
assert handler.ignore_chat_model is True
def test_ignore_properties_when_debug_enabled(self, enable_debug, handler):
assert handler.ignore_agent is False
assert handler.ignore_chat_model is False

View File

@ -0,0 +1,162 @@
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import (
DatasetIndexToolCallbackHandler,
)
@pytest.fixture
def mock_queue_manager(mocker):
return mocker.Mock()
@pytest.fixture
def handler(mock_queue_manager, mocker):
mocker.patch(
"core.callback_handler.index_tool_callback_handler.db",
)
return DatasetIndexToolCallbackHandler(
queue_manager=mock_queue_manager,
app_id="app-1",
message_id="msg-1",
user_id="user-1",
invoke_from=mocker.Mock(),
)
class TestOnQuery:
@pytest.mark.parametrize(
("invoke_from", "expected_role"),
[
(InvokeFrom.EXPLORE, "account"),
(InvokeFrom.DEBUGGER, "account"),
(InvokeFrom.WEB_APP, "end_user"),
],
)
def test_on_query_success_roles(self, mocker, mock_queue_manager, invoke_from, expected_role):
# Arrange
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
handler = DatasetIndexToolCallbackHandler(
queue_manager=mock_queue_manager,
app_id="app-1",
message_id="msg-1",
user_id="user-1",
invoke_from=mocker.Mock(),
)
handler._invoke_from = invoke_from
# Act
handler.on_query("test query", "dataset-1")
# Assert
mock_db.session.add.assert_called_once()
dataset_query = mock_db.session.add.call_args.args[0]
assert dataset_query.created_by_role == expected_role
mock_db.session.commit.assert_called_once()
def test_on_query_none_values(self, mocker, mock_queue_manager):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
handler = DatasetIndexToolCallbackHandler(
queue_manager=mock_queue_manager,
app_id=None,
message_id=None,
user_id=None,
invoke_from=None,
)
handler.on_query(None, None)
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
class TestOnToolEnd:
def test_on_tool_end_no_metadata(self, handler, mocker):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
document = mocker.Mock()
document.metadata = None
handler.on_tool_end([document])
mock_db.session.commit.assert_not_called()
def test_on_tool_end_dataset_document_not_found(self, handler, mocker):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
mock_db.session.scalar.return_value = None
document = mocker.Mock()
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
handler.on_tool_end([document])
mock_db.session.scalar.assert_called_once()
def test_on_tool_end_parent_child_index_with_child(self, handler, mocker):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
mock_dataset_doc = mocker.Mock()
from core.callback_handler.index_tool_callback_handler import IndexStructureType
mock_dataset_doc.doc_form = IndexStructureType.PARENT_CHILD_INDEX
mock_dataset_doc.dataset_id = "dataset-1"
mock_dataset_doc.id = "doc-1"
mock_child_chunk = mocker.Mock()
mock_child_chunk.segment_id = "segment-1"
mock_db.session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk]
document = mocker.Mock()
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
mock_query = mocker.Mock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
handler.on_tool_end([document])
mock_query.update.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db")
mock_dataset_doc = mocker.Mock()
mock_dataset_doc.doc_form = "OTHER"
mock_db.session.scalar.return_value = mock_dataset_doc
document = mocker.Mock()
document.metadata = {
"document_id": "doc-1",
"doc_id": "node-1",
"dataset_id": "dataset-1",
}
mock_query = mocker.Mock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
handler.on_tool_end([document])
mock_query.update.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_on_tool_end_empty_documents(self, handler):
handler.on_tool_end([])
class TestReturnRetrieverResourceInfo:
def test_publish_called(self, handler, mock_queue_manager, mocker):
mock_event = mocker.patch("core.callback_handler.index_tool_callback_handler.QueueRetrieverResourcesEvent")
resources = [mocker.Mock()]
handler.return_retriever_resource_info(resources)
mock_queue_manager.publish.assert_called_once()

View File

@ -0,0 +1,184 @@
from unittest.mock import MagicMock, call
import pytest
from core.callback_handler.workflow_tool_callback_handler import (
DifyWorkflowCallbackHandler,
)
class DummyToolInvokeMessage:
"""Lightweight dummy to simulate ToolInvokeMessage behavior."""
def __init__(self, json_value: str):
self._json_value = json_value
def model_dump_json(self):
return self._json_value
@pytest.fixture
def handler():
"""Fixture to create handler instance with deterministic color."""
instance = DifyWorkflowCallbackHandler()
instance.color = "blue"
return instance
@pytest.fixture
def mock_print_text(mocker):
"""Mock print_text to avoid real stdout printing."""
return mocker.patch("core.callback_handler.workflow_tool_callback_handler.print_text")
class TestDifyWorkflowCallbackHandler:
def test_on_tool_execution_single_output_success(self, handler, mock_print_text):
# Arrange
tool_name = "test_tool"
tool_inputs = {"a": 1}
message = DummyToolInvokeMessage('{"key": "value"}')
# Act
results = list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs=tool_inputs,
tool_outputs=[message],
)
)
# Assert
assert results == [message]
assert mock_print_text.call_count == 4
mock_print_text.assert_has_calls(
[
call("\n[on_tool_execution]\n", color="blue"),
call("Tool: test_tool\n", color="blue"),
call(
"Outputs: " + message.model_dump_json()[:1000] + "\n",
color="blue",
),
call("\n"),
]
)
def test_on_tool_execution_multiple_outputs(self, handler, mock_print_text):
# Arrange
tool_name = "multi_tool"
outputs = [
DummyToolInvokeMessage('{"id": 1}'),
DummyToolInvokeMessage('{"id": 2}'),
]
# Act
results = list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=outputs,
)
)
# Assert
assert results == outputs
assert mock_print_text.call_count == 4 * len(outputs)
def test_on_tool_execution_empty_iterable(self, handler, mock_print_text):
# Arrange
tool_name = "empty_tool"
# Act
results = list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=[],
)
)
# Assert
assert results == []
mock_print_text.assert_not_called()
@pytest.mark.parametrize(
("invalid_outputs", "expected_exception"),
[
(None, TypeError),
(123, TypeError),
("not_iterable", AttributeError),
],
)
def test_on_tool_execution_invalid_outputs_type(self, handler, invalid_outputs, expected_exception):
# Arrange
tool_name = "invalid_tool"
# Act & Assert
with pytest.raises(expected_exception):
list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=invalid_outputs,
)
)
def test_on_tool_execution_long_json_truncation(self, handler, mock_print_text):
# Arrange
tool_name = "long_json_tool"
long_json = "x" * 1500
message = DummyToolInvokeMessage(long_json)
# Act
list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=[message],
)
)
# Assert
expected_truncated = long_json[:1000]
mock_print_text.assert_any_call(
"Outputs: " + expected_truncated + "\n",
color="blue",
)
def test_on_tool_execution_model_dump_json_exception(self, handler, mock_print_text):
# Arrange
tool_name = "exception_tool"
bad_message = MagicMock()
bad_message.model_dump_json.side_effect = ValueError("JSON error")
# Act & Assert
with pytest.raises(ValueError):
list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=[bad_message],
)
)
# Ensure first two prints happened before failure
assert mock_print_text.call_count >= 2
def test_on_tool_execution_none_message_id_and_trace_manager(self, handler, mock_print_text):
# Arrange
tool_name = "optional_params_tool"
message = DummyToolInvokeMessage('{"data": "ok"}')
# Act
results = list(
handler.on_tool_execution(
tool_name=tool_name,
tool_inputs={},
tool_outputs=[message],
message_id=None,
timer=None,
trace_manager=None,
)
)
assert results == [message]
assert mock_print_text.call_count == 4

View File

@ -0,0 +1,9 @@
from core.entities.agent_entities import PlanningStrategy
def test_planning_strategy_values_are_stable() -> None:
# Arrange / Act / Assert
assert PlanningStrategy.ROUTER.value == "router"
assert PlanningStrategy.REACT_ROUTER.value == "react_router"
assert PlanningStrategy.REACT.value == "react"
assert PlanningStrategy.FUNCTION_CALL.value == "function_call"

View File

@ -0,0 +1,18 @@
from core.entities.document_task import DocumentTask
def test_document_task_keeps_indexing_identifiers() -> None:
# Arrange
document_ids = ("doc-1", "doc-2")
# Act
task = DocumentTask(
tenant_id="tenant-1",
dataset_id="dataset-1",
document_ids=document_ids,
)
# Assert
assert task.tenant_id == "tenant-1"
assert task.dataset_id == "dataset-1"
assert task.document_ids == document_ids

View File

@ -0,0 +1,7 @@
from core.entities.embedding_type import EmbeddingInputType
def test_embedding_input_type_values_are_stable() -> None:
# Arrange / Act / Assert
assert EmbeddingInputType.DOCUMENT.value == "document"
assert EmbeddingInputType.QUERY.value == "query"

View File

@ -0,0 +1,45 @@
from core.entities.execution_extra_content import (
ExecutionExtraContentDomainModel,
HumanInputContent,
HumanInputFormDefinition,
HumanInputFormSubmissionData,
)
from dify_graph.nodes.human_input.entities import FormInput, UserAction
from dify_graph.nodes.human_input.enums import FormInputType
from models.execution_extra_content import ExecutionContentType
def test_human_input_content_defaults_and_domain_alias() -> None:
# Arrange
form_definition = HumanInputFormDefinition(
form_id="form-1",
node_id="node-1",
node_title="Human Input",
form_content="Please confirm",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="answer")],
actions=[UserAction(id="confirm", title="Confirm")],
resolved_default_values={"answer": "yes"},
expiration_time=1_700_000_000,
)
submission_data = HumanInputFormSubmissionData(
node_id="node-1",
node_title="Human Input",
rendered_content="Please confirm",
action_id="confirm",
action_text="Confirm",
)
# Act
content = HumanInputContent(
workflow_run_id="workflow-run-1",
submitted=True,
form_definition=form_definition,
form_submission_data=submission_data,
)
# Assert
assert form_definition.model_config.get("frozen") is True
assert content.type == ExecutionContentType.HUMAN_INPUT
assert content.form_definition is form_definition
assert content.form_submission_data is submission_data
assert ExecutionExtraContentDomainModel is HumanInputContent

View File

@ -0,0 +1,45 @@
from core.entities.knowledge_entities import (
PipelineDataset,
PipelineDocument,
PipelineGenerateResponse,
)
def test_pipeline_dataset_normalizes_none_description() -> None:
# Arrange / Act
dataset = PipelineDataset(
id="dataset-1",
name="Dataset",
description=None,
chunk_structure="parent-child",
)
# Assert
assert dataset.description == ""
def test_pipeline_generate_response_builds_nested_models() -> None:
# Arrange
dataset = PipelineDataset(
id="dataset-1",
name="Dataset",
description="Knowledge base",
chunk_structure="parent-child",
)
document = PipelineDocument(
id="doc-1",
position=1,
data_source_type="file",
data_source_info={"name": "spec.pdf"},
name="spec.pdf",
indexing_status="completed",
enabled=True,
)
# Act
response = PipelineGenerateResponse(batch="batch-1", dataset=dataset, documents=[document])
# Assert
assert response.batch == "batch-1"
assert response.dataset.id == "dataset-1"
assert response.documents[0].id == "doc-1"

View File

@ -0,0 +1,450 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from core.entities import mcp_provider as mcp_provider_module
from core.entities.mcp_provider import (
DEFAULT_EXPIRES_IN,
DEFAULT_TOKEN_TYPE,
MCPProviderEntity,
)
from core.mcp.types import OAuthTokens
def _build_mcp_provider_entity() -> MCPProviderEntity:
now = datetime(2025, 1, 1, tzinfo=UTC)
return MCPProviderEntity(
id="provider-1",
provider_id="server-1",
name="Example MCP",
tenant_id="tenant-1",
user_id="user-1",
server_url="encrypted-server-url",
headers={},
timeout=30,
sse_read_timeout=300,
authed=False,
credentials={},
tools=[],
icon={"en_US": "icon.png"},
created_at=now,
updated_at=now,
)
def test_from_db_model_maps_fields() -> None:
# Arrange
now = datetime(2025, 1, 1, tzinfo=UTC)
db_provider = SimpleNamespace(
id="provider-1",
server_identifier="server-1",
name="Example MCP",
tenant_id="tenant-1",
user_id="user-1",
server_url="encrypted-server-url",
headers={"Authorization": "enc"},
timeout=15,
sse_read_timeout=120,
authed=True,
credentials={"access_token": "enc-token"},
tool_dict=[{"name": "search"}],
icon=None,
created_at=now,
updated_at=now,
)
# Act
entity = MCPProviderEntity.from_db_model(db_provider)
# Assert
assert entity.provider_id == "server-1"
assert entity.tools == [{"name": "search"}]
assert entity.icon == ""
def test_redirect_url_uses_console_api_url(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
entity = _build_mcp_provider_entity()
monkeypatch.setattr(mcp_provider_module.dify_config, "CONSOLE_API_URL", "https://console.example.com")
# Act
redirect_url = entity.redirect_url
# Assert
assert redirect_url == "https://console.example.com/console/api/mcp/oauth/callback"
def test_client_metadata_for_authorization_code_flow() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
# Act
metadata = entity.client_metadata
# Assert
assert metadata.grant_types == ["refresh_token", "authorization_code"]
assert metadata.redirect_uris == [entity.redirect_url]
assert metadata.response_types == ["code"]
def test_client_metadata_for_client_credentials_flow() -> None:
# Arrange
entity = _build_mcp_provider_entity()
credentials = {"client_information": {"grant_types": ["client_credentials"]}}
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
# Act
metadata = entity.client_metadata
# Assert
assert metadata.grant_types == ["refresh_token", "client_credentials"]
assert metadata.redirect_uris == []
assert metadata.response_types == []
def test_client_metadata_prefers_nested_authorization_code_grant_type() -> None:
# Arrange
entity = _build_mcp_provider_entity()
credentials = {
"grant_type": "client_credentials",
"client_information": {"grant_types": ["authorization_code"]},
}
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
# Act
metadata = entity.client_metadata
# Assert
assert metadata.grant_types == ["refresh_token", "authorization_code"]
assert metadata.redirect_uris == [entity.redirect_url]
assert metadata.response_types == ["code"]
def test_provider_icon_returns_icon_dict_as_is() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}})
# Act
icon = entity.provider_icon
# Assert
assert icon == {"en_US": "icon.png"}
def test_provider_icon_uses_signed_url_for_plain_path() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"icon": "icons/mcp.png"})
with patch(
"core.entities.mcp_provider.file_helpers.get_signed_file_url",
return_value="https://signed.example.com/icons/mcp.png",
) as mock_get_signed_url:
# Act
icon = entity.provider_icon
# Assert
mock_get_signed_url.assert_called_once_with("icons/mcp.png")
assert icon == "https://signed.example.com/icons/mcp.png"
def test_to_api_response_without_sensitive_data_skips_auth_related_work() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}})
with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"):
# Act
response = entity.to_api_response(include_sensitive=False)
# Assert
assert response["author"] == "Anonymous"
assert response["masked_headers"] == {}
assert response["is_dynamic_registration"] is True
assert "authentication" not in response
def test_to_api_response_with_sensitive_data_includes_masked_values() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(
update={
"credentials": {"client_information": {"is_dynamic_registration": False}},
"icon": {"en_US": "icon.png"},
}
)
with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"):
with patch.object(MCPProviderEntity, "masked_headers", return_value={"Authorization": "Be****"}):
with patch.object(MCPProviderEntity, "masked_credentials", return_value={"client_id": "cl****"}):
# Act
response = entity.to_api_response(user_name="Rajat", include_sensitive=True)
# Assert
assert response["author"] == "Rajat"
assert response["masked_headers"] == {"Authorization": "Be****"}
assert response["authentication"] == {"client_id": "cl****"}
assert response["is_dynamic_registration"] is False
def test_retrieve_client_information_decrypts_nested_secret() -> None:
# Arrange
entity = _build_mcp_provider_entity()
credentials = {"client_information": {"client_id": "client-1", "encrypted_client_secret": "enc-secret"}}
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="plain-secret") as mock_decrypt:
# Act
client_info = entity.retrieve_client_information()
# Assert
assert client_info is not None
assert client_info.client_id == "client-1"
assert client_info.client_secret == "plain-secret"
mock_decrypt.assert_called_once_with("tenant-1", "enc-secret")
def test_retrieve_client_information_returns_none_for_missing_data() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
# Act
result_empty = entity.retrieve_client_information()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}):
# Act
result_invalid = entity.retrieve_client_information()
# Assert
assert result_empty is None
assert result_invalid is None
def test_masked_server_url_hides_path_segments() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(
MCPProviderEntity,
"decrypt_server_url",
return_value="https://api.example.com/v1/mcp?query=1",
):
# Act
masked_url = entity.masked_server_url()
# Assert
assert masked_url == "https://api.example.com/******?query=1"
def test_mask_value_covers_short_and_long_values() -> None:
# Arrange
entity = _build_mcp_provider_entity()
# Act
short_masked = entity._mask_value("short")
long_masked = entity._mask_value("abcdefghijkl")
# Assert
assert short_masked == "*****"
assert long_masked == "ab********kl"
def test_masked_headers_masks_all_decrypted_header_values() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "abcdefgh"}):
# Act
masked = entity.masked_headers()
# Assert
assert masked == {"Authorization": "ab****gh"}
def test_masked_credentials_handles_nested_secret_fields() -> None:
# Arrange
entity = _build_mcp_provider_entity()
credentials = {
"client_information": {
"client_id": "client-id",
"encrypted_client_secret": "encrypted-value",
"client_secret": "plain-secret",
}
}
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials):
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="decrypted-secret"):
# Act
masked = entity.masked_credentials()
# Assert
assert masked["client_id"] == "cl*****id"
assert masked["client_secret"] == "pl********et"
def test_masked_credentials_returns_empty_for_missing_client_information() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}):
# Act
masked_empty = entity.masked_credentials()
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}):
# Act
masked_invalid = entity.masked_credentials()
# Assert
assert masked_empty == {}
assert masked_invalid == {}
def test_retrieve_tokens_returns_defaults_when_optional_fields_missing() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}})
with patch.object(
MCPProviderEntity,
"decrypt_credentials",
return_value={"access_token": "token", "expires_in": "", "refresh_token": "refresh"},
):
# Act
tokens = entity.retrieve_tokens()
# Assert
assert isinstance(tokens, OAuthTokens)
assert tokens.access_token == "token"
assert tokens.token_type == DEFAULT_TOKEN_TYPE
assert tokens.expires_in == DEFAULT_EXPIRES_IN
assert tokens.refresh_token == "refresh"
def test_retrieve_tokens_returns_none_when_access_token_missing() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}})
with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"access_token": ""}) as mock_decrypt:
# Act
tokens = entity.retrieve_tokens()
# Assert
mock_decrypt.assert_called_once()
assert tokens is None
def test_decrypt_server_url_delegates_to_encrypter() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="https://api.example.com") as mock:
# Act
decrypted = entity.decrypt_server_url()
# Assert
mock.assert_called_once_with("tenant-1", "encrypted-server-url")
assert decrypted == "https://api.example.com"
def test_decrypt_authentication_injects_authorization_for_oauth() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(update={"authed": True, "headers": {}})
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={}):
with patch.object(
MCPProviderEntity,
"retrieve_tokens",
return_value=OAuthTokens(access_token="abc123", token_type="bearer"),
):
# Act
headers = entity.decrypt_authentication()
# Assert
assert headers["Authorization"] == "Bearer abc123"
def test_decrypt_authentication_does_not_overwrite_existing_headers() -> None:
# Arrange
entity = _build_mcp_provider_entity().model_copy(
update={"authed": True, "headers": {"Authorization": "encrypted-header"}}
)
with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "existing"}):
with patch.object(
MCPProviderEntity,
"retrieve_tokens",
return_value=OAuthTokens(access_token="abc", token_type="bearer"),
) as mock_tokens:
# Act
headers = entity.decrypt_authentication()
# Assert
mock_tokens.assert_not_called()
assert headers == {"Authorization": "existing"}
def test_decrypt_dict_returns_empty_for_empty_input() -> None:
# Arrange
entity = _build_mcp_provider_entity()
# Act
decrypted = entity._decrypt_dict({})
# Assert
assert decrypted == {}
def test_decrypt_dict_returns_original_data_when_no_encrypted_fields() -> None:
# Arrange
entity = _build_mcp_provider_entity()
input_data = {"nested": {"k": "v"}, "count": 2, "empty": ""}
# Act
result = entity._decrypt_dict(input_data)
# Assert
assert result is input_data
def test_decrypt_dict_only_decrypts_top_level_string_values() -> None:
# Arrange
entity = _build_mcp_provider_entity()
decryptor = Mock()
decryptor.decrypt.return_value = {"api_key": "plain-key"}
def _fake_create_provider_encrypter(*, tenant_id: str, config: list, cache):
assert tenant_id == "tenant-1"
assert any(item.name == "api_key" for item in config)
return decryptor, None
with patch("core.tools.utils.encryption.create_provider_encrypter", side_effect=_fake_create_provider_encrypter):
# Act
result = entity._decrypt_dict(
{
"api_key": "encrypted-key",
"nested": {"client_id": "unchanged"},
"empty": "",
"count": 2,
}
)
# Assert
decryptor.decrypt.assert_called_once_with({"api_key": "encrypted-key"})
assert result["api_key"] == "plain-key"
assert result["nested"] == {"client_id": "unchanged"}
assert result["count"] == 2
def test_decrypt_headers_and_credentials_delegate_to_decrypt_dict() -> None:
# Arrange
entity = _build_mcp_provider_entity()
with patch.object(MCPProviderEntity, "_decrypt_dict", side_effect=[{"h": "v"}, {"c": "v"}]) as mock:
# Act
headers = entity.decrypt_headers()
credentials = entity.decrypt_credentials()
# Assert
assert mock.call_count == 2
assert headers == {"h": "v"}
assert credentials == {"c": "v"}

View File

@ -0,0 +1,92 @@
"""Unit tests for model entity behavior and invariants.
Covers DefaultModelEntity, DefaultModelProviderEntity, ModelStatus,
ProviderModelWithStatusEntity, and SimpleModelProviderEntity. Assumes i18n
labels are provided via I18nObject, model metadata aligns with FetchFrom and
ModelType expectations, and ProviderEntity/ConfigurateMethod interactions
drive provider mapping behavior.
"""
import pytest
from core.entities.model_entities import (
DefaultModelEntity,
DefaultModelProviderEntity,
ModelStatus,
ProviderModelWithStatusEntity,
SimpleModelProviderEntity,
)
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType
from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity:
return ProviderModelWithStatusEntity(
model="gpt-4",
label=I18nObject(en_US="GPT-4"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
status=status,
)
def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
# Arrange
provider_entity = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
# Act
simple_provider = SimpleModelProviderEntity(provider_entity)
# Assert
assert simple_provider.provider == "openai"
assert simple_provider.label.en_US == "OpenAI"
assert simple_provider.supported_model_types == [ModelType.LLM]
def test_provider_model_with_status_raises_for_known_error_statuses() -> None:
# Arrange
expectations = {
ModelStatus.NO_CONFIGURE: "Model is not configured",
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
ModelStatus.NO_PERMISSION: "No permission to use this model",
ModelStatus.DISABLED: "Model is disabled",
}
for status, message in expectations.items():
# Act / Assert
with pytest.raises(ValueError, match=message):
_build_model_with_status(status).raise_for_status()
def test_provider_model_with_status_allows_active_and_credential_removed() -> None:
# Arrange
active_model = _build_model_with_status(ModelStatus.ACTIVE)
removed_model = _build_model_with_status(ModelStatus.CREDENTIAL_REMOVED)
# Act / Assert
active_model.raise_for_status()
removed_model.raise_for_status()
def test_default_model_entity_accepts_model_field_name() -> None:
# Arrange / Act
default_model = DefaultModelEntity(
model="gpt-4o-mini",
model_type=ModelType.LLM,
provider=DefaultModelProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
),
)
# Assert
assert default_model.model == "gpt-4o-mini"
assert default_model.provider.provider == "openai"

View File

@ -0,0 +1,22 @@
from core.entities.parameter_entities import (
AppSelectorScope,
CommonParameterType,
ModelSelectorScope,
ToolSelectorScope,
)
def test_common_parameter_type_values_are_stable() -> None:
# Arrange / Act / Assert
assert CommonParameterType.SECRET_INPUT.value == "secret-input"
assert CommonParameterType.MODEL_SELECTOR.value == "model-selector"
assert CommonParameterType.DYNAMIC_SELECT.value == "dynamic-select"
assert CommonParameterType.ARRAY.value == "array"
assert CommonParameterType.OBJECT.value == "object"
def test_selector_scope_values_are_stable() -> None:
# Arrange / Act / Assert
assert AppSelectorScope.WORKFLOW.value == "workflow"
assert ModelSelectorScope.TEXT_EMBEDDING.value == "text-embedding"
assert ToolSelectorScope.BUILTIN.value == "builtin"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,72 @@
import pytest
from core.entities.parameter_entities import AppSelectorScope
from core.entities.provider_entities import (
BasicProviderConfig,
ModelSettings,
ProviderConfig,
ProviderQuotaType,
)
from core.tools.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
def test_provider_quota_type_value_of_returns_enum_member() -> None:
# Arrange / Act
quota_type = ProviderQuotaType.value_of(ProviderQuotaType.TRIAL.value)
# Assert
assert quota_type == ProviderQuotaType.TRIAL
def test_provider_quota_type_value_of_rejects_unknown_values() -> None:
# Arrange / Act / Assert
with pytest.raises(ValueError, match="No matching enum found"):
ProviderQuotaType.value_of("enterprise")
def test_basic_provider_config_type_value_of_handles_known_values() -> None:
# Arrange / Act
parameter_type = BasicProviderConfig.Type.value_of("text-input")
# Assert
assert parameter_type == BasicProviderConfig.Type.TEXT_INPUT
def test_basic_provider_config_type_value_of_rejects_invalid_values() -> None:
# Arrange / Act / Assert
with pytest.raises(ValueError, match="invalid mode value"):
BasicProviderConfig.Type.value_of("unknown")
def test_provider_config_to_basic_provider_config_keeps_type_and_name() -> None:
# Arrange
provider_config = ProviderConfig(
type=BasicProviderConfig.Type.SELECT,
name="workspace",
scope=AppSelectorScope.ALL,
options=[ProviderConfig.Option(value="all", label=I18nObject(en_US="All"))],
)
# Act
basic_config = provider_config.to_basic_provider_config()
# Assert
assert isinstance(basic_config, BasicProviderConfig)
assert basic_config.type == BasicProviderConfig.Type.SELECT
assert basic_config.name == "workspace"
def test_model_settings_accepts_model_field_name() -> None:
# Arrange / Act
settings = ModelSettings(
model="gpt-4o",
model_type=ModelType.LLM,
enabled=True,
load_balancing_enabled=False,
load_balancing_configs=[],
)
# Assert
assert settings.model == "gpt-4o"
assert settings.model_type == ModelType.LLM