From 245f6b824d7278704c5845929f6abe509931656b Mon Sep 17 00:00:00 2001 From: mahammadasim <135003320+mahammadasim@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:14:38 +0530 Subject: [PATCH] test: add test for core extension, external_data_tool and llm generator (#32468) --- api/core/llm_generator/llm_generator.py | 6 +- .../test_api_based_extension_requestor.py | 137 ++++ .../core/extension/test_extensible.py | 281 +++++++++ .../core/extension/test_extension.py | 90 +++ .../core/external_data_tool/api/test_api.py | 145 +++++ .../core/external_data_tool/test_base.py | 66 ++ .../test_external_data_fetch.py | 115 ++++ .../core/external_data_tool/test_factory.py | 58 ++ .../test_rule_config_generator.py | 103 +++ .../output_parser/test_structured_output.py | 402 ++++++++++++ .../core/llm_generator/test_llm_generator.py | 589 ++++++++++++++++++ 11 files changed, 1990 insertions(+), 2 deletions(-) create mode 100644 api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py create mode 100644 api/tests/unit_tests/core/extension/test_extensible.py create mode 100644 api/tests/unit_tests/core/extension/test_extension.py create mode 100644 api/tests/unit_tests/core/external_data_tool/api/test_api.py create mode 100644 api/tests/unit_tests/core/external_data_tool/test_base.py create mode 100644 api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py create mode 100644 api/tests/unit_tests/core/external_data_tool/test_factory.py create mode 100644 api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py create mode 100644 api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py create mode 100644 api/tests/unit_tests/core/llm_generator/test_llm_generator.py diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 6a09dbff35..c8848336d9 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -193,7 +193,8 @@ class LLMGenerator: error_step = "generate rule config" except Exception as e: logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + error = str(e) + error_step = "generate rule config" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -279,7 +280,8 @@ class LLMGenerator: except Exception as e: logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + error = str(e) + error_step = "handle unexpected exception" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" diff --git a/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py b/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py new file mode 100644 index 0000000000..399b531205 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py @@ -0,0 +1,137 @@ +import httpx +import pytest + +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor +from models.api_based_extension import APIBasedExtensionPoint + + +def test_request_success(mocker): + # Mock httpx.Client and its context manager + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + result = requestor.request(APIBasedExtensionPoint.PING, {"foo": "bar"}) + + assert result == {"result": "success"} + mock_client_instance.request.assert_called_once_with( + method="POST", + url="http://example.com", + json={"point": APIBasedExtensionPoint.PING.value, "params": {"foo": "bar"}}, + headers={"Content-Type": "application/json", "Authorization": "Bearer test_key"}, + ) + + +def test_request_with_ssrf_proxy(mocker): + # Mock dify_config + mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080") + mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", "https://proxy:8081") + + # Mock httpx.Client + mock_client = mocker.MagicMock() + mock_client_class = mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance = mock_client.__enter__.return_value + + # Mock response + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + # Mock HTTPTransport + mock_transport = mocker.patch("httpx.HTTPTransport") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + requestor.request(APIBasedExtensionPoint.PING, {}) + + # Verify httpx.Client was called with mounts + mock_client_class.assert_called_once() + kwargs = mock_client_class.call_args.kwargs + assert "mounts" in kwargs + assert "http://" in kwargs["mounts"] + assert "https://" in kwargs["mounts"] + assert mock_transport.call_count == 2 + + +def test_request_with_only_one_proxy_config(mocker): + # Mock dify_config with only one proxy + mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080") + mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", None) + + # Mock httpx.Client + mock_client = mocker.MagicMock() + mock_client_class = mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance = mock_client.__enter__.return_value + + # Mock response + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + requestor.request(APIBasedExtensionPoint.PING, {}) + + # Verify httpx.Client was called with mounts=None (default) + mock_client_class.assert_called_once() + kwargs = mock_client_class.call_args.kwargs + assert kwargs.get("mounts") is None + + +def test_request_timeout(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance.request.side_effect = httpx.TimeoutException("timeout") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request timeout"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_connection_error(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance.request.side_effect = httpx.RequestError("error") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request connection error"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_error_status_code(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request error, status_code: 404, content: Not Found"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_error_status_code_long_content(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 500 + mock_response.text = "A" * 200 # Testing truncation of content + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + expected_content = "A" * 100 + with pytest.raises(ValueError, match=f"request error, status_code: 500, content: {expected_content}"): + requestor.request(APIBasedExtensionPoint.PING, {}) diff --git a/api/tests/unit_tests/core/extension/test_extensible.py b/api/tests/unit_tests/core/extension/test_extensible.py new file mode 100644 index 0000000000..9bce0cd7c8 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_extensible.py @@ -0,0 +1,281 @@ +import json +import types +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from core.extension.extensible import Extensible + + +class TestExtensible: + def test_init(self): + tenant_id = "tenant_123" + config = {"key": "value"} + ext = Extensible(tenant_id, config) + assert ext.tenant_id == tenant_id + assert ext.config == config + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.Path.read_text") + @patch("core.extension.extensible.importlib.util.module_from_spec") + @patch("core.extension.extensible.sort_to_dict_by_position_map") + def test_scan_extensions_success( + self, + mock_sort, + mock_module_from_spec, + mock_read_text, + mock_exists, + mock_isdir, + mock_listdir, + mock_dirname, + mock_find_spec, + ): + # Setup + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [ + ["ext1"], # package_dir + ["ext1.py", "__builtin__"], # subdir_path + ] + mock_isdir.return_value = True + + mock_exists.return_value = True + mock_read_text.return_value = "10" + + # Use types.ModuleType to avoid MagicMock __dict__ issues + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + mock_sort.side_effect = lambda position_map, data, name_func: data + + # Execute + results = Extensible.scan_extensions() + + # Assert + assert len(results) == 1 + assert results[0].name == "ext1" + assert results[0].position == 10 + assert results[0].builtin is True + assert results[0].extension_class == MockExtension + + @patch("core.extension.extensible.importlib.util.find_spec") + def test_scan_extensions_package_not_found(self, mock_find_spec): + mock_find_spec.return_value = None + with pytest.raises(ImportError, match="Could not find package"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + def test_scan_extensions_skip_subdirs(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + mock_find_spec.return_value = package_spec + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["__pycache__", "not_a_dir", "missing_py_file"], []] + + mock_isdir.side_effect = [False, True] + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_not_builtin_success( + self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py", "schema.json"]] + mock_isdir.return_value = True + + # exists checks: only schema.json needs to exist + mock_exists.return_value = True + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + schema_content = json.dumps({"label": {"en": "Test"}, "form_schema": [{"name": "field1"}]}) + + with ( + patch("builtins.open", mock_open(read_data=schema_content)), + patch( + "core.extension.extensible.sort_to_dict_by_position_map", + side_effect=lambda position_map, data, name_func: data, + ), + ): + results = Extensible.scan_extensions() + + assert len(results) == 1 + assert results[0].name == "ext1" + assert results[0].builtin is False + assert results[0].label == {"en": "Test"} + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_not_builtin_missing_schema( + self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + # exists: only schema.json checked, and return False + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.importlib.util.module_from_spec") + @patch("core.extension.extensible.os.path.exists") + def test_scan_extensions_no_extension_class( + self, mock_exists, mock_module_from_spec, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + # Mock not builtin + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + mock_mod.SomeOtherClass = type("SomeOtherClass", (), {}) + mock_module_from_spec.return_value = mock_mod + + # We need to ensure we don't crash if checking schema (but we won't reach there because class not found) + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + def test_scan_extensions_module_import_error(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + mock_find_spec.side_effect = [package_spec, None] # No module spec + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + with pytest.raises(ImportError, match="Failed to load module"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + def test_scan_extensions_general_exception(self, mock_find_spec): + mock_find_spec.side_effect = Exception("Unexpected error") + with pytest.raises(Exception, match="Unexpected error"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.Path.read_text") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_builtin_without_position_file( + self, mock_module_from_spec, mock_read_text, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + mock_listdir.side_effect = [["ext1"], ["ext1.py", "__builtin__"]] + mock_isdir.return_value = True + + # builtin exists in listdir, but os.path.exists(builtin_file_path) returns False + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + with patch( + "core.extension.extensible.sort_to_dict_by_position_map", + side_effect=lambda position_map, data, name_func: data, + ): + results = Extensible.scan_extensions() + + assert len(results) == 1 + assert results[0].position == 0 diff --git a/api/tests/unit_tests/core/extension/test_extension.py b/api/tests/unit_tests/core/extension/test_extension.py new file mode 100644 index 0000000000..4ad32d3840 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_extension.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.extension.extensible import ExtensionModule, ModuleExtension +from core.extension.extension import Extension + + +class TestExtension: + def setup_method(self): + # Reset the private class attribute before each test + Extension._Extension__module_extensions = {} + + def test_init(self): + # Mock scan_extensions for Moderation and ExternalDataTool + mock_mod_extensions = {"mod1": ModuleExtension(name="mod1")} + mock_ext_extensions = {"ext1": ModuleExtension(name="ext1")} + + extension = Extension() + + # We need to mock scan_extensions on the classes defined in Extension.module_classes + with ( + patch("core.extension.extension.Moderation.scan_extensions", return_value=mock_mod_extensions), + patch("core.extension.extension.ExternalDataTool.scan_extensions", return_value=mock_ext_extensions), + ): + extension.init() + + # Check if internal state is updated + internal_state = Extension._Extension__module_extensions + assert internal_state[ExtensionModule.MODERATION.value] == mock_mod_extensions + assert internal_state[ExtensionModule.EXTERNAL_DATA_TOOL.value] == mock_ext_extensions + + def test_module_extensions_success(self): + # Setup data + mock_extensions = {"name1": ModuleExtension(name="name1"), "name2": ModuleExtension(name="name2")} + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: mock_extensions} + + extension = Extension() + result = extension.module_extensions(ExtensionModule.MODERATION.value) + + assert len(result) == 2 + assert any(e.name == "name1" for e in result) + assert any(e.name == "name2" for e in result) + + def test_module_extensions_not_found(self): + extension = Extension() + with pytest.raises(ValueError, match="Extension Module unknown not found"): + extension.module_extensions("unknown") + + def test_module_extension_success(self): + mock_ext = ModuleExtension(name="test_ext") + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + result = extension.module_extension(ExtensionModule.MODERATION, "test_ext") + assert result == mock_ext + + def test_module_extension_module_not_found(self): + extension = Extension() + # ExtensionModule.MODERATION is "moderation" + with pytest.raises(ValueError, match="Extension Module moderation not found"): + extension.module_extension(ExtensionModule.MODERATION, "any") + + def test_module_extension_extension_not_found(self): + # We need a non-empty dict because 'if not module_extensions' in extension.py + # returns True for an empty dict, which raises the module not found error instead. + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"other": MagicMock()}} + + extension = Extension() + with pytest.raises(ValueError, match="Extension unknown not found"): + extension.module_extension(ExtensionModule.MODERATION, "unknown") + + def test_extension_class_success(self): + class MockClass: + pass + + mock_ext = ModuleExtension(name="test_ext", extension_class=MockClass) + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + result = extension.extension_class(ExtensionModule.MODERATION, "test_ext") + assert result == MockClass + + def test_extension_class_none(self): + mock_ext = ModuleExtension(name="test_ext", extension_class=None) + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + with pytest.raises(AssertionError): + extension.extension_class(ExtensionModule.MODERATION, "test_ext") diff --git a/api/tests/unit_tests/core/external_data_tool/api/test_api.py b/api/tests/unit_tests/core/external_data_tool/api/test_api.py new file mode 100644 index 0000000000..1653124bd8 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/api/test_api.py @@ -0,0 +1,145 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.external_data_tool.api.api import ApiExternalDataTool +from models.api_based_extension import APIBasedExtensionPoint + + +def test_api_external_data_tool_name(): + assert ApiExternalDataTool.name == "api" + + +@patch("core.external_data_tool.api.api.db") +def test_validate_config_success(mock_db): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_db.session.scalar.return_value = mock_extension + + # Should not raise exception + ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"}) + + +def test_validate_config_missing_id(): + with pytest.raises(ValueError, match="api_based_extension_id is required"): + ApiExternalDataTool.validate_config("tenant_id", {}) + + +@patch("core.external_data_tool.api.api.db") +def test_validate_config_invalid_id(mock_db): + mock_db.session.scalar.return_value = None + + with pytest.raises(ValueError, match="api_based_extension_id is invalid"): + ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"}) + + +@pytest.fixture +def api_tool(): + # Use standard kwargs as it inherits from ExternalDataTool which is typically a Pydantic BaseModel + return ApiExternalDataTool( + tenant_id="tenant_id", app_id="app_id", variable="var1", config={"api_based_extension_id": "ext_id"} + ) + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_success(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"result": "success_result"} + + res = api_tool.query({"input1": "value1"}, "query_str") + + assert res == "success_result" + + mock_requestor_class.assert_called_once_with(api_endpoint="http://api", api_key="decrypted_key") + mock_requestor.request.assert_called_once_with( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": "app_id", "tool_variable": "var1", "inputs": {"input1": "value1"}, "query": "query_str"}, + ) + + +def test_query_missing_config(): + api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1") + api_tool.config = None # Force None + with pytest.raises(ValueError, match="config is required"): + api_tool.query({}, "") + + +def test_query_missing_extension_id(): + api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1", config={"dummy": "value"}) + with pytest.raises(AssertionError, match="api_based_extension_id is required"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +def test_query_invalid_extension(mock_db, api_tool): + mock_db.session.scalar.return_value = None + + with pytest.raises(ValueError, match=".*error: api_based_extension_id is invalid"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_requestor_init_error(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor_class.side_effect = Exception("init error") + + with pytest.raises(ValueError, match=".*error: init error"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_no_result_in_response(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"other": "value"} + + with pytest.raises(ValueError, match=".*error: result not found in response"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_result_not_string(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"result": 123} # Not a string + + with pytest.raises(ValueError, match=".*error: result is not string"): + api_tool.query({}, "") diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py new file mode 100644 index 0000000000..216cda83c5 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -0,0 +1,66 @@ +import pytest + +from core.extension.extensible import ExtensionModule +from core.external_data_tool.base import ExternalDataTool + + +class TestExternalDataTool: + def test_module_attribute(self): + assert ExternalDataTool.module == ExtensionModule.EXTERNAL_DATA_TOOL + + def test_init(self): + # Create a concrete subclass to test init + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + return super().validate_config(tenant_id, config) + + def query(self, inputs: dict, query: str | None = None) -> str: + return super().query(inputs, query) + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) + assert tool.tenant_id == "tenant_1" + assert tool.app_id == "app_1" + assert tool.variable == "var_1" + assert tool.config == {"key": "value"} + + def test_init_without_config(self): + # Create a concrete subclass to test init + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + pass + + def query(self, inputs: dict, query: str | None = None) -> str: + return "" + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") + assert tool.tenant_id == "tenant_1" + assert tool.app_id == "app_1" + assert tool.variable == "var_1" + assert tool.config is None + + def test_validate_config_raises_not_implemented(self): + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + return super().validate_config(tenant_id, config) + + def query(self, inputs: dict, query: str | None = None) -> str: + return "" + + with pytest.raises(NotImplementedError): + ConcreteTool.validate_config("tenant_1", {}) + + def test_query_raises_not_implemented(self): + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + pass + + def query(self, inputs: dict, query: str | None = None) -> str: + return super().query(inputs, query) + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") + with pytest.raises(NotImplementedError): + tool.query({}) diff --git a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py new file mode 100644 index 0000000000..86b461cf04 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py @@ -0,0 +1,115 @@ +from unittest.mock import patch + +import pytest +from flask import Flask + +from core.app.app_config.entities import ExternalDataVariableEntity +from core.external_data_tool.external_data_fetch import ExternalDataFetch + + +class TestExternalDataFetch: + @pytest.fixture + def app(self): + app = Flask(__name__) + return app + + def test_fetch_success(self, app): + with app.app_context(): + fetcher = ExternalDataFetch() + + # Setup mocks + tool1 = ExternalDataVariableEntity(variable="var1", type="type1", config={"c1": "v1"}) + tool2 = ExternalDataVariableEntity(variable="var2", type="type2", config={"c2": "v2"}) + + external_data_tools = [tool1, tool2] + inputs = {"input_key": "input_value"} + query = "test query" + + with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory: + # Create distinct mock instances for each tool to ensure deterministic results + # This approach is robust regardless of thread scheduling order + from unittest.mock import MagicMock + + def factory_side_effect(*args, **kwargs): + variable = kwargs.get("variable") + mock_instance = MagicMock() + if variable == "var1": + mock_instance.query.return_value = "result1" + elif variable == "var2": + mock_instance.query.return_value = "result2" + return mock_instance + + MockFactory.side_effect = factory_side_effect + + result_inputs = fetcher.fetch( + tenant_id="tenant1", + app_id="app1", + external_data_tools=external_data_tools, + inputs=inputs, + query=query, + ) + + # Each tool gets its deterministic result regardless of thread completion order + assert result_inputs["var1"] == "result1" + assert result_inputs["var2"] == "result2" + assert result_inputs["input_key"] == "input_value" + assert len(result_inputs) == 3 + + # Verify factory calls + assert MockFactory.call_count == 2 + MockFactory.assert_any_call( + name="type1", tenant_id="tenant1", app_id="app1", variable="var1", config={"c1": "v1"} + ) + MockFactory.assert_any_call( + name="type2", tenant_id="tenant1", app_id="app1", variable="var2", config={"c2": "v2"} + ) + + def test_fetch_no_tools(self): + # We don't necessarily need app_context if there are no tools, + # but fetch calls current_app._get_current_object() only inside the loop. + # Wait, let's look at the code. + # for tool in external_data_tools: + # executor.submit(..., current_app._get_current_object(), ...) + # So if external_data_tools is empty, it shouldn't access current_app. + fetcher = ExternalDataFetch() + inputs = {"input_key": "input_value"} + result_inputs = fetcher.fetch( + tenant_id="tenant1", app_id="app1", external_data_tools=[], inputs=inputs, query="test query" + ) + assert result_inputs == inputs + assert result_inputs is not inputs # Should be a copy + + def test_fetch_with_none_variable(self, app): + with app.app_context(): + fetcher = ExternalDataFetch() + tool = ExternalDataVariableEntity(variable="var1", type="type1", config={}) + + # Patch _query_external_data_tool to return None variable + with patch.object(ExternalDataFetch, "_query_external_data_tool") as mock_query: + mock_query.return_value = (None, "some_result") + + result_inputs = fetcher.fetch( + tenant_id="t1", app_id="a1", external_data_tools=[tool], inputs={"in": "val"}, query="q" + ) + + assert "var1" not in result_inputs + assert result_inputs == {"in": "val"} + + def test_query_external_data_tool(self, app): + fetcher = ExternalDataFetch() + tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"}) + + with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory: + mock_factory_instance = MockFactory.return_value + mock_factory_instance.query.return_value = "query_result" + + var, res = fetcher._query_external_data_tool( + flask_app=app, tenant_id="t1", app_id="a1", external_data_tool=tool, inputs={"i": "v"}, query="q" + ) + + assert var == "var1" + assert res == "query_result" + MockFactory.assert_called_once_with( + name="type1", tenant_id="t1", app_id="a1", variable="var1", config={"k": "v"} + ) + mock_factory_instance.query.assert_called_once_with(inputs={"i": "v"}, query="q") diff --git a/api/tests/unit_tests/core/external_data_tool/test_factory.py b/api/tests/unit_tests/core/external_data_tool/test_factory.py new file mode 100644 index 0000000000..6bb384b0ac --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_factory.py @@ -0,0 +1,58 @@ +from unittest.mock import MagicMock, patch + +from core.extension.extensible import ExtensionModule +from core.external_data_tool.factory import ExternalDataToolFactory + + +def test_external_data_tool_factory_init(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_code_based_extension.extension_class.return_value = mock_extension_class + + name = "test_tool" + tenant_id = "tenant_123" + app_id = "app_456" + variable = "var_v" + config = {"key": "value"} + + factory = ExternalDataToolFactory(name, tenant_id, app_id, variable, config) + + mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name) + mock_extension_class.assert_called_once_with( + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config + ) + + +def test_external_data_tool_factory_validate_config(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_code_based_extension.extension_class.return_value = mock_extension_class + + name = "test_tool" + tenant_id = "tenant_123" + config = {"key": "value"} + + ExternalDataToolFactory.validate_config(name, tenant_id, config) + + mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name) + mock_extension_class.validate_config.assert_called_once_with(tenant_id, config) + + +def test_external_data_tool_factory_query(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_extension_instance = MagicMock() + mock_extension_class.return_value = mock_extension_instance + mock_code_based_extension.extension_class.return_value = mock_extension_class + + mock_extension_instance.query.return_value = "query_result" + + factory = ExternalDataToolFactory("name", "tenant", "app", "var", {}) + + inputs = {"input_key": "input_value"} + query = "search_query" + + result = factory.query(inputs, query) + + assert result == "query_result" + mock_extension_instance.query.assert_called_once_with(inputs, query) diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py new file mode 100644 index 0000000000..b2783bdf99 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py @@ -0,0 +1,103 @@ +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.prompts import ( + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, +) + + +class TestRuleConfigGeneratorOutputParser: + def test_get_format_instructions(self): + parser = RuleConfigGeneratorOutputParser() + instructions = parser.get_format_instructions() + assert instructions == ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) + + def test_parse_success(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"], + "opening_statement": "Hello!" +} +``` +""" + result = parser.parse(text) + assert result["prompt"] == "This is a prompt" + assert result["variables"] == ["var1", "var2"] + assert result["opening_statement"] == "Hello!" + + def test_parse_invalid_json(self): + parser = RuleConfigGeneratorOutputParser() + text = "invalid json" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Parsing text" in str(excinfo.value) + assert "could not find json block in the output" in str(excinfo.value) + + def test_parse_missing_keys(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"] +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "expected key `opening_statement` to be present" in str(excinfo.value) + + def test_parse_wrong_type_prompt(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": 123, + "variables": ["var1", "var2"], + "opening_statement": "Hello!" +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'prompt' to be a string" in str(excinfo.value) + + def test_parse_wrong_type_variables(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": "not a list", + "opening_statement": "Hello!" +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'variables' to be a list" in str(excinfo.value) + + def test_parse_wrong_type_opening_statement(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"], + "opening_statement": 123 +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'opening_statement' to be a str" in str(excinfo.value) diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py new file mode 100644 index 0000000000..46c9dc6f9c --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -0,0 +1,402 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import ( + ResponseFormat, + _handle_native_json_schema, + _handle_prompt_based_schema, + _parse_structured_output, + _prepare_schema_for_model, + _set_response_format, + convert_boolean_to_string, + invoke_llm_with_structured_output, + remove_additional_properties, +) +from core.model_manager import ModelInstance +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultWithStructuredOutput, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType + + +class TestStructuredOutput: + def test_remove_additional_properties(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "additionalProperties": False, + "nested": {"type": "object", "additionalProperties": True}, + "items": [{"type": "object", "additionalProperties": False}], + } + remove_additional_properties(schema) + assert "additionalProperties" not in schema + assert "additionalProperties" not in schema["nested"] + assert "additionalProperties" not in schema["items"][0] + + # Test with non-dict input + remove_additional_properties(None) # Should not raise + remove_additional_properties([]) # Should not raise + + def test_convert_boolean_to_string(self): + schema = { + "type": "object", + "properties": { + "is_active": {"type": "boolean"}, + "tags": {"type": "array", "items": {"type": "boolean"}}, + "list_schema": [{"type": "boolean"}], + }, + } + convert_boolean_to_string(schema) + assert schema["properties"]["is_active"]["type"] == "string" + assert schema["properties"]["tags"]["items"]["type"] == "string" + assert schema["properties"]["list_schema"][0]["type"] == "string" + + # Test with non-dict input + convert_boolean_to_string(None) # Should not raise + convert_boolean_to_string([]) # Should not raise + + def test_parse_structured_output_valid(self): + text = '{"key": "value"}' + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_non_dict_valid_json(self): + # Even if it's valid JSON, if it's not a dict, it should try repair or fail + text = '["a", "b"]' + with patch("json_repair.loads") as mock_repair: + mock_repair.return_value = {"key": "value"} + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_not_dict_fail_via_validate(self): + # Force TypeAdapter to return a non-dict to trigger line 292 + with patch("pydantic.TypeAdapter.validate_json") as mock_validate: + mock_validate.return_value = ["a list"] + with pytest.raises(OutputParserError) as excinfo: + _parse_structured_output('["a list"]') + assert "Failed to parse structured output" in str(excinfo.value) + + def test_parse_structured_output_repair_success(self): + text = "{'key': 'value'}" # Invalid JSON (single quotes) + # json_repair should handle this + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_repair_list(self): + # Deepseek-r1 case: result is a list containing a dict + text = '[{"key": "value"}]' + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_repair_list_no_dict(self): + # Deepseek-r1 case: result is a list with NO dict + text = "[1, 2, 3]" + assert _parse_structured_output(text) == {} + + def test_parse_structured_output_repair_fail(self): + text = "not a json at all" + with patch("json_repair.loads") as mock_repair: + mock_repair.return_value = "still not a dict or list" + with pytest.raises(OutputParserError): + _parse_structured_output(text) + + def test_set_response_format(self): + # Test JSON + params = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON], + ) + ] + _set_response_format(params, rules) + assert params["response_format"] == ResponseFormat.JSON + + # Test JSON_OBJECT + params = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_OBJECT], + ) + ] + _set_response_format(params, rules) + assert params["response_format"] == ResponseFormat.JSON_OBJECT + + def test_handle_native_json_schema(self): + provider = "openai" + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + structured_output_schema = {"type": "object"} + model_parameters = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_SCHEMA], + ) + ] + + updated_params = _handle_native_json_schema( + provider, model_schema, structured_output_schema, model_parameters, rules + ) + + assert "json_schema" in updated_params + assert json.loads(updated_params["json_schema"]) == {"schema": {"type": "object"}, "name": "llm_response"} + assert updated_params["response_format"] == ResponseFormat.JSON_SCHEMA + + def test_handle_native_json_schema_no_format_rule(self): + provider = "openai" + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + structured_output_schema = {"type": "object"} + model_parameters = {} + rules = [] + + updated_params = _handle_native_json_schema( + provider, model_schema, structured_output_schema, model_parameters, rules + ) + + assert "json_schema" in updated_params + assert "response_format" not in updated_params + + def test_handle_prompt_based_schema_with_system_prompt(self): + prompt_messages = [ + SystemPromptMessage(content="Existing system prompt"), + UserPromptMessage(content="User question"), + ] + schema = {"type": "object"} + + result = _handle_prompt_based_schema(prompt_messages, schema) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert "Existing system prompt" in result[0].content + assert json.dumps(schema) in result[0].content + assert isinstance(result[1], UserPromptMessage) + + def test_handle_prompt_based_schema_without_system_prompt(self): + prompt_messages = [UserPromptMessage(content="User question")] + schema = {"type": "object"} + + result = _handle_prompt_based_schema(prompt_messages, schema) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert json.dumps(schema) in result[0].content + assert isinstance(result[1], UserPromptMessage) + + def test_prepare_schema_for_model_gemini(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gemini-1.5-pro" + schema = {"type": "object", "additionalProperties": False} + + result = _prepare_schema_for_model("google", model_schema, schema) + assert "additionalProperties" not in result + + def test_prepare_schema_for_model_ollama(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "llama3" + schema = {"type": "object"} + + result = _prepare_schema_for_model("ollama", model_schema, schema) + assert result == schema + + def test_prepare_schema_for_model_default(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + schema = {"type": "object"} + + result = _prepare_schema_for_model("openai", model_schema, schema) + assert result == {"schema": schema, "name": "llm_response"} + + def test_invoke_llm_with_structured_output_no_stream_native(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = True + model_schema.parameter_rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_SCHEMA], + ) + ] + model_schema.model = "gpt-4o" + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content='{"result": "success"}') + mock_result.model = "gpt-4o" + mock_result.usage = LLMUsage.empty_usage() + mock_result.system_fingerprint = "fp_native" + mock_result.prompt_messages = [UserPromptMessage(content="hi")] + + model_instance.invoke_llm.return_value = mock_result + + result = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={"type": "object"}, + stream=False, + ) + + assert isinstance(result, LLMResultWithStructuredOutput) + assert result.structured_output == {"result": "success"} + assert result.system_fingerprint == "fp_native" + + def test_invoke_llm_with_structured_output_no_stream_prompt_based(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON], + ) + ] + model_schema.model = "claude-3" + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content='{"result": "success"}') + mock_result.model = "claude-3" + mock_result.usage = LLMUsage.empty_usage() + mock_result.system_fingerprint = "fp_prompt" + mock_result.prompt_messages = [] + + model_instance.invoke_llm.return_value = mock_result + + result = invoke_llm_with_structured_output( + provider="anthropic", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={"type": "object"}, + stream=False, + ) + + assert isinstance(result, LLMResultWithStructuredOutput) + assert result.structured_output == {"result": "success"} + assert result.system_fingerprint == "fp_prompt" + + def test_invoke_llm_with_structured_output_no_string_error(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content=[TextPromptMessageContent(data="not a string")]) + + model_instance.invoke_llm.return_value = mock_result + + with pytest.raises(OutputParserError) as excinfo: + invoke_llm_with_structured_output( + provider="anthropic", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[], + json_schema={}, + stream=False, + ) + assert "Failed to parse structured output, LLM result is not a string" in str(excinfo.value) + + def test_invoke_llm_with_structured_output_stream(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + model_schema.model = "gpt-4" + + model_instance = MagicMock(spec=ModelInstance) + + # Mock chunks + chunk1 = MagicMock(spec=LLMResultChunk) + chunk1.delta = LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content='{"key": '), usage=LLMUsage.empty_usage() + ) + chunk1.prompt_messages = [UserPromptMessage(content="hi")] + chunk1.system_fingerprint = "fp1" + + chunk2 = MagicMock(spec=LLMResultChunk) + chunk2.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content='"value"}')) + chunk2.prompt_messages = [UserPromptMessage(content="hi")] + chunk2.system_fingerprint = "fp1" + + chunk3 = MagicMock(spec=LLMResultChunk) + chunk3.delta = LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data=" "), + ] + ), + ) + chunk3.prompt_messages = [UserPromptMessage(content="hi")] + chunk3.system_fingerprint = "fp1" + + event4 = MagicMock() + event4.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="")) + + model_instance.invoke_llm.return_value = [chunk1, chunk2, chunk3, event4] + + generator = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={}, + stream=True, + ) + + chunks = list(generator) + assert len(chunks) == 5 + assert chunks[-1].structured_output == {"key": "value"} + assert chunks[-1].system_fingerprint == "fp1" + assert chunks[-1].prompt_messages == [UserPromptMessage(content="hi")] + + def test_invoke_llm_with_structured_output_stream_no_id_events(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + model_schema.model = "gpt-4" + + model_instance = MagicMock(spec=ModelInstance) + model_instance.invoke_llm.return_value = [] + + generator = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[], + json_schema={}, + stream=True, + ) + + with pytest.raises(OutputParserError): + list(generator) + + def test_parse_structured_output_empty_string(self): + with pytest.raises(OutputParserError): + _parse_structured_output("") diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py new file mode 100644 index 0000000000..5b7640696f --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -0,0 +1,589 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.app_config.entities import ModelConfig +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload +from core.llm_generator.llm_generator import LLMGenerator +from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError + + +class TestLLMGenerator: + @pytest.fixture + def mock_model_instance(self): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + instance = MagicMock() + mock_manager.return_value.get_default_model_instance.return_value = instance + mock_manager.return_value.get_model_instance.return_value = instance + yield instance + + @pytest.fixture + def model_config_entity(self): + return ModelConfig(provider="openai", name="gpt-4", mode=LLMMode.CHAT, completion_params={"temperature": 0.7}) + + def test_generate_conversation_name_success(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Test Conversation Name"}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager") as mock_trace: + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "Test Conversation Name" + mock_trace.assert_called_once() + + def test_generate_conversation_name_truncated(self, mock_model_instance): + long_query = "a" * 2100 + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Short Name"}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", long_query) + assert name == "Short Name" + + def test_generate_conversation_name_empty_answer(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "" + mock_model_instance.invoke_llm.return_value = mock_response + + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "" + + def test_generate_conversation_name_json_repair(self, mock_model_instance): + mock_response = MagicMock() + # Invalid JSON that json_repair can fix + mock_response.message.get_text_content.return_value = "{'Your Output': 'Repaired Name'}" + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "Repaired Name" + + def test_generate_conversation_name_not_dict_result(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '["not a dict"]' + mock_model_instance.invoke_llm.return_value = mock_response + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "test query" + + def test_generate_conversation_name_no_output_in_dict(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"something": "else"}' + mock_model_instance.invoke_llm.return_value = mock_response + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "test query" + + def test_generate_conversation_name_long_output(self, mock_model_instance): + long_output = "a" * 100 + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": long_output}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert len(name) == 78 # 75 + "..." + assert name.endswith("...") + + def test_generate_suggested_questions_after_answer_success(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '["Question 1?", "Question 2?"]' + mock_model_instance.invoke_llm.return_value = mock_response + + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert len(questions) == 2 + assert questions[0] == "Question 1?" + + def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_suggested_questions_after_answer_invoke_error(self, mock_model_instance): + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_suggested_questions_after_answer_exception(self, mock_model_instance): + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "Generated Prompt" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert result["prompt"] == "Generated Prompt" + assert result["error"] == "" + + def test_generate_rule_config_no_variable_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate rule config" in result["error"] + + def test_generate_rule_config_no_variable_exception(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate rule config" in result["error"] + assert "Random error" in result["error"] + + def test_generate_rule_config_with_variable_success(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + # Mocking 3 calls for invoke_llm + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + mock_res2 = MagicMock() + mock_res2.message.get_text_content.return_value = '"var1", "var2"' + + mock_res3 = MagicMock() + mock_res3.message.get_text_content.return_value = "Opening Statement" + + mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, mock_res3] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert result["prompt"] == "Step 1 Prompt" + assert result["variables"] == ["var1", "var2"] + assert result["opening_statement"] == "Opening Statement" + assert result["error"] == "" + + def test_generate_rule_config_with_variable_step1_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_model_instance.invoke_llm.side_effect = InvokeError("Step 1 Failed") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate prefix prompt" in result["error"] + + def test_generate_rule_config_with_variable_step2_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + # Step 2 fails + mock_model_instance.invoke_llm.side_effect = [mock_res1, InvokeError("Step 2 Failed"), MagicMock()] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate variables" in result["error"] + + def test_generate_rule_config_with_variable_step3_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + mock_res2 = MagicMock() + mock_res2.message.get_text_content.return_value = '"var1"' + + # Step 3 fails + mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, InvokeError("Step 3 Failed")] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate conversation opener" in result["error"] + + def test_generate_rule_config_with_variable_exception(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + # Mock any step to throw Exception + mock_model_instance.invoke_llm.side_effect = Exception("Unexpected multi-step error") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to handle unexpected exception" in result["error"] + assert "Unexpected multi-step error" in result["error"] + + def test_generate_code_python_success(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload( + instruction="print hello", code_language="python", model_config=model_config_entity + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "print('hello')" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_code("tenant_id", payload) + assert result["code"] == "print('hello')" + assert result["language"] == "python" + + def test_generate_code_javascript_success(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload( + instruction="console log hello", code_language="javascript", model_config=model_config_entity + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "console.log('hello')" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_code("tenant_id", payload) + assert result["code"] == "console.log('hello')" + assert result["language"] == "javascript" + + def test_generate_code_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_code("tenant_id", payload) + assert "Failed to generate code" in result["error"] + + def test_generate_code_exception(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_code("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + + def test_generate_qa_document_success(self, mock_model_instance): + mock_response = MagicMock(spec=LLMResult) + mock_response.message = MagicMock() + mock_response.message.get_text_content.return_value = "QA Document Content" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_qa_document("tenant_id", "query", "English") + assert result == "QA Document Content" + + def test_generate_qa_document_type_error(self, mock_model_instance): + mock_model_instance.invoke_llm.return_value = "Not an LLMResult" + + with pytest.raises(TypeError, match="Expected LLMResult when stream=False"): + LLMGenerator.generate_qa_document("tenant_id", "query", "English") + + def test_generate_structured_output_success(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"type": "object", "properties": {}}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + parsed_output = json.loads(result["output"]) + assert parsed_output["type"] == "object" + assert result["error"] == "" + + def test_generate_structured_output_json_repair(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "{'type': 'object'}" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + parsed_output = json.loads(result["output"]) + assert parsed_output["type"] == "object" + + def test_generate_structured_output_not_dict_or_list(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "true" # parsed as bool + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + assert "Failed to parse structured output" in result["error"] + + def test_generate_structured_output_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "Failed to generate JSON Schema" in result["error"] + + def test_generate_structured_output_exception(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + # Mock __instruction_modify_common call via invoke_llm + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert result == {"modified": "prompt"} + + def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + last_run = MagicMock() + last_run.query = "q" + last_run.answer = "a" + last_run.error = "e" + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert result == {"modified": "prompt"} + + def test_instruction_modify_workflow_app_not_found(self): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="App not found."): + LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock()) + + def test_instruction_modify_workflow_no_workflow(self): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + with pytest.raises(ValueError, match="Workflow not found for the given app model."): + LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", workflow_service) + + def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + + last_run = MagicMock() + last_run.node_type = "llm" + last_run.status = "s" + last_run.error = "e" + # Return regular values, not Mocks + last_run.execution_metadata_dict = {"agent_log": [{"status": "s", "error": "e", "data": {}}]} + last_run.load_full_inputs.return_value = {"in": "val"} + + workflow_service.get_node_last_run.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "workflow"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "workflow"} + + def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "fallback"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "fallback"} + + def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + # Cause exception in node_type logic + workflow.graph_dict = {"graph": {"nodes": []}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "fallback"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "fallback"} + + def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + + last_run = MagicMock() + last_run.node_type = "llm" + last_run.status = "s" + last_run.error = "e" + # Return regular empty list, not a Mock + last_run.execution_metadata_dict = {"agent_log": []} + last_run.load_full_inputs.return_value = {} + + workflow_service.get_node_last_run.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "workflow"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "workflow"} + + def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity): + # Testing placeholders replacement via instruction_modify_legacy for convenience + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"ok": true}' + mock_model_instance.invoke_llm.return_value = mock_response + + instruction = "Test {{#last_run#}} and {{#current#}} and {{#error_message#}}" + LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current_val", instruction, model_config_entity, "ideal" + ) + + # Verify the call to invoke_llm contains replaced instruction + args, kwargs = mock_model_instance.invoke_llm.call_args + prompt_messages = kwargs["prompt_messages"] + user_msg = prompt_messages[1].content + user_msg_dict = json.loads(user_msg) + assert "null" in user_msg_dict["instruction"] # because last_run is None and current is current_val etc. + assert "current_val" in user_msg_dict["instruction"] + + def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "No braces here" + mock_model_instance.invoke_llm.return_value = mock_response + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"] + assert "Could not find a valid JSON object" in result["error"] + + def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "[1, 2, 3]" + mock_model_instance.invoke_llm.return_value = mock_response + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + # The exception message is "Expected a JSON object, but got list" + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + instance = MagicMock() + mock_manager.return_value.get_model_instance.return_value = instance + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"ok": true}' + instance.invoke_llm.return_value = mock_response + + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + + def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed") + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "Failed to generate code" in result["error"] + + def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "No JSON here" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"]