diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py new file mode 100644 index 0000000000..3759b6aa37 --- /dev/null +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -0,0 +1,390 @@ +import base64 +import queue +from unittest.mock import MagicMock + +import pytest + +from core.base.tts.app_generator_tts_publisher import ( + AppGeneratorTTSPublisher, + AudioTrunk, + _invoice_tts, + _process_future, +) + +# ========================= +# Fixtures +# ========================= + + +@pytest.fixture +def mock_model_instance(mocker): + model = mocker.MagicMock() + model.invoke_tts.return_value = [b"audio1", b"audio2"] + model.get_tts_voices.return_value = [{"value": "voice1"}, {"value": "voice2"}] + return model + + +@pytest.fixture +def mock_model_manager(mocker, mock_model_instance): + manager = mocker.MagicMock() + manager.get_default_model_instance.return_value = mock_model_instance + mocker.patch( + "core.base.tts.app_generator_tts_publisher.ModelManager", + return_value=manager, + ) + return manager + + +@pytest.fixture(autouse=True) +def patch_threads(mocker): + """Prevent real threads from starting during tests""" + mocker.patch("threading.Thread.start", return_value=None) + + +# ========================= +# AudioTrunk Tests +# ========================= + + +class TestAudioTrunk: + def test_audio_trunk_initialization(self): + trunk = AudioTrunk("responding", b"data") + assert trunk.status == "responding" + assert trunk.audio == b"data" + + +# ========================= +# _invoice_tts Tests +# ========================= + + +class TestInvoiceTTS: + @pytest.mark.parametrize( + "text", + [None, "", " "], + ) + def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): + result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + assert result is None + mock_model_instance.invoke_tts.assert_not_called() + + def test_invoice_tts_valid_text(self, mock_model_instance): + result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="hello", + user="responding_tts", + tenant_id="tenant", + voice="voice1", + ) + assert result == [b"audio1", b"audio2"] + + +# ========================= +# _process_future Tests +# ========================= + + +class TestProcessFuture: + def test_process_future_normal_flow(self): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.return_value = [b"abc"] + + future_queue.put(future) + future_queue.put(None) + + _process_future(future_queue, audio_queue) + + first = audio_queue.get() + assert first.status == "responding" + assert first.audio == base64.b64encode(b"abc") + + finish = audio_queue.get() + assert finish.status == "finish" + + def test_process_future_empty_result(self): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.return_value = None + + future_queue.put(future) + future_queue.put(None) + + _process_future(future_queue, audio_queue) + + finish = audio_queue.get() + assert finish.status == "finish" + + def test_process_future_exception(self, mocker): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.side_effect = Exception("error") + + future_queue.put(future) + + _process_future(future_queue, audio_queue) + + finish = audio_queue.get() + assert finish.status == "finish" + + +# ========================= +# AppGeneratorTTSPublisher Tests +# ========================= + + +class TestAppGeneratorTTSPublisher: + def test_initialization_valid_voice(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + assert publisher.voice == "voice1" + assert publisher.max_sentence == 2 + assert publisher.msg_text == "" + + def test_initialization_invalid_voice_fallback(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "invalid_voice") + assert publisher.voice == "voice1" + + def test_publish_puts_message_in_queue(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + message = MagicMock() + publisher.publish(message) + assert publisher._msg_queue.get() == message + + def test_check_and_get_audio_no_audio(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + result = publisher.check_and_get_audio() + assert result is None + + def test_check_and_get_audio_non_finish_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + trunk = AudioTrunk("responding", b"abc") + publisher._audio_queue.put(trunk) + + result = publisher.check_and_get_audio() + + assert result.status == "responding" + assert publisher._last_audio_event == trunk + + def test_check_and_get_audio_finish_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + finish_trunk = AudioTrunk("finish", b"") + publisher._audio_queue.put(finish_trunk) + + result = publisher.check_and_get_audio() + + assert result.status == "finish" + publisher.executor.shutdown.assert_called_once() + + def test_check_and_get_audio_cached_finish(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher._last_audio_event = AudioTrunk("finish", b"") + + result = publisher.check_and_get_audio() + + assert result.status == "finish" + publisher.executor.shutdown.assert_called_once() + + @pytest.mark.parametrize( + ("text", "expected_sentences", "expected_remaining"), + [ + ("Hello world.", ["Hello world."], ""), + ("Hello world! How are you?", ["Hello world!", " How are you?"], ""), + ("No punctuation", [], "No punctuation"), + ("", [], ""), + ], + ) + def test_extract_sentence(self, mock_model_manager, text, expected_sentences, expected_remaining): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + sentences, remaining = publisher._extract_sentence(text) + assert sentences == expected_sentences + assert remaining == expected_remaining + + def test_runtime_handles_none_message_with_buffer(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher.msg_text = "Hello." + + publisher._msg_queue.put(None) + publisher._runtime() + + publisher.executor.submit.assert_called_once() + + def test_runtime_handles_none_message_without_buffer(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher.msg_text = " " + + publisher._msg_queue.put(None) + publisher._runtime() + + publisher.executor.submit.assert_not_called() + + def test_runtime_sentence_threshold_triggers_submit(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + # Force sentence extraction to hit threshold condition + mocker.patch.object( + publisher, + "_extract_sentence", + return_value=(["Hello world.", " Second sentence."], ""), + ) + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world. Second sentence." + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_text_chunk_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world." + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_node_succeeded_event_with_output(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueNodeSucceededEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueNodeSucceededEvent) + event.event.outputs = {"output": "Hello world."} + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_node_succeeded_event_without_output(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueNodeSucceededEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueNodeSucceededEvent) + event.event.outputs = None + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + publisher.executor.submit.assert_not_called() + + def test_runtime_handles_agent_message_event_list_content(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueAgentMessageEvent + from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + ) + + chunk = LLMResultChunk( + model="model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data="Hello "), + ImagePromptMessageContent(format="png", mime_type="image/png", base64_data="a"), + ] + ), + ), + ) + event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk)) + + mocker.patch.object(publisher, "_extract_sentence", return_value=([], "")) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "Hello " + + def test_runtime_handles_agent_message_event_empty_content(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueAgentMessageEvent + from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + + chunk = LLMResultChunk( + model="model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=""), + ), + ) + event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk)) + + mocker.patch.object(publisher, "_extract_sentence", return_value=([], "")) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "" + + def test_runtime_resets_msg_text_when_text_tmp_not_str(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world. Another sentence." + + mocker.patch.object(publisher, "_extract_sentence", return_value=(["A.", "B."], None)) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "" + + def test_runtime_exception_path(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher._msg_queue = MagicMock() + publisher._msg_queue.get.side_effect = Exception("error") + + publisher._runtime() diff --git a/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py new file mode 100644 index 0000000000..4c1aa33540 --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py @@ -0,0 +1,197 @@ +from unittest.mock import MagicMock + +import pytest + +import core.callback_handler.agent_tool_callback_handler as module + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def enable_debug(mocker): + mocker.patch.object(module.dify_config, "DEBUG", True) + + +@pytest.fixture +def disable_debug(mocker): + mocker.patch.object(module.dify_config, "DEBUG", False) + + +@pytest.fixture +def mock_print(mocker): + return mocker.patch("builtins.print") + + +@pytest.fixture +def handler(): + return module.DifyAgentCallbackHandler(color="blue") + + +# ----------------------------- +# get_colored_text Tests +# ----------------------------- + + +class TestGetColoredText: + @pytest.mark.parametrize( + ("color", "expected_code"), + [ + ("blue", "36;1"), + ("yellow", "33;1"), + ("pink", "38;5;200"), + ("green", "32;1"), + ("red", "31;1"), + ], + ) + def test_get_colored_text_valid_colors(self, color, expected_code): + text = "hello" + result = module.get_colored_text(text, color) + assert expected_code in result + assert text in result + assert result.endswith("\u001b[0m") + + def test_get_colored_text_invalid_color_raises(self): + with pytest.raises(KeyError): + module.get_colored_text("hello", "invalid") + + def test_get_colored_text_empty_string(self): + result = module.get_colored_text("", "green") + assert "\u001b[" in result + + +# ----------------------------- +# print_text Tests +# ----------------------------- + + +class TestPrintText: + def test_print_text_without_color(self, mock_print): + module.print_text("hello") + mock_print.assert_called_once_with("hello", end="", file=None) + + def test_print_text_with_color(self, mocker, mock_print): + mock_get_color = mocker.patch( + "core.callback_handler.agent_tool_callback_handler.get_colored_text", + return_value="colored_text", + ) + + module.print_text("hello", color="green") + + mock_get_color.assert_called_once_with("hello", "green") + mock_print.assert_called_once_with("colored_text", end="", file=None) + + def test_print_text_with_file_flush(self, mocker): + mock_file = MagicMock() + mock_print = mocker.patch("builtins.print") + + module.print_text("hello", file=mock_file) + + mock_print.assert_called_once_with("hello", end="", file=mock_file) + mock_file.flush.assert_called_once() + + def test_print_text_with_end_parameter(self, mock_print): + module.print_text("hello", end="\n") + mock_print.assert_called_once_with("hello", end="\n", file=None) + + +# ----------------------------- +# DifyAgentCallbackHandler Tests +# ----------------------------- + + +class TestDifyAgentCallbackHandler: + def test_init_default_color(self): + handler = module.DifyAgentCallbackHandler() + assert handler.color == "green" + assert handler.current_loop == 1 + + def test_on_tool_start_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_start("tool1", {"a": 1}) + + mock_print_text.assert_called() + + def test_on_tool_start_debug_disabled(self, handler, disable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_start("tool1", {"a": 1}) + + mock_print_text.assert_not_called() + + def test_on_tool_end_debug_enabled_and_trace(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + mock_trace_manager = MagicMock() + + handler.on_tool_end( + tool_name="tool1", + tool_inputs={"a": 1}, + tool_outputs="output", + message_id="msg1", + timer=123, + trace_manager=mock_trace_manager, + ) + + assert mock_print_text.call_count >= 1 + mock_trace_manager.add_trace_task.assert_called_once() + + def test_on_tool_end_without_trace_manager(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_end( + tool_name="tool1", + tool_inputs={}, + tool_outputs="output", + ) + + assert mock_print_text.call_count >= 1 + + def test_on_tool_error_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_error(Exception("error")) + + mock_print_text.assert_called_once() + + def test_on_tool_error_debug_disabled(self, handler, disable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_error(Exception("error")) + + mock_print_text.assert_not_called() + + @pytest.mark.parametrize("thought", ["thinking", ""]) + def test_on_agent_start(self, handler, enable_debug, mocker, thought): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_agent_start(thought) + + mock_print_text.assert_called() + + def test_on_agent_finish_increments_loop(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + current_loop = handler.current_loop + handler.on_agent_finish() + + assert handler.current_loop == current_loop + 1 + mock_print_text.assert_called() + + def test_on_datasource_start_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_datasource_start("ds1", {"x": 1}) + + mock_print_text.assert_called_once() + + def test_ignore_agent_property(self, disable_debug, handler): + assert handler.ignore_agent is True + + def test_ignore_chat_model_property(self, disable_debug, handler): + assert handler.ignore_chat_model is True + + def test_ignore_properties_when_debug_enabled(self, enable_debug, handler): + assert handler.ignore_agent is False + assert handler.ignore_chat_model is False diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py new file mode 100644 index 0000000000..b37c4c57a1 --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -0,0 +1,162 @@ +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.index_tool_callback_handler import ( + DatasetIndexToolCallbackHandler, +) + + +@pytest.fixture +def mock_queue_manager(mocker): + return mocker.Mock() + + +@pytest.fixture +def handler(mock_queue_manager, mocker): + mocker.patch( + "core.callback_handler.index_tool_callback_handler.db", + ) + return DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id="app-1", + message_id="msg-1", + user_id="user-1", + invoke_from=mocker.Mock(), + ) + + +class TestOnQuery: + @pytest.mark.parametrize( + ("invoke_from", "expected_role"), + [ + (InvokeFrom.EXPLORE, "account"), + (InvokeFrom.DEBUGGER, "account"), + (InvokeFrom.WEB_APP, "end_user"), + ], + ) + def test_on_query_success_roles(self, mocker, mock_queue_manager, invoke_from, expected_role): + # Arrange + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + handler = DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id="app-1", + message_id="msg-1", + user_id="user-1", + invoke_from=mocker.Mock(), + ) + + handler._invoke_from = invoke_from + + # Act + handler.on_query("test query", "dataset-1") + + # Assert + mock_db.session.add.assert_called_once() + dataset_query = mock_db.session.add.call_args.args[0] + assert dataset_query.created_by_role == expected_role + mock_db.session.commit.assert_called_once() + + def test_on_query_none_values(self, mocker, mock_queue_manager): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + handler = DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id=None, + message_id=None, + user_id=None, + invoke_from=None, + ) + + handler.on_query(None, None) + + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + +class TestOnToolEnd: + def test_on_tool_end_no_metadata(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + document = mocker.Mock() + document.metadata = None + + handler.on_tool_end([document]) + + mock_db.session.commit.assert_not_called() + + def test_on_tool_end_dataset_document_not_found(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + mock_db.session.scalar.return_value = None + + document = mocker.Mock() + document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} + + handler.on_tool_end([document]) + + mock_db.session.scalar.assert_called_once() + + def test_on_tool_end_parent_child_index_with_child(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + mock_dataset_doc = mocker.Mock() + from core.callback_handler.index_tool_callback_handler import IndexStructureType + + mock_dataset_doc.doc_form = IndexStructureType.PARENT_CHILD_INDEX + mock_dataset_doc.dataset_id = "dataset-1" + mock_dataset_doc.id = "doc-1" + + mock_child_chunk = mocker.Mock() + mock_child_chunk.segment_id = "segment-1" + + mock_db.session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk] + + document = mocker.Mock() + document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} + + mock_query = mocker.Mock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + + handler.on_tool_end([document]) + + mock_query.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_on_tool_end_non_parent_child_index(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + mock_dataset_doc = mocker.Mock() + mock_dataset_doc.doc_form = "OTHER" + + mock_db.session.scalar.return_value = mock_dataset_doc + + document = mocker.Mock() + document.metadata = { + "document_id": "doc-1", + "doc_id": "node-1", + "dataset_id": "dataset-1", + } + + mock_query = mocker.Mock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + + handler.on_tool_end([document]) + + mock_query.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_on_tool_end_empty_documents(self, handler): + handler.on_tool_end([]) + + +class TestReturnRetrieverResourceInfo: + def test_publish_called(self, handler, mock_queue_manager, mocker): + mock_event = mocker.patch("core.callback_handler.index_tool_callback_handler.QueueRetrieverResourcesEvent") + + resources = [mocker.Mock()] + + handler.return_retriever_resource_info(resources) + + mock_queue_manager.publish.assert_called_once() diff --git a/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py new file mode 100644 index 0000000000..131fb006ed --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py @@ -0,0 +1,184 @@ +from unittest.mock import MagicMock, call + +import pytest + +from core.callback_handler.workflow_tool_callback_handler import ( + DifyWorkflowCallbackHandler, +) + + +class DummyToolInvokeMessage: + """Lightweight dummy to simulate ToolInvokeMessage behavior.""" + + def __init__(self, json_value: str): + self._json_value = json_value + + def model_dump_json(self): + return self._json_value + + +@pytest.fixture +def handler(): + """Fixture to create handler instance with deterministic color.""" + instance = DifyWorkflowCallbackHandler() + instance.color = "blue" + return instance + + +@pytest.fixture +def mock_print_text(mocker): + """Mock print_text to avoid real stdout printing.""" + return mocker.patch("core.callback_handler.workflow_tool_callback_handler.print_text") + + +class TestDifyWorkflowCallbackHandler: + def test_on_tool_execution_single_output_success(self, handler, mock_print_text): + # Arrange + tool_name = "test_tool" + tool_inputs = {"a": 1} + message = DummyToolInvokeMessage('{"key": "value"}') + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs=tool_inputs, + tool_outputs=[message], + ) + ) + + # Assert + assert results == [message] + assert mock_print_text.call_count == 4 + mock_print_text.assert_has_calls( + [ + call("\n[on_tool_execution]\n", color="blue"), + call("Tool: test_tool\n", color="blue"), + call( + "Outputs: " + message.model_dump_json()[:1000] + "\n", + color="blue", + ), + call("\n"), + ] + ) + + def test_on_tool_execution_multiple_outputs(self, handler, mock_print_text): + # Arrange + tool_name = "multi_tool" + outputs = [ + DummyToolInvokeMessage('{"id": 1}'), + DummyToolInvokeMessage('{"id": 2}'), + ] + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=outputs, + ) + ) + + # Assert + assert results == outputs + assert mock_print_text.call_count == 4 * len(outputs) + + def test_on_tool_execution_empty_iterable(self, handler, mock_print_text): + # Arrange + tool_name = "empty_tool" + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[], + ) + ) + + # Assert + assert results == [] + mock_print_text.assert_not_called() + + @pytest.mark.parametrize( + ("invalid_outputs", "expected_exception"), + [ + (None, TypeError), + (123, TypeError), + ("not_iterable", AttributeError), + ], + ) + def test_on_tool_execution_invalid_outputs_type(self, handler, invalid_outputs, expected_exception): + # Arrange + tool_name = "invalid_tool" + + # Act & Assert + with pytest.raises(expected_exception): + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=invalid_outputs, + ) + ) + + def test_on_tool_execution_long_json_truncation(self, handler, mock_print_text): + # Arrange + tool_name = "long_json_tool" + long_json = "x" * 1500 + message = DummyToolInvokeMessage(long_json) + + # Act + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[message], + ) + ) + + # Assert + expected_truncated = long_json[:1000] + mock_print_text.assert_any_call( + "Outputs: " + expected_truncated + "\n", + color="blue", + ) + + def test_on_tool_execution_model_dump_json_exception(self, handler, mock_print_text): + # Arrange + tool_name = "exception_tool" + bad_message = MagicMock() + bad_message.model_dump_json.side_effect = ValueError("JSON error") + + # Act & Assert + with pytest.raises(ValueError): + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[bad_message], + ) + ) + + # Ensure first two prints happened before failure + assert mock_print_text.call_count >= 2 + + def test_on_tool_execution_none_message_id_and_trace_manager(self, handler, mock_print_text): + # Arrange + tool_name = "optional_params_tool" + message = DummyToolInvokeMessage('{"data": "ok"}') + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[message], + message_id=None, + timer=None, + trace_manager=None, + ) + ) + + assert results == [message] + assert mock_print_text.call_count == 4 diff --git a/api/tests/unit_tests/core/entities/test_entities_agent_entities.py b/api/tests/unit_tests/core/entities/test_entities_agent_entities.py new file mode 100644 index 0000000000..2437602695 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_agent_entities.py @@ -0,0 +1,9 @@ +from core.entities.agent_entities import PlanningStrategy + + +def test_planning_strategy_values_are_stable() -> None: + # Arrange / Act / Assert + assert PlanningStrategy.ROUTER.value == "router" + assert PlanningStrategy.REACT_ROUTER.value == "react_router" + assert PlanningStrategy.REACT.value == "react" + assert PlanningStrategy.FUNCTION_CALL.value == "function_call" diff --git a/api/tests/unit_tests/core/entities/test_entities_document_task.py b/api/tests/unit_tests/core/entities/test_entities_document_task.py new file mode 100644 index 0000000000..dd550930d7 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_document_task.py @@ -0,0 +1,18 @@ +from core.entities.document_task import DocumentTask + + +def test_document_task_keeps_indexing_identifiers() -> None: + # Arrange + document_ids = ("doc-1", "doc-2") + + # Act + task = DocumentTask( + tenant_id="tenant-1", + dataset_id="dataset-1", + document_ids=document_ids, + ) + + # Assert + assert task.tenant_id == "tenant-1" + assert task.dataset_id == "dataset-1" + assert task.document_ids == document_ids diff --git a/api/tests/unit_tests/core/entities/test_entities_embedding_type.py b/api/tests/unit_tests/core/entities/test_entities_embedding_type.py new file mode 100644 index 0000000000..5a82fc4842 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_embedding_type.py @@ -0,0 +1,7 @@ +from core.entities.embedding_type import EmbeddingInputType + + +def test_embedding_input_type_values_are_stable() -> None: + # Arrange / Act / Assert + assert EmbeddingInputType.DOCUMENT.value == "document" + assert EmbeddingInputType.QUERY.value == "query" diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py new file mode 100644 index 0000000000..2e4f6d34fb --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -0,0 +1,45 @@ +from core.entities.execution_extra_content import ( + ExecutionExtraContentDomainModel, + HumanInputContent, + HumanInputFormDefinition, + HumanInputFormSubmissionData, +) +from dify_graph.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.enums import FormInputType +from models.execution_extra_content import ExecutionContentType + + +def test_human_input_content_defaults_and_domain_alias() -> None: + # Arrange + form_definition = HumanInputFormDefinition( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="Please confirm", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="answer")], + actions=[UserAction(id="confirm", title="Confirm")], + resolved_default_values={"answer": "yes"}, + expiration_time=1_700_000_000, + ) + submission_data = HumanInputFormSubmissionData( + node_id="node-1", + node_title="Human Input", + rendered_content="Please confirm", + action_id="confirm", + action_text="Confirm", + ) + + # Act + content = HumanInputContent( + workflow_run_id="workflow-run-1", + submitted=True, + form_definition=form_definition, + form_submission_data=submission_data, + ) + + # Assert + assert form_definition.model_config.get("frozen") is True + assert content.type == ExecutionContentType.HUMAN_INPUT + assert content.form_definition is form_definition + assert content.form_submission_data is submission_data + assert ExecutionExtraContentDomainModel is HumanInputContent diff --git a/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py b/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py new file mode 100644 index 0000000000..d25f20145f --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py @@ -0,0 +1,45 @@ +from core.entities.knowledge_entities import ( + PipelineDataset, + PipelineDocument, + PipelineGenerateResponse, +) + + +def test_pipeline_dataset_normalizes_none_description() -> None: + # Arrange / Act + dataset = PipelineDataset( + id="dataset-1", + name="Dataset", + description=None, + chunk_structure="parent-child", + ) + + # Assert + assert dataset.description == "" + + +def test_pipeline_generate_response_builds_nested_models() -> None: + # Arrange + dataset = PipelineDataset( + id="dataset-1", + name="Dataset", + description="Knowledge base", + chunk_structure="parent-child", + ) + document = PipelineDocument( + id="doc-1", + position=1, + data_source_type="file", + data_source_info={"name": "spec.pdf"}, + name="spec.pdf", + indexing_status="completed", + enabled=True, + ) + + # Act + response = PipelineGenerateResponse(batch="batch-1", dataset=dataset, documents=[document]) + + # Assert + assert response.batch == "batch-1" + assert response.dataset.id == "dataset-1" + assert response.documents[0].id == "doc-1" diff --git a/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py b/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py new file mode 100644 index 0000000000..5449c63b45 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py @@ -0,0 +1,450 @@ +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.entities import mcp_provider as mcp_provider_module +from core.entities.mcp_provider import ( + DEFAULT_EXPIRES_IN, + DEFAULT_TOKEN_TYPE, + MCPProviderEntity, +) +from core.mcp.types import OAuthTokens + + +def _build_mcp_provider_entity() -> MCPProviderEntity: + now = datetime(2025, 1, 1, tzinfo=UTC) + return MCPProviderEntity( + id="provider-1", + provider_id="server-1", + name="Example MCP", + tenant_id="tenant-1", + user_id="user-1", + server_url="encrypted-server-url", + headers={}, + timeout=30, + sse_read_timeout=300, + authed=False, + credentials={}, + tools=[], + icon={"en_US": "icon.png"}, + created_at=now, + updated_at=now, + ) + + +def test_from_db_model_maps_fields() -> None: + # Arrange + now = datetime(2025, 1, 1, tzinfo=UTC) + db_provider = SimpleNamespace( + id="provider-1", + server_identifier="server-1", + name="Example MCP", + tenant_id="tenant-1", + user_id="user-1", + server_url="encrypted-server-url", + headers={"Authorization": "enc"}, + timeout=15, + sse_read_timeout=120, + authed=True, + credentials={"access_token": "enc-token"}, + tool_dict=[{"name": "search"}], + icon=None, + created_at=now, + updated_at=now, + ) + + # Act + entity = MCPProviderEntity.from_db_model(db_provider) + + # Assert + assert entity.provider_id == "server-1" + assert entity.tools == [{"name": "search"}] + assert entity.icon == "" + + +def test_redirect_url_uses_console_api_url(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + entity = _build_mcp_provider_entity() + monkeypatch.setattr(mcp_provider_module.dify_config, "CONSOLE_API_URL", "https://console.example.com") + + # Act + redirect_url = entity.redirect_url + + # Assert + assert redirect_url == "https://console.example.com/console/api/mcp/oauth/callback" + + +def test_client_metadata_for_authorization_code_flow() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "authorization_code"] + assert metadata.redirect_uris == [entity.redirect_url] + assert metadata.response_types == ["code"] + + +def test_client_metadata_for_client_credentials_flow() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = {"client_information": {"grant_types": ["client_credentials"]}} + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "client_credentials"] + assert metadata.redirect_uris == [] + assert metadata.response_types == [] + + +def test_client_metadata_prefers_nested_authorization_code_grant_type() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = { + "grant_type": "client_credentials", + "client_information": {"grant_types": ["authorization_code"]}, + } + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "authorization_code"] + assert metadata.redirect_uris == [entity.redirect_url] + assert metadata.response_types == ["code"] + + +def test_provider_icon_returns_icon_dict_as_is() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}}) + + # Act + icon = entity.provider_icon + + # Assert + assert icon == {"en_US": "icon.png"} + + +def test_provider_icon_uses_signed_url_for_plain_path() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": "icons/mcp.png"}) + + with patch( + "core.entities.mcp_provider.file_helpers.get_signed_file_url", + return_value="https://signed.example.com/icons/mcp.png", + ) as mock_get_signed_url: + # Act + icon = entity.provider_icon + + # Assert + mock_get_signed_url.assert_called_once_with("icons/mcp.png") + assert icon == "https://signed.example.com/icons/mcp.png" + + +def test_to_api_response_without_sensitive_data_skips_auth_related_work() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}}) + + with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"): + # Act + response = entity.to_api_response(include_sensitive=False) + + # Assert + assert response["author"] == "Anonymous" + assert response["masked_headers"] == {} + assert response["is_dynamic_registration"] is True + assert "authentication" not in response + + +def test_to_api_response_with_sensitive_data_includes_masked_values() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy( + update={ + "credentials": {"client_information": {"is_dynamic_registration": False}}, + "icon": {"en_US": "icon.png"}, + } + ) + + with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"): + with patch.object(MCPProviderEntity, "masked_headers", return_value={"Authorization": "Be****"}): + with patch.object(MCPProviderEntity, "masked_credentials", return_value={"client_id": "cl****"}): + # Act + response = entity.to_api_response(user_name="Rajat", include_sensitive=True) + + # Assert + assert response["author"] == "Rajat" + assert response["masked_headers"] == {"Authorization": "Be****"} + assert response["authentication"] == {"client_id": "cl****"} + assert response["is_dynamic_registration"] is False + + +def test_retrieve_client_information_decrypts_nested_secret() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = {"client_information": {"client_id": "client-1", "encrypted_client_secret": "enc-secret"}} + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="plain-secret") as mock_decrypt: + # Act + client_info = entity.retrieve_client_information() + + # Assert + assert client_info is not None + assert client_info.client_id == "client-1" + assert client_info.client_secret == "plain-secret" + mock_decrypt.assert_called_once_with("tenant-1", "enc-secret") + + +def test_retrieve_client_information_returns_none_for_missing_data() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + result_empty = entity.retrieve_client_information() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}): + # Act + result_invalid = entity.retrieve_client_information() + + # Assert + assert result_empty is None + assert result_invalid is None + + +def test_masked_server_url_hides_path_segments() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object( + MCPProviderEntity, + "decrypt_server_url", + return_value="https://api.example.com/v1/mcp?query=1", + ): + # Act + masked_url = entity.masked_server_url() + + # Assert + assert masked_url == "https://api.example.com/******?query=1" + + +def test_mask_value_covers_short_and_long_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + # Act + short_masked = entity._mask_value("short") + long_masked = entity._mask_value("abcdefghijkl") + + # Assert + assert short_masked == "*****" + assert long_masked == "ab********kl" + + +def test_masked_headers_masks_all_decrypted_header_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "abcdefgh"}): + # Act + masked = entity.masked_headers() + + # Assert + assert masked == {"Authorization": "ab****gh"} + + +def test_masked_credentials_handles_nested_secret_fields() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = { + "client_information": { + "client_id": "client-id", + "encrypted_client_secret": "encrypted-value", + "client_secret": "plain-secret", + } + } + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="decrypted-secret"): + # Act + masked = entity.masked_credentials() + + # Assert + assert masked["client_id"] == "cl*****id" + assert masked["client_secret"] == "pl********et" + + +def test_masked_credentials_returns_empty_for_missing_client_information() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + masked_empty = entity.masked_credentials() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}): + # Act + masked_invalid = entity.masked_credentials() + + # Assert + assert masked_empty == {} + assert masked_invalid == {} + + +def test_retrieve_tokens_returns_defaults_when_optional_fields_missing() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}}) + + with patch.object( + MCPProviderEntity, + "decrypt_credentials", + return_value={"access_token": "token", "expires_in": "", "refresh_token": "refresh"}, + ): + # Act + tokens = entity.retrieve_tokens() + + # Assert + assert isinstance(tokens, OAuthTokens) + assert tokens.access_token == "token" + assert tokens.token_type == DEFAULT_TOKEN_TYPE + assert tokens.expires_in == DEFAULT_EXPIRES_IN + assert tokens.refresh_token == "refresh" + + +def test_retrieve_tokens_returns_none_when_access_token_missing() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}}) + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"access_token": ""}) as mock_decrypt: + # Act + tokens = entity.retrieve_tokens() + + # Assert + mock_decrypt.assert_called_once() + assert tokens is None + + +def test_decrypt_server_url_delegates_to_encrypter() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="https://api.example.com") as mock: + # Act + decrypted = entity.decrypt_server_url() + + # Assert + mock.assert_called_once_with("tenant-1", "encrypted-server-url") + assert decrypted == "https://api.example.com" + + +def test_decrypt_authentication_injects_authorization_for_oauth() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"authed": True, "headers": {}}) + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={}): + with patch.object( + MCPProviderEntity, + "retrieve_tokens", + return_value=OAuthTokens(access_token="abc123", token_type="bearer"), + ): + # Act + headers = entity.decrypt_authentication() + + # Assert + assert headers["Authorization"] == "Bearer abc123" + + +def test_decrypt_authentication_does_not_overwrite_existing_headers() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy( + update={"authed": True, "headers": {"Authorization": "encrypted-header"}} + ) + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "existing"}): + with patch.object( + MCPProviderEntity, + "retrieve_tokens", + return_value=OAuthTokens(access_token="abc", token_type="bearer"), + ) as mock_tokens: + # Act + headers = entity.decrypt_authentication() + + # Assert + mock_tokens.assert_not_called() + assert headers == {"Authorization": "existing"} + + +def test_decrypt_dict_returns_empty_for_empty_input() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + # Act + decrypted = entity._decrypt_dict({}) + + # Assert + assert decrypted == {} + + +def test_decrypt_dict_returns_original_data_when_no_encrypted_fields() -> None: + # Arrange + entity = _build_mcp_provider_entity() + input_data = {"nested": {"k": "v"}, "count": 2, "empty": ""} + + # Act + result = entity._decrypt_dict(input_data) + + # Assert + assert result is input_data + + +def test_decrypt_dict_only_decrypts_top_level_string_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + decryptor = Mock() + decryptor.decrypt.return_value = {"api_key": "plain-key"} + + def _fake_create_provider_encrypter(*, tenant_id: str, config: list, cache): + assert tenant_id == "tenant-1" + assert any(item.name == "api_key" for item in config) + return decryptor, None + + with patch("core.tools.utils.encryption.create_provider_encrypter", side_effect=_fake_create_provider_encrypter): + # Act + result = entity._decrypt_dict( + { + "api_key": "encrypted-key", + "nested": {"client_id": "unchanged"}, + "empty": "", + "count": 2, + } + ) + + # Assert + decryptor.decrypt.assert_called_once_with({"api_key": "encrypted-key"}) + assert result["api_key"] == "plain-key" + assert result["nested"] == {"client_id": "unchanged"} + assert result["count"] == 2 + + +def test_decrypt_headers_and_credentials_delegate_to_decrypt_dict() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "_decrypt_dict", side_effect=[{"h": "v"}, {"c": "v"}]) as mock: + # Act + headers = entity.decrypt_headers() + credentials = entity.decrypt_credentials() + + # Assert + assert mock.call_count == 2 + assert headers == {"h": "v"} + assert credentials == {"c": "v"} diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py new file mode 100644 index 0000000000..7a3d5e84ed --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -0,0 +1,92 @@ +"""Unit tests for model entity behavior and invariants. + +Covers DefaultModelEntity, DefaultModelProviderEntity, ModelStatus, +ProviderModelWithStatusEntity, and SimpleModelProviderEntity. Assumes i18n +labels are provided via I18nObject, model metadata aligns with FetchFrom and +ModelType expectations, and ProviderEntity/ConfigurateMethod interactions +drive provider mapping behavior. +""" + +import pytest + +from core.entities.model_entities import ( + DefaultModelEntity, + DefaultModelProviderEntity, + ModelStatus, + ProviderModelWithStatusEntity, + SimpleModelProviderEntity, +) +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + + +def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: + return ProviderModelWithStatusEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + status=status, + ) + + +def test_simple_model_provider_entity_maps_from_provider_entity() -> None: + # Arrange + provider_entity = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + + # Act + simple_provider = SimpleModelProviderEntity(provider_entity) + + # Assert + assert simple_provider.provider == "openai" + assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.supported_model_types == [ModelType.LLM] + + +def test_provider_model_with_status_raises_for_known_error_statuses() -> None: + # Arrange + expectations = { + ModelStatus.NO_CONFIGURE: "Model is not configured", + ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", + ModelStatus.NO_PERMISSION: "No permission to use this model", + ModelStatus.DISABLED: "Model is disabled", + } + + for status, message in expectations.items(): + # Act / Assert + with pytest.raises(ValueError, match=message): + _build_model_with_status(status).raise_for_status() + + +def test_provider_model_with_status_allows_active_and_credential_removed() -> None: + # Arrange + active_model = _build_model_with_status(ModelStatus.ACTIVE) + removed_model = _build_model_with_status(ModelStatus.CREDENTIAL_REMOVED) + + # Act / Assert + active_model.raise_for_status() + removed_model.raise_for_status() + + +def test_default_model_entity_accepts_model_field_name() -> None: + # Arrange / Act + default_model = DefaultModelEntity( + model="gpt-4o-mini", + model_type=ModelType.LLM, + provider=DefaultModelProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + ), + ) + + # Assert + assert default_model.model == "gpt-4o-mini" + assert default_model.provider.provider == "openai" diff --git a/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py b/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py new file mode 100644 index 0000000000..20b7bf2a9f --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py @@ -0,0 +1,22 @@ +from core.entities.parameter_entities import ( + AppSelectorScope, + CommonParameterType, + ModelSelectorScope, + ToolSelectorScope, +) + + +def test_common_parameter_type_values_are_stable() -> None: + # Arrange / Act / Assert + assert CommonParameterType.SECRET_INPUT.value == "secret-input" + assert CommonParameterType.MODEL_SELECTOR.value == "model-selector" + assert CommonParameterType.DYNAMIC_SELECT.value == "dynamic-select" + assert CommonParameterType.ARRAY.value == "array" + assert CommonParameterType.OBJECT.value == "object" + + +def test_selector_scope_values_are_stable() -> None: + # Arrange / Act / Assert + assert AppSelectorScope.WORKFLOW.value == "workflow" + assert ModelSelectorScope.TEXT_EMBEDDING.value == "text-embedding" + assert ToolSelectorScope.BUILTIN.value == "builtin" diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py new file mode 100644 index 0000000000..82f98d07a3 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -0,0 +1,1850 @@ +from __future__ import annotations + +from contextlib import contextmanager +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from constants import HIDDEN_VALUE +from core.entities.model_entities import ModelStatus +from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations +from core.entities.provider_entities import ( + CredentialConfiguration, + CustomConfiguration, + CustomModelConfiguration, + CustomProviderConfiguration, + ModelLoadBalancingConfiguration, + ModelSettings, + ProviderQuotaType, + QuotaConfiguration, + QuotaUnit, + RestrictModel, + SystemConfiguration, + SystemConfigurationStatus, +) +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from models.provider import ProviderType +from models.provider_ids import ModelProviderID + +_UNSET = object() + + +def _build_provider_configuration(*, provider_name: str = "openai") -> ProviderConfiguration: + provider_entity = ProviderEntity( + provider=provider_name, + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + system_configuration = SystemConfiguration( + enabled=True, + credentials={"api_key": "test-key"}, + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=1_000, + quota_used=0, + is_valid=True, + restrict_models=[], + ) + ], + ) + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + return ProviderConfiguration( + tenant_id="tenant-1", + provider=provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + custom_configuration=CustomConfiguration(provider=None, models=[]), + model_settings=[], + ) + + +def _build_ai_model(name: str, *, model_type: ModelType = ModelType.LLM) -> AIModelEntity: + return AIModelEntity( + model=name, + label=I18nObject(en_US=name), + model_type=model_type, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +def _exec_result( + *, + scalar_one_or_none: Any = _UNSET, + scalar: Any = _UNSET, + scalars_all: Any = _UNSET, + scalars_first: Any = _UNSET, +) -> Mock: + result = Mock() + if scalar_one_or_none is not _UNSET: + result.scalar_one_or_none.return_value = scalar_one_or_none + if scalar is not _UNSET: + result.scalar.return_value = scalar + if scalars_all is not _UNSET or scalars_first is not _UNSET: + scalars = Mock() + if scalars_all is not _UNSET: + scalars.all.return_value = scalars_all + if scalars_first is not _UNSET: + scalars.first.return_value = scalars_first + result.scalars.return_value = scalars + return result + + +@contextmanager +def _patched_session(session: Mock): + with patch("core.entities.provider_configuration.db") as mock_db: + mock_db.engine = Mock() + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + mock_session_cls.return_value.__enter__.return_value = session + yield mock_session_cls + + +def _build_secret_provider_schema() -> ProviderCredentialSchema: + return ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ] + ) + + +def _build_secret_model_schema() -> ModelCredentialSchema: + return ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ], + ) + + +def test_extract_secret_variables_returns_only_secret_inputs() -> None: + configuration = _build_provider_configuration() + credential_form_schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + ), + ] + + secret_variables = configuration.extract_secret_variables(credential_form_schemas) + assert secret_variables == ["api_key"] + + +def test_obfuscated_credentials_masks_only_secret_fields() -> None: + configuration = _build_provider_configuration() + credential_form_schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + ), + ] + + with patch( + "core.entities.provider_configuration.encrypter.obfuscated_token", + side_effect=lambda value: f"masked-{value[-2:]}", + ): + obfuscated = configuration.obfuscated_credentials( + credentials={"api_key": "sk-test-1234", "endpoint": "https://api.example.com"}, + credential_form_schemas=credential_form_schemas, + ) + + assert obfuscated["api_key"] == "masked-34" + assert obfuscated["endpoint"] == "https://api.example.com" + + +def test_provider_configurations_behave_like_keyed_container() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + + configurations[provider_key] = configuration + + assert "openai" in configurations + assert configurations["openai"] is configuration + assert configurations.get("openai") is configuration + assert configurations.to_list() == [configuration] + assert list(configurations) == [(provider_key, configuration)] + + +def test_provider_configurations_get_models_forwards_filters() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + configurations[provider_key] = configuration + expected_model = Mock() + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[expected_model]) as mock_get: + models = configurations.get_models(provider="openai", model_type=ModelType.LLM, only_active=True) + + mock_get.assert_called_once_with(ModelType.LLM, True) + assert models == [expected_model] + + +def test_provider_configurations_get_models_skips_non_matching_provider_filter() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + configurations[provider_key] = configuration + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[Mock()]) as mock_get: + models = configurations.get_models(provider="anthropic", model_type=ModelType.LLM, only_active=True) + + assert models == [] + mock_get.assert_not_called() + + +def test_get_current_credentials_custom_provider_checks_current_credential() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + current_credential_id="credential-1", + current_credential_name="Primary", + available_credentials=[], + ) + + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + assert credentials == {"api_key": "provider-key"} + assert mock_check.call_count == 1 + assert mock_check.call_args.kwargs["credential_id"] == "credential-1" + assert mock_check.call_args.kwargs["provider"] == "openai" + + +def test_get_current_credentials_custom_provider_checks_all_available_credentials() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + available_credentials=[ + CredentialConfiguration(credential_id="cred-1", credential_name="First"), + CredentialConfiguration(credential_id="cred-2", credential_name="Second"), + ], + ) + + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + assert credentials == {"api_key": "provider-key"} + assert [c.kwargs["credential_id"] for c in mock_check.call_args_list] == ["cred-1", "cred-2"] + assert all(c.kwargs["provider"] == "openai" for c in mock_check.call_args_list) + + +def test_get_system_configuration_status_returns_none_when_current_quota_missing() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.current_quota_type = ProviderQuotaType.FREE + + status = configuration.get_system_configuration_status() + assert status is None + + +def test_get_provider_names_supports_legacy_and_full_plugin_id() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider = "langgenius/openai/openai" + + provider_names = configuration._get_provider_names() + assert provider_names == ["langgenius/openai/openai", "openai"] + + +def test_generate_next_api_key_name_uses_highest_numeric_suffix() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [ + SimpleNamespace(credential_name="API KEY 9"), + SimpleNamespace(credential_name="legacy"), + SimpleNamespace(credential_name=" API KEY 2 "), + ] + + name = configuration._generate_next_api_key_name(session=session, query_factory=lambda: Mock()) + assert name == "API KEY 10" + + +def test_generate_next_api_key_name_falls_back_to_default_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + + def _raise_query_error(): + raise RuntimeError("boom") + + name = configuration._generate_next_api_key_name(session=session, query_factory=_raise_query_error) + assert name == "API KEY 1" + + +def test_generate_provider_and_custom_model_names_delegate_to_shared_generator() -> None: + configuration = _build_provider_configuration() + + with patch.object(configuration, "_generate_next_api_key_name", return_value="API KEY 7") as mock_generator: + provider_name = configuration._generate_provider_credential_name(session=Mock()) + custom_model_name = configuration._generate_custom_model_credential_name( + model="gpt-4o", + model_type=ModelType.LLM, + session=Mock(), + ) + + assert provider_name == "API KEY 7" + assert custom_model_name == "API KEY 7" + assert mock_generator.call_count == 2 + + +def test_get_provider_credential_uses_specific_lookup_when_id_provided() -> None: + configuration = _build_provider_configuration() + + with patch.object(configuration, "_get_specific_provider_credential", return_value={"api_key": "***"}) as mock_get: + credential = configuration.get_provider_credential("credential-1") + + assert credential == {"api_key": "***"} + mock_get.assert_called_once_with("credential-1") + + +def test_validate_provider_credentials_handles_hidden_secret_value() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ] + ) + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + session=session, + ) + + assert validated["openai_api_key"] == "enc::restored-key" + assert validated["region"] == "us" + mock_factory.provider_credentials_validate.assert_called_once_with( + provider="openai", + credentials={"openai_api_key": "restored-key", "region": "us"}, + ) + + +def test_validate_provider_credentials_opens_session_when_not_passed() -> None: + configuration = _build_provider_configuration() + mock_session = Mock() + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"region": "us"} + + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + with patch("core.entities.provider_configuration.db") as mock_db: + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = mock_session + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + + assert validated == {"region": "us"} + mock_session_cls.assert_called_once() + + +def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: + configuration = _build_provider_configuration() + + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + configuration.switch_preferred_provider_type(ProviderType.SYSTEM) + mock_session_cls.assert_not_called() + + configuration.preferred_provider_type = ProviderType.CUSTOM + configuration.system_configuration.enabled = False + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + configuration.switch_preferred_provider_type(ProviderType.SYSTEM) + mock_session_cls.assert_not_called() + + +def test_switch_preferred_provider_type_updates_existing_record_with_session() -> None: + configuration = _build_provider_configuration() + configuration.preferred_provider_type = ProviderType.CUSTOM + session = Mock() + existing_record = SimpleNamespace(preferred_provider_type="custom") + session.execute.return_value.scalars.return_value.first.return_value = existing_record + + configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session) + + assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value + session.commit.assert_called_once() + + +def test_switch_preferred_provider_type_creates_record_when_missing() -> None: + configuration = _build_provider_configuration() + configuration.preferred_provider_type = ProviderType.SYSTEM + session = Mock() + session.execute.return_value.scalars.return_value.first.return_value = None + + configuration.switch_preferred_provider_type(ProviderType.CUSTOM, session=session) + + assert session.add.call_count == 1 + session.commit.assert_called_once() + + +def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: + configuration = _build_provider_configuration() + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) + mock_factory.get_model_schema.assert_called_once_with( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"api_key": "x"}, + ) + + +def test_get_provider_model_returns_none_when_model_not_found() -> None: + configuration = _build_provider_configuration() + fake_model = SimpleNamespace(model="other-model") + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[fake_model]): + selected = configuration.get_provider_model(ModelType.LLM, "gpt-4o") + + assert selected is None + + +def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> None: + configuration = _build_provider_configuration() + configuration.provider.position = {"llm": ["b-model", "a-model"]} + configuration.model_settings = [ + ModelSettings(model="a-model", model_type=ModelType.LLM, enabled=False, load_balancing_configs=[]) + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("a-model"), _build_ai_model("b-model"), _build_ai_model("a-model")], + ) + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = provider_schema + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) + active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) + + assert [model.model for model in all_models] == ["b-model", "a-model"] + assert [model.status for model in all_models] == [ModelStatus.ACTIVE, ModelStatus.DISABLED] + assert [model.model for model in active_models] == ["b-model"] + + +def test_get_custom_provider_models_sets_status_for_removed_credentials_and_invalid_lb_configs() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="custom-model", + model_type=ModelType.LLM, + credentials=None, + available_model_credentials=[CredentialConfiguration(credential_id="c-1", credential_name="first")], + ) + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("base-model")], + ) + model_setting_map = { + ModelType.LLM: { + "base-model": ModelSettings( + model="base-model", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=True, + load_balancing_configs=[ + ModelLoadBalancingConfiguration( + id="lb-base", + name="LB Base", + credentials={}, + credential_source_type="provider", + ) + ], + ), + "custom-model": ModelSettings( + model="custom-model", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=True, + load_balancing_configs=[ + ModelLoadBalancingConfiguration( + id="lb-custom", + name="LB Custom", + credentials={}, + credential_source_type="custom_model", + ) + ], + ), + } + } + + with patch.object(ProviderConfiguration, "get_model_schema", return_value=_build_ai_model("custom-model")): + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map=model_setting_map, + ) + + status_map = {model.model: model.status for model in models} + invalid_lb_map = {model.model: model.has_invalid_load_balancing_configs for model in models} + assert status_map["base-model"] == ModelStatus.ACTIVE + assert status_map["custom-model"] == ModelStatus.CREDENTIAL_REMOVED + assert invalid_lb_map["base-model"] is True + assert invalid_lb_map["custom-model"] is True + + +def test_validator_adds_predefined_model_for_customizable_provider_with_restrictions() -> None: + provider = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + ) + system_configuration = SystemConfiguration( + enabled=True, + credentials={"api_key": "test-key"}, + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="restricted", base_model_name="base-model", model_type=ModelType.LLM) + ], + ) + ], + ) + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + configuration = ProviderConfiguration( + tenant_id="tenant-1", + provider=provider, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + custom_configuration=CustomConfiguration(provider=None, models=[]), + model_settings=[], + ) + + assert ConfigurateMethod.PREDEFINED_MODEL in configuration.provider.configurate_methods + + +def test_get_current_credentials_system_handles_disable_and_restricted_base_model() -> None: + configuration = _build_provider_configuration() + configuration.model_settings = [ + ModelSettings(model="gpt-4o", model_type=ModelType.LLM, enabled=False, load_balancing_configs=[]) + ] + + with pytest.raises(ValueError, match="Model gpt-4o is disabled"): + configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + configuration.model_settings = [] + configuration.system_configuration.quota_configurations[0].restrict_models = [ + RestrictModel(model="gpt-4o", base_model_name="base-model", model_type=ModelType.LLM) + ] + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials["base_model_name"] == "base-model" + + +def test_get_current_credentials_prefers_model_specific_custom_credentials() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="gpt-4o", + model_type=ModelType.LLM, + credentials={"api_key": "model-key"}, + ) + ] + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials == {"api_key": "model-key"} + + +def test_get_system_configuration_status_falsey_quota_returns_unsupported() -> None: + class _FalseyQuota: + quota_type = ProviderQuotaType.TRIAL + is_valid = True + + def __bool__(self) -> bool: + return False + + configuration = _build_provider_configuration() + configuration.system_configuration.quota_configurations = [_FalseyQuota()] # type: ignore[list-item] + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.UNSUPPORTED + + +def test_get_provider_credential_default_uses_custom_provider_credentials() -> None: + configuration = _build_provider_configuration() + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + obfuscated = configuration.get_provider_credential() + assert obfuscated == {"api_key": "provider-key"} + + +def test_custom_configuration_availability_and_provider_record_helpers() -> None: + configuration = _build_provider_configuration() + assert not configuration.is_custom_configuration_available() + + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + available_credentials=[CredentialConfiguration(credential_id="cred-1", credential_name="Main")], + ) + assert configuration.is_custom_configuration_available() + + configuration.custom_configuration.provider = None + configuration.custom_configuration.models = [ + CustomModelConfiguration(model="gpt-4o", model_type=ModelType.LLM, credentials={"api_key": "model-key"}) + ] + assert configuration.is_custom_configuration_available() + + session = Mock() + provider_record = SimpleNamespace(id="provider-1") + session.execute.return_value.scalar_one_or_none.return_value = provider_record + assert configuration._get_provider_record(session) is provider_record + + session.execute.return_value.scalar_one_or_none.return_value = None + assert configuration._get_provider_record(session) is None + + +def test_check_provider_credential_name_exists_and_model_setting_lookup() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = "existing-id" + assert configuration._check_provider_credential_name_exists("Main", session) + + session.execute.return_value.scalar_one_or_none.return_value = None + assert not configuration._check_provider_credential_name_exists("Main", session, exclude_id="cred-2") + + setting = SimpleNamespace(id="setting-1") + session.execute.return_value.scalars.return_value.first.return_value = setting + assert configuration._get_provider_model_setting(ModelType.LLM, "gpt-4o", session) is setting + + +def test_validate_provider_credentials_handles_invalid_original_json() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-key"} + + +def test_generate_next_api_key_name_returns_default_when_no_records() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [] + + name = configuration._generate_next_api_key_name(session=session, query_factory=lambda: Mock()) + assert name == "API KEY 1" + + +def test_create_provider_credential_creates_provider_record_when_missing() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.flush.side_effect = lambda: None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with patch.object( + ProviderConfiguration, + "_generate_provider_credential_name", + return_value="API KEY 2", + ): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.create_provider_credential({"api_key": "raw"}, None) + + assert session.add.call_count == 2 + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(provider_type=ProviderType.CUSTOM, session=session) + + +def test_create_provider_credential_marks_existing_provider_as_valid() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_record = SimpleNamespace(is_valid=False) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + assert provider_record.is_valid is True + session.commit.assert_called_once() + + +def test_create_provider_credential_raises_when_duplicate_name_exists() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + +def test_update_provider_credential_success_updates_and_invalidates_cache() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1", encrypted_config="{}", credential_name="Old", updated_at=None) + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object( + ProviderConfiguration, + "_update_load_balancing_configs_with_credential", + ) as mock_lb: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.update_provider_credential( + credentials={"api_key": "raw"}, + credential_id="cred-1", + credential_name="New Name", + ) + + assert credential_record.credential_name == "New Name" + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + mock_lb.assert_called_once() + + +def test_update_provider_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", None) + + +def test_update_load_balancing_configs_updates_all_matching_configs() -> None: + configuration = _build_provider_configuration() + session = Mock() + lb_config = SimpleNamespace(id="lb-1", encrypted_config="old", name="old", updated_at=None) + session.execute.return_value.scalars.return_value.all.return_value = [lb_config] + credential_record = SimpleNamespace(encrypted_config='{"api_key":"enc"}', credential_name="API KEY 3") + + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration._update_load_balancing_configs_with_credential( + credential_id="cred-1", + credential_record=credential_record, + credential_source="provider", + session=session, + ) + + assert lb_config.encrypted_config == '{"api_key":"enc"}' + assert lb_config.name == "API KEY 3" + mock_cache.return_value.delete.assert_called_once() + session.commit.assert_called_once() + + +def test_update_load_balancing_configs_returns_when_no_matching_configs() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [] + + configuration._update_load_balancing_configs_with_credential( + credential_id="cred-1", + credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"), + credential_source="provider", + session=session, + ) + + session.commit.assert_not_called() + + +def test_delete_provider_credential_removes_provider_record_when_last_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[]), + _exec_result(scalar=1), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_provider_credential("cred-1") + + assert any(call.args and call.args[0] is provider_record for call in session.delete.call_args_list) + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(provider_type=ProviderType.SYSTEM, session=session) + + +def test_delete_provider_credential_raises_when_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.delete_provider_credential("cred-1") + + +def test_delete_provider_credential_unsets_active_credential_when_more_available() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + lb_config = SimpleNamespace(id="lb-1") + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[lb_config]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_provider_credential("cred-1") + + assert provider_record.credential_id is None + assert mock_cache.return_value.delete.call_count == 2 + mock_switch.assert_called_once_with(provider_type=ProviderType.SYSTEM, session=session) + + +def test_switch_active_provider_credential_success_and_failures() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.switch_active_provider_credential("cred-1") + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(ValueError, match="Provider record not found"): + configuration.switch_active_provider_credential("cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_record = SimpleNamespace(id="provider-1", credential_id=None, updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.switch_active_provider_credential("cred-1") + + assert provider_record.credential_id == "cred-1" + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(ProviderType.CUSTOM, session=session) + + +def test_get_custom_model_record_supports_plugin_id_alias() -> None: + configuration = _build_provider_configuration(provider_name="langgenius/openai/openai") + session = Mock() + custom_model_record = SimpleNamespace(id="model-1") + session.execute.return_value.scalar_one_or_none.return_value = custom_model_record + + result = configuration._get_custom_model_record(ModelType.LLM, "gpt-4o", session) + assert result is custom_model_record + + +def test_get_specific_custom_model_credential_success_and_not_found() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + record = SimpleNamespace(id="cred-1", credential_name="Main", encrypted_config='{"openai_api_key":"enc"}') + session.execute.return_value.scalar_one_or_none.return_value = record + + with _patched_session(session): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch.object(ProviderConfiguration, "obfuscated_credentials", return_value={"openai_api_key": "***"}): + response = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert response["current_credential_id"] == "cred-1" + assert response["credentials"] == {"openai_api_key": "***"} + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential with id cred-1 not found"): + configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + credential_name="Main", + encrypted_config="{invalid-json", + ) + with _patched_session(session): + invalid_json = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert invalid_json["credentials"] == {} + + +def test_check_custom_model_credential_name_exists_respects_exclusion() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + assert configuration._check_custom_model_credential_name_exists( + ModelType.LLM, "gpt-4o", "Main", session, exclude_id="other-id" + ) + + session.execute.return_value.scalar_one_or_none.return_value = None + assert not configuration._check_custom_model_credential_name_exists(ModelType.LLM, "gpt-4o", "Main", session) + + +def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback() -> None: + configuration = _build_provider_configuration() + with patch.object( + ProviderConfiguration, + "_get_specific_custom_model_credential", + return_value={"current_credential_id": "cred-1"}, + ) as mock_specific: + result = configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert result == {"current_credential_id": "cred-1"} + mock_specific.assert_called_once() + + configuration.provider.model_credential_schema = _build_secret_model_schema() + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="gpt-4o", + model_type=ModelType.LLM, + credentials={"openai_api_key": "raw"}, + current_credential_id="cred-1", + current_credential_name="Main", + ) + ] + with patch.object(ProviderConfiguration, "obfuscated_credentials", return_value={"openai_api_key": "***"}): + fallback = configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", None) + assert fallback == { + "current_credential_id": "cred-1", + "current_credential_name": "Main", + "credentials": {"openai_api_key": "***"}, + } + + configuration.custom_configuration.models = [] + assert configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", None) is None + + +def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc"}' + ) + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + assert validated == {"openai_api_key": "enc-new"} + + session = Mock() + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"region": "us"} + with _patched_session(session): + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) + assert validated == {"region": "us"} + + +def test_create_update_delete_custom_model_credential_flow() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.flush.side_effect = lambda: None + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + credential_record = SimpleNamespace(id="cred-1", encrypted_config="{}", credential_name="Old", updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_custom_model_credential_name", return_value="API KEY 1"): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc"}, + ): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, None) + assert session.add.call_count == 2 + assert mock_cache.return_value.delete.call_count == 1 + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc2"}, + ): + with patch.object( + ProviderConfiguration, + "_get_custom_model_record", + return_value=provider_model_record, + ): + with patch.object( + ProviderConfiguration, + "_update_load_balancing_configs_with_credential", + ) as mock_lb: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="New Name", + credential_id="cred-1", + ) + assert credential_record.credential_name == "New Name" + assert mock_cache.return_value.delete.call_count == 1 + mock_lb.assert_called_once() + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + lb_config = SimpleNamespace(id="lb-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[lb_config]), + _exec_result(scalar=2), + ] + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert provider_model_record.credential_id is None + assert mock_cache.return_value.delete.call_count == 2 + + +def test_add_model_credential_to_model_and_switch_custom_model_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + session.add.assert_called_once() + session.commit.assert_called_once() + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with pytest.raises(ValueError, match="Can't add same credential"): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-2") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-2") + assert provider_model_record.credential_id == "cred-2" + mock_cache.return_value.delete.assert_called_once() + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(ValueError, match="custom model record not found"): + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id=None, updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert provider_model_record.credential_id == "cred-1" + mock_cache.return_value.delete.assert_called_once() + + +def test_delete_custom_model_and_model_setting_methods() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_model_record = SimpleNamespace(id="model-1") + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_custom_model(ModelType.LLM, "gpt-4o") + session.delete.assert_called_once_with(provider_model_record) + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + + session = Mock() + existing = SimpleNamespace(enabled=False, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + assert configuration.enable_model(ModelType.LLM, "gpt-4o") is existing + assert existing.enabled is True + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.enable_model(ModelType.LLM, "gpt-4o") + assert created.enabled is True + + session = Mock() + existing = SimpleNamespace(enabled=True, load_balancing_enabled=True, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + assert configuration.disable_model(ModelType.LLM, "gpt-4o") is existing + assert existing.enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.disable_model(ModelType.LLM, "gpt-4o") + assert created.enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.get_provider_model_setting(ModelType.LLM, "gpt-4o") + assert result is existing + + +def test_model_load_balancing_enable_disable_and_switch_preferred_provider_type_without_session() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar.return_value = 1 + with _patched_session(session): + with pytest.raises(ValueError, match="must be more than 1"): + configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + + session = Mock() + session.execute.return_value.scalar.return_value = 2 + existing = SimpleNamespace(load_balancing_enabled=False, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert result is existing + assert existing.load_balancing_enabled is True + + session = Mock() + session.execute.return_value.scalar.return_value = 2 + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert created.load_balancing_enabled is True + + session = Mock() + existing = SimpleNamespace(load_balancing_enabled=True, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.disable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert result is existing + assert existing.load_balancing_enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.disable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert created.load_balancing_enabled is False + + configuration.preferred_provider_type = ProviderType.SYSTEM + switch_session = Mock() + with _patched_session(switch_session): + switch_session.execute.return_value.scalars.return_value.first.return_value = None + configuration.switch_preferred_provider_type(ProviderType.CUSTOM) + assert any( + call.args and call.args[0].__class__.__name__ == "TenantPreferredModelProvider" + for call in switch_session.add.call_args_list + ) + switch_session.commit.assert_called() + + +def test_system_and_custom_provider_model_helpers_cover_remaining_skip_paths() -> None: + configuration = _build_provider_configuration() + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + models=[_build_ai_model("llm-model")], + ) + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="target", base_model_name="base", model_type=ModelType.LLM), + ], + ), + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="target", base_model_name="base", model_type=ModelType.LLM), + RestrictModel(model="error-model", base_model_name="base", model_type=ModelType.LLM), + RestrictModel(model="none-model", base_model_name="base", model_type=ModelType.LLM), + RestrictModel( + model="embed-model", + base_model_name="base", + model_type=ModelType.TEXT_EMBEDDING, + ), + ], + ), + ] + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + + def _system_schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-model": + raise RuntimeError("boom") + if model == "none-model": + return None + if model == "embed-model": + return _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING) + return _build_ai_model("target") + + with patch( + "core.entities.provider_configuration.original_provider_configurate_methods", + {"openai": [ConfigurateMethod.CUSTOMIZABLE_MODEL]}, + ): + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema): + system_models = configuration._get_system_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={ + ModelType.LLM: { + "target": ModelSettings( + model="target", + model_type=ModelType.LLM, + enabled=False, + load_balancing_configs=[], + ) + } + }, + ) + assert any(model.model == "target" and model.status == ModelStatus.DISABLED for model in system_models) + + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="skip-model-type", + model_type=ModelType.TEXT_EMBEDDING, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="skip-unadded", + model_type=ModelType.LLM, + credentials={"k": "v"}, + unadded_to_model_list=True, + ), + CustomModelConfiguration( + model="skip-filter", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="error-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="none-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="disabled-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + ] + + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("base-disabled")], + ) + model_setting_map = { + ModelType.LLM: { + "base-disabled": ModelSettings( + model="base-disabled", + model_type=ModelType.LLM, + enabled=False, + load_balancing_enabled=True, + load_balancing_configs=[ModelLoadBalancingConfiguration(id="lb-1", name="lb", credentials={})], + ), + "disabled-custom": ModelSettings( + model="disabled-custom", + model_type=ModelType.LLM, + enabled=False, + load_balancing_enabled=False, + load_balancing_configs=[], + ), + } + } + + def _custom_schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-custom": + raise RuntimeError("boom") + if model == "none-custom": + return None + return _build_ai_model(model) + + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_custom_schema): + custom_models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map=model_setting_map, + model="disabled-custom", + ) + assert any(model.model == "base-disabled" and model.status == ModelStatus.DISABLED for model in custom_models) + assert any(model.model == "disabled-custom" and model.status == ModelStatus.DISABLED for model in custom_models) + + +def test_get_current_credentials_skips_non_current_quota_restrictions() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="gpt-4o", base_model_name="free-base", model_type=ModelType.LLM), + ], + ), + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="gpt-4o", base_model_name="trial-base", model_type=ModelType.LLM), + ], + ), + ] + + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials["base_model_name"] == "trial-base" + + +def test_get_system_configuration_status_covers_disabled_and_quota_exceeded() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.enabled = False + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.UNSUPPORTED + + configuration.system_configuration.enabled = True + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=100, + is_valid=False, + restrict_models=[], + ) + ] + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.QUOTA_EXCEEDED + + +def test_get_specific_provider_credential_decrypts_and_obfuscates_credentials() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc-secret","region":"us"}' + ) + provider_record = SimpleNamespace(provider_name="aliased-openai") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw-secret"): + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + credentials = configuration._get_specific_provider_credential("cred-1") + + assert credentials == {"openai_api_key": "raw-secret", "region": "us"} + + +def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc-secret"}' + ) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with patch( + "core.entities.provider_configuration.encrypter.decrypt_token", + side_effect=RuntimeError("boom"), + ): + with patch("core.entities.provider_configuration.logger.exception") as mock_logger: + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + credentials = configuration._get_specific_provider_credential("cred-1") + + assert credentials == {"openai_api_key": "enc-secret"} + mock_logger.assert_called_once() + + +def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-new"} + + +def test_create_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.add.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_provider_credential_name", return_value="API KEY 9"): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.create_provider_credential({"api_key": "raw"}, None) + + session.rollback.assert_called_once() + + +def test_update_provider_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", "Main") + + +def test_update_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + encrypted_config="{}", + credential_name="Main", + updated_at=None, + ) + session.commit.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", "Main") + + session.rollback.assert_called_once() + + +def test_delete_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.delete.side_effect = RuntimeError("boom") + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=SimpleNamespace(id="cred-1")), + _exec_result(scalars_all=[]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.delete_provider_credential("cred-1") + + session.rollback.assert_called_once() + + +def test_switch_active_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + session.commit.side_effect = RuntimeError("boom") + provider_record = SimpleNamespace(id="provider-1", credential_id=None, updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with pytest.raises(RuntimeError, match="boom"): + configuration.switch_active_provider_credential("cred-1") + + session.rollback.assert_called_once() + + +def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + credential_name="Main", + encrypted_config='{"openai_api_key":"enc-secret"}', + ) + + with _patched_session(session): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")): + with patch("core.entities.provider_configuration.logger.exception") as mock_logger: + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + result = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert result["credentials"] == {"openai_api_key": "enc-secret"} + mock_logger.assert_called_once() + + +def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-new"} + + +def test_create_custom_model_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, "Main") + + +def test_create_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.add.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_custom_model_credential_name", return_value="API KEY 4"): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc"}, + ): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, None) + + session.rollback.assert_called_once() + + +def test_update_custom_model_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + +def test_update_custom_model_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_custom_model_credentials", return_value={"k": "v"}): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + +def test_update_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + encrypted_config="{}", + credential_name="Main", + updated_at=None, + ) + session.commit.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_custom_model_credentials", return_value={"k": "v"}): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + session.rollback.assert_called_once() + + +def test_delete_custom_model_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + +def test_delete_custom_model_credential_removes_custom_model_record_when_last_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[]), + _exec_result(scalar=1), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert any(call.args and call.args[0] is provider_model_record for call in session.delete.call_args_list) + + +def test_delete_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.delete.side_effect = RuntimeError("boom") + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=SimpleNamespace(id="cred-1")), + _exec_result(scalars_all=[]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session.rollback.assert_called_once() + + +def test_get_custom_provider_models_skips_schema_models_with_mismatched_type() -> None: + configuration = _build_provider_configuration() + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[ + _build_ai_model("llm-model", model_type=ModelType.LLM), + _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING), + ], + ) + + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={}, + ) + + assert any(model.model == "llm-model" for model in models) + assert all(model.model != "embed-model" for model in models) + + +def test_get_custom_provider_models_skips_custom_models_on_schema_error_or_none() -> None: + configuration = _build_provider_configuration() + configuration.custom_configuration.models = [ + CustomModelConfiguration(model="error-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + CustomModelConfiguration(model="none-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + CustomModelConfiguration(model="ok-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[], + ) + + def _schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-custom": + raise RuntimeError("boom") + if model == "none-custom": + return None + return _build_ai_model(model) + + with patch("core.entities.provider_configuration.logger.warning") as mock_warning: + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_schema): + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={}, + ) + + assert mock_warning.call_count == 1 + assert any(model.model == "ok-custom" for model in models) + assert all(model.model != "none-custom" for model in models) diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py new file mode 100644 index 0000000000..c5bfd05a1e --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -0,0 +1,72 @@ +import pytest + +from core.entities.parameter_entities import AppSelectorScope +from core.entities.provider_entities import ( + BasicProviderConfig, + ModelSettings, + ProviderConfig, + ProviderQuotaType, +) +from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType + + +def test_provider_quota_type_value_of_returns_enum_member() -> None: + # Arrange / Act + quota_type = ProviderQuotaType.value_of(ProviderQuotaType.TRIAL.value) + + # Assert + assert quota_type == ProviderQuotaType.TRIAL + + +def test_provider_quota_type_value_of_rejects_unknown_values() -> None: + # Arrange / Act / Assert + with pytest.raises(ValueError, match="No matching enum found"): + ProviderQuotaType.value_of("enterprise") + + +def test_basic_provider_config_type_value_of_handles_known_values() -> None: + # Arrange / Act + parameter_type = BasicProviderConfig.Type.value_of("text-input") + + # Assert + assert parameter_type == BasicProviderConfig.Type.TEXT_INPUT + + +def test_basic_provider_config_type_value_of_rejects_invalid_values() -> None: + # Arrange / Act / Assert + with pytest.raises(ValueError, match="invalid mode value"): + BasicProviderConfig.Type.value_of("unknown") + + +def test_provider_config_to_basic_provider_config_keeps_type_and_name() -> None: + # Arrange + provider_config = ProviderConfig( + type=BasicProviderConfig.Type.SELECT, + name="workspace", + scope=AppSelectorScope.ALL, + options=[ProviderConfig.Option(value="all", label=I18nObject(en_US="All"))], + ) + + # Act + basic_config = provider_config.to_basic_provider_config() + + # Assert + assert isinstance(basic_config, BasicProviderConfig) + assert basic_config.type == BasicProviderConfig.Type.SELECT + assert basic_config.name == "workspace" + + +def test_model_settings_accepts_model_field_name() -> None: + # Arrange / Act + settings = ModelSettings( + model="gpt-4o", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=False, + load_balancing_configs=[], + ) + + # Assert + assert settings.model == "gpt-4o" + assert settings.model_type == ModelType.LLM