test: add test for core extension, external_data_tool and llm generator (#32468)

This commit is contained in:
mahammadasim 2026-03-12 09:14:38 +05:30 committed by GitHub
parent 7d2054d4f4
commit 245f6b824d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1990 additions and 2 deletions

View File

@ -193,7 +193,8 @@ class LLMGenerator:
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
error = str(e)
error_step = "generate rule config"
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@ -279,7 +280,8 @@ class LLMGenerator:
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
error = str(e)
error_step = "handle unexpected exception"
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""

View File

@ -0,0 +1,137 @@
import httpx
import pytest
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from models.api_based_extension import APIBasedExtensionPoint
def test_request_success(mocker):
# Mock httpx.Client and its context manager
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_response = mocker.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": "success"}
mock_client_instance.request.return_value = mock_response
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
result = requestor.request(APIBasedExtensionPoint.PING, {"foo": "bar"})
assert result == {"result": "success"}
mock_client_instance.request.assert_called_once_with(
method="POST",
url="http://example.com",
json={"point": APIBasedExtensionPoint.PING.value, "params": {"foo": "bar"}},
headers={"Content-Type": "application/json", "Authorization": "Bearer test_key"},
)
def test_request_with_ssrf_proxy(mocker):
# Mock dify_config
mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080")
mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", "https://proxy:8081")
# Mock httpx.Client
mock_client = mocker.MagicMock()
mock_client_class = mocker.patch("httpx.Client", return_value=mock_client)
mock_client_instance = mock_client.__enter__.return_value
# Mock response
mock_response = mocker.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": "success"}
mock_client_instance.request.return_value = mock_response
# Mock HTTPTransport
mock_transport = mocker.patch("httpx.HTTPTransport")
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
requestor.request(APIBasedExtensionPoint.PING, {})
# Verify httpx.Client was called with mounts
mock_client_class.assert_called_once()
kwargs = mock_client_class.call_args.kwargs
assert "mounts" in kwargs
assert "http://" in kwargs["mounts"]
assert "https://" in kwargs["mounts"]
assert mock_transport.call_count == 2
def test_request_with_only_one_proxy_config(mocker):
# Mock dify_config with only one proxy
mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080")
mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", None)
# Mock httpx.Client
mock_client = mocker.MagicMock()
mock_client_class = mocker.patch("httpx.Client", return_value=mock_client)
mock_client_instance = mock_client.__enter__.return_value
# Mock response
mock_response = mocker.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": "success"}
mock_client_instance.request.return_value = mock_response
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
requestor.request(APIBasedExtensionPoint.PING, {})
# Verify httpx.Client was called with mounts=None (default)
mock_client_class.assert_called_once()
kwargs = mock_client_class.call_args.kwargs
assert kwargs.get("mounts") is None
def test_request_timeout(mocker):
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_client_instance.request.side_effect = httpx.TimeoutException("timeout")
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
with pytest.raises(ValueError, match="request timeout"):
requestor.request(APIBasedExtensionPoint.PING, {})
def test_request_connection_error(mocker):
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_client_instance.request.side_effect = httpx.RequestError("error")
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
with pytest.raises(ValueError, match="request connection error"):
requestor.request(APIBasedExtensionPoint.PING, {})
def test_request_error_status_code(mocker):
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_response = mocker.MagicMock()
mock_response.status_code = 404
mock_response.text = "Not Found"
mock_client_instance.request.return_value = mock_response
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
with pytest.raises(ValueError, match="request error, status_code: 404, content: Not Found"):
requestor.request(APIBasedExtensionPoint.PING, {})
def test_request_error_status_code_long_content(mocker):
mock_client = mocker.MagicMock()
mock_client_instance = mock_client.__enter__.return_value
mocker.patch("httpx.Client", return_value=mock_client)
mock_response = mocker.MagicMock()
mock_response.status_code = 500
mock_response.text = "A" * 200 # Testing truncation of content
mock_client_instance.request.return_value = mock_response
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
expected_content = "A" * 100
with pytest.raises(ValueError, match=f"request error, status_code: 500, content: {expected_content}"):
requestor.request(APIBasedExtensionPoint.PING, {})

View File

@ -0,0 +1,281 @@
import json
import types
from unittest.mock import MagicMock, mock_open, patch
import pytest
from core.extension.extensible import Extensible
class TestExtensible:
def test_init(self):
tenant_id = "tenant_123"
config = {"key": "value"}
ext = Extensible(tenant_id, config)
assert ext.tenant_id == tenant_id
assert ext.config == config
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.os.path.exists")
@patch("core.extension.extensible.Path.read_text")
@patch("core.extension.extensible.importlib.util.module_from_spec")
@patch("core.extension.extensible.sort_to_dict_by_position_map")
def test_scan_extensions_success(
self,
mock_sort,
mock_module_from_spec,
mock_read_text,
mock_exists,
mock_isdir,
mock_listdir,
mock_dirname,
mock_find_spec,
):
# Setup
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [
["ext1"], # package_dir
["ext1.py", "__builtin__"], # subdir_path
]
mock_isdir.return_value = True
mock_exists.return_value = True
mock_read_text.return_value = "10"
# Use types.ModuleType to avoid MagicMock __dict__ issues
mock_mod = types.ModuleType("ext1")
class MockExtension(Extensible):
pass
mock_mod.MockExtension = MockExtension
mock_module_from_spec.return_value = mock_mod
mock_sort.side_effect = lambda position_map, data, name_func: data
# Execute
results = Extensible.scan_extensions()
# Assert
assert len(results) == 1
assert results[0].name == "ext1"
assert results[0].position == 10
assert results[0].builtin is True
assert results[0].extension_class == MockExtension
@patch("core.extension.extensible.importlib.util.find_spec")
def test_scan_extensions_package_not_found(self, mock_find_spec):
mock_find_spec.return_value = None
with pytest.raises(ImportError, match="Could not find package"):
Extensible.scan_extensions()
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
def test_scan_extensions_skip_subdirs(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
mock_find_spec.return_value = package_spec
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["__pycache__", "not_a_dir", "missing_py_file"], []]
mock_isdir.side_effect = [False, True]
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
results = Extensible.scan_extensions()
assert len(results) == 0
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.os.path.exists")
@patch("core.extension.extensible.importlib.util.module_from_spec")
def test_scan_extensions_not_builtin_success(
self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py", "schema.json"]]
mock_isdir.return_value = True
# exists checks: only schema.json needs to exist
mock_exists.return_value = True
mock_mod = types.ModuleType("ext1")
class MockExtension(Extensible):
pass
mock_mod.MockExtension = MockExtension
mock_module_from_spec.return_value = mock_mod
schema_content = json.dumps({"label": {"en": "Test"}, "form_schema": [{"name": "field1"}]})
with (
patch("builtins.open", mock_open(read_data=schema_content)),
patch(
"core.extension.extensible.sort_to_dict_by_position_map",
side_effect=lambda position_map, data, name_func: data,
),
):
results = Extensible.scan_extensions()
assert len(results) == 1
assert results[0].name == "ext1"
assert results[0].builtin is False
assert results[0].label == {"en": "Test"}
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.os.path.exists")
@patch("core.extension.extensible.importlib.util.module_from_spec")
def test_scan_extensions_not_builtin_missing_schema(
self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
mock_isdir.return_value = True
# exists: only schema.json checked, and return False
mock_exists.return_value = False
mock_mod = types.ModuleType("ext1")
class MockExtension(Extensible):
pass
mock_mod.MockExtension = MockExtension
mock_module_from_spec.return_value = mock_mod
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
results = Extensible.scan_extensions()
assert len(results) == 0
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.importlib.util.module_from_spec")
@patch("core.extension.extensible.os.path.exists")
def test_scan_extensions_no_extension_class(
self, mock_exists, mock_module_from_spec, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
mock_isdir.return_value = True
# Mock not builtin
mock_exists.return_value = False
mock_mod = types.ModuleType("ext1")
mock_mod.SomeOtherClass = type("SomeOtherClass", (), {})
mock_module_from_spec.return_value = mock_mod
# We need to ensure we don't crash if checking schema (but we won't reach there because class not found)
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
results = Extensible.scan_extensions()
assert len(results) == 0
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
def test_scan_extensions_module_import_error(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
mock_find_spec.side_effect = [package_spec, None] # No module spec
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
mock_isdir.return_value = True
with pytest.raises(ImportError, match="Failed to load module"):
Extensible.scan_extensions()
@patch("core.extension.extensible.importlib.util.find_spec")
def test_scan_extensions_general_exception(self, mock_find_spec):
mock_find_spec.side_effect = Exception("Unexpected error")
with pytest.raises(Exception, match="Unexpected error"):
Extensible.scan_extensions()
@patch("core.extension.extensible.importlib.util.find_spec")
@patch("core.extension.extensible.os.path.dirname")
@patch("core.extension.extensible.os.listdir")
@patch("core.extension.extensible.os.path.isdir")
@patch("core.extension.extensible.os.path.exists")
@patch("core.extension.extensible.Path.read_text")
@patch("core.extension.extensible.importlib.util.module_from_spec")
def test_scan_extensions_builtin_without_position_file(
self, mock_module_from_spec, mock_read_text, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
):
package_spec = MagicMock()
package_spec.origin = "/path/to/pkg/__init__.py"
module_spec = MagicMock()
module_spec.loader = MagicMock()
mock_find_spec.side_effect = [package_spec, module_spec]
mock_dirname.return_value = "/path/to/pkg"
mock_listdir.side_effect = [["ext1"], ["ext1.py", "__builtin__"]]
mock_isdir.return_value = True
# builtin exists in listdir, but os.path.exists(builtin_file_path) returns False
mock_exists.return_value = False
mock_mod = types.ModuleType("ext1")
class MockExtension(Extensible):
pass
mock_mod.MockExtension = MockExtension
mock_module_from_spec.return_value = mock_mod
with patch(
"core.extension.extensible.sort_to_dict_by_position_map",
side_effect=lambda position_map, data, name_func: data,
):
results = Extensible.scan_extensions()
assert len(results) == 1
assert results[0].position == 0

View File

@ -0,0 +1,90 @@
from unittest.mock import MagicMock, patch
import pytest
from core.extension.extensible import ExtensionModule, ModuleExtension
from core.extension.extension import Extension
class TestExtension:
def setup_method(self):
# Reset the private class attribute before each test
Extension._Extension__module_extensions = {}
def test_init(self):
# Mock scan_extensions for Moderation and ExternalDataTool
mock_mod_extensions = {"mod1": ModuleExtension(name="mod1")}
mock_ext_extensions = {"ext1": ModuleExtension(name="ext1")}
extension = Extension()
# We need to mock scan_extensions on the classes defined in Extension.module_classes
with (
patch("core.extension.extension.Moderation.scan_extensions", return_value=mock_mod_extensions),
patch("core.extension.extension.ExternalDataTool.scan_extensions", return_value=mock_ext_extensions),
):
extension.init()
# Check if internal state is updated
internal_state = Extension._Extension__module_extensions
assert internal_state[ExtensionModule.MODERATION.value] == mock_mod_extensions
assert internal_state[ExtensionModule.EXTERNAL_DATA_TOOL.value] == mock_ext_extensions
def test_module_extensions_success(self):
# Setup data
mock_extensions = {"name1": ModuleExtension(name="name1"), "name2": ModuleExtension(name="name2")}
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: mock_extensions}
extension = Extension()
result = extension.module_extensions(ExtensionModule.MODERATION.value)
assert len(result) == 2
assert any(e.name == "name1" for e in result)
assert any(e.name == "name2" for e in result)
def test_module_extensions_not_found(self):
extension = Extension()
with pytest.raises(ValueError, match="Extension Module unknown not found"):
extension.module_extensions("unknown")
def test_module_extension_success(self):
mock_ext = ModuleExtension(name="test_ext")
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
extension = Extension()
result = extension.module_extension(ExtensionModule.MODERATION, "test_ext")
assert result == mock_ext
def test_module_extension_module_not_found(self):
extension = Extension()
# ExtensionModule.MODERATION is "moderation"
with pytest.raises(ValueError, match="Extension Module moderation not found"):
extension.module_extension(ExtensionModule.MODERATION, "any")
def test_module_extension_extension_not_found(self):
# We need a non-empty dict because 'if not module_extensions' in extension.py
# returns True for an empty dict, which raises the module not found error instead.
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"other": MagicMock()}}
extension = Extension()
with pytest.raises(ValueError, match="Extension unknown not found"):
extension.module_extension(ExtensionModule.MODERATION, "unknown")
def test_extension_class_success(self):
class MockClass:
pass
mock_ext = ModuleExtension(name="test_ext", extension_class=MockClass)
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
extension = Extension()
result = extension.extension_class(ExtensionModule.MODERATION, "test_ext")
assert result == MockClass
def test_extension_class_none(self):
mock_ext = ModuleExtension(name="test_ext", extension_class=None)
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
extension = Extension()
with pytest.raises(AssertionError):
extension.extension_class(ExtensionModule.MODERATION, "test_ext")

View File

@ -0,0 +1,145 @@
from unittest.mock import MagicMock, patch
import pytest
from core.external_data_tool.api.api import ApiExternalDataTool
from models.api_based_extension import APIBasedExtensionPoint
def test_api_external_data_tool_name():
assert ApiExternalDataTool.name == "api"
@patch("core.external_data_tool.api.api.db")
def test_validate_config_success(mock_db):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = mock_extension
# Should not raise exception
ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"})
def test_validate_config_missing_id():
with pytest.raises(ValueError, match="api_based_extension_id is required"):
ApiExternalDataTool.validate_config("tenant_id", {})
@patch("core.external_data_tool.api.api.db")
def test_validate_config_invalid_id(mock_db):
mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match="api_based_extension_id is invalid"):
ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"})
@pytest.fixture
def api_tool():
# Use standard kwargs as it inherits from ExternalDataTool which is typically a Pydantic BaseModel
return ApiExternalDataTool(
tenant_id="tenant_id", app_id="app_id", variable="var1", config={"api_based_extension_id": "ext_id"}
)
@patch("core.external_data_tool.api.api.db")
@patch("core.external_data_tool.api.api.encrypter")
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
def test_query_success(mock_requestor_class, mock_encrypter, mock_db, api_tool):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_extension.api_endpoint = "http://api"
mock_extension.api_key = "encrypted_key"
mock_db.session.scalar.return_value = mock_extension
mock_encrypter.decrypt_token.return_value = "decrypted_key"
mock_requestor = mock_requestor_class.return_value
mock_requestor.request.return_value = {"result": "success_result"}
res = api_tool.query({"input1": "value1"}, "query_str")
assert res == "success_result"
mock_requestor_class.assert_called_once_with(api_endpoint="http://api", api_key="decrypted_key")
mock_requestor.request.assert_called_once_with(
point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
params={"app_id": "app_id", "tool_variable": "var1", "inputs": {"input1": "value1"}, "query": "query_str"},
)
def test_query_missing_config():
api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1")
api_tool.config = None # Force None
with pytest.raises(ValueError, match="config is required"):
api_tool.query({}, "")
def test_query_missing_extension_id():
api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1", config={"dummy": "value"})
with pytest.raises(AssertionError, match="api_based_extension_id is required"):
api_tool.query({}, "")
@patch("core.external_data_tool.api.api.db")
def test_query_invalid_extension(mock_db, api_tool):
mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match=".*error: api_based_extension_id is invalid"):
api_tool.query({}, "")
@patch("core.external_data_tool.api.api.db")
@patch("core.external_data_tool.api.api.encrypter")
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
def test_query_requestor_init_error(mock_requestor_class, mock_encrypter, mock_db, api_tool):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_extension.api_endpoint = "http://api"
mock_extension.api_key = "encrypted_key"
mock_db.session.scalar.return_value = mock_extension
mock_encrypter.decrypt_token.return_value = "decrypted_key"
mock_requestor_class.side_effect = Exception("init error")
with pytest.raises(ValueError, match=".*error: init error"):
api_tool.query({}, "")
@patch("core.external_data_tool.api.api.db")
@patch("core.external_data_tool.api.api.encrypter")
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
def test_query_no_result_in_response(mock_requestor_class, mock_encrypter, mock_db, api_tool):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_extension.api_endpoint = "http://api"
mock_extension.api_key = "encrypted_key"
mock_db.session.scalar.return_value = mock_extension
mock_encrypter.decrypt_token.return_value = "decrypted_key"
mock_requestor = mock_requestor_class.return_value
mock_requestor.request.return_value = {"other": "value"}
with pytest.raises(ValueError, match=".*error: result not found in response"):
api_tool.query({}, "")
@patch("core.external_data_tool.api.api.db")
@patch("core.external_data_tool.api.api.encrypter")
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
def test_query_result_not_string(mock_requestor_class, mock_encrypter, mock_db, api_tool):
mock_extension = MagicMock()
mock_extension.id = "ext_id"
mock_extension.tenant_id = "tenant_id"
mock_extension.api_endpoint = "http://api"
mock_extension.api_key = "encrypted_key"
mock_db.session.scalar.return_value = mock_extension
mock_encrypter.decrypt_token.return_value = "decrypted_key"
mock_requestor = mock_requestor_class.return_value
mock_requestor.request.return_value = {"result": 123} # Not a string
with pytest.raises(ValueError, match=".*error: result is not string"):
api_tool.query({}, "")

View File

@ -0,0 +1,66 @@
import pytest
from core.extension.extensible import ExtensionModule
from core.external_data_tool.base import ExternalDataTool
class TestExternalDataTool:
def test_module_attribute(self):
assert ExternalDataTool.module == ExtensionModule.EXTERNAL_DATA_TOOL
def test_init(self):
# Create a concrete subclass to test init
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
return super().validate_config(tenant_id, config)
def query(self, inputs: dict, query: str | None = None) -> str:
return super().query(inputs, query)
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"})
assert tool.tenant_id == "tenant_1"
assert tool.app_id == "app_1"
assert tool.variable == "var_1"
assert tool.config == {"key": "value"}
def test_init_without_config(self):
# Create a concrete subclass to test init
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
pass
def query(self, inputs: dict, query: str | None = None) -> str:
return ""
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
assert tool.tenant_id == "tenant_1"
assert tool.app_id == "app_1"
assert tool.variable == "var_1"
assert tool.config is None
def test_validate_config_raises_not_implemented(self):
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
return super().validate_config(tenant_id, config)
def query(self, inputs: dict, query: str | None = None) -> str:
return ""
with pytest.raises(NotImplementedError):
ConcreteTool.validate_config("tenant_1", {})
def test_query_raises_not_implemented(self):
class ConcreteTool(ExternalDataTool):
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
pass
def query(self, inputs: dict, query: str | None = None) -> str:
return super().query(inputs, query)
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
with pytest.raises(NotImplementedError):
tool.query({})

View File

@ -0,0 +1,115 @@
from unittest.mock import patch
import pytest
from flask import Flask
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.external_data_fetch import ExternalDataFetch
class TestExternalDataFetch:
@pytest.fixture
def app(self):
app = Flask(__name__)
return app
def test_fetch_success(self, app):
with app.app_context():
fetcher = ExternalDataFetch()
# Setup mocks
tool1 = ExternalDataVariableEntity(variable="var1", type="type1", config={"c1": "v1"})
tool2 = ExternalDataVariableEntity(variable="var2", type="type2", config={"c2": "v2"})
external_data_tools = [tool1, tool2]
inputs = {"input_key": "input_value"}
query = "test query"
with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory:
# Create distinct mock instances for each tool to ensure deterministic results
# This approach is robust regardless of thread scheduling order
from unittest.mock import MagicMock
def factory_side_effect(*args, **kwargs):
variable = kwargs.get("variable")
mock_instance = MagicMock()
if variable == "var1":
mock_instance.query.return_value = "result1"
elif variable == "var2":
mock_instance.query.return_value = "result2"
return mock_instance
MockFactory.side_effect = factory_side_effect
result_inputs = fetcher.fetch(
tenant_id="tenant1",
app_id="app1",
external_data_tools=external_data_tools,
inputs=inputs,
query=query,
)
# Each tool gets its deterministic result regardless of thread completion order
assert result_inputs["var1"] == "result1"
assert result_inputs["var2"] == "result2"
assert result_inputs["input_key"] == "input_value"
assert len(result_inputs) == 3
# Verify factory calls
assert MockFactory.call_count == 2
MockFactory.assert_any_call(
name="type1", tenant_id="tenant1", app_id="app1", variable="var1", config={"c1": "v1"}
)
MockFactory.assert_any_call(
name="type2", tenant_id="tenant1", app_id="app1", variable="var2", config={"c2": "v2"}
)
def test_fetch_no_tools(self):
# We don't necessarily need app_context if there are no tools,
# but fetch calls current_app._get_current_object() only inside the loop.
# Wait, let's look at the code.
# for tool in external_data_tools:
# executor.submit(..., current_app._get_current_object(), ...)
# So if external_data_tools is empty, it shouldn't access current_app.
fetcher = ExternalDataFetch()
inputs = {"input_key": "input_value"}
result_inputs = fetcher.fetch(
tenant_id="tenant1", app_id="app1", external_data_tools=[], inputs=inputs, query="test query"
)
assert result_inputs == inputs
assert result_inputs is not inputs # Should be a copy
def test_fetch_with_none_variable(self, app):
with app.app_context():
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={})
# Patch _query_external_data_tool to return None variable
with patch.object(ExternalDataFetch, "_query_external_data_tool") as mock_query:
mock_query.return_value = (None, "some_result")
result_inputs = fetcher.fetch(
tenant_id="t1", app_id="a1", external_data_tools=[tool], inputs={"in": "val"}, query="q"
)
assert "var1" not in result_inputs
assert result_inputs == {"in": "val"}
def test_query_external_data_tool(self, app):
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"})
with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory:
mock_factory_instance = MockFactory.return_value
mock_factory_instance.query.return_value = "query_result"
var, res = fetcher._query_external_data_tool(
flask_app=app, tenant_id="t1", app_id="a1", external_data_tool=tool, inputs={"i": "v"}, query="q"
)
assert var == "var1"
assert res == "query_result"
MockFactory.assert_called_once_with(
name="type1", tenant_id="t1", app_id="a1", variable="var1", config={"k": "v"}
)
mock_factory_instance.query.assert_called_once_with(inputs={"i": "v"}, query="q")

View File

@ -0,0 +1,58 @@
from unittest.mock import MagicMock, patch
from core.extension.extensible import ExtensionModule
from core.external_data_tool.factory import ExternalDataToolFactory
def test_external_data_tool_factory_init():
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
mock_extension_class = MagicMock()
mock_code_based_extension.extension_class.return_value = mock_extension_class
name = "test_tool"
tenant_id = "tenant_123"
app_id = "app_456"
variable = "var_v"
config = {"key": "value"}
factory = ExternalDataToolFactory(name, tenant_id, app_id, variable, config)
mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name)
mock_extension_class.assert_called_once_with(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
)
def test_external_data_tool_factory_validate_config():
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
mock_extension_class = MagicMock()
mock_code_based_extension.extension_class.return_value = mock_extension_class
name = "test_tool"
tenant_id = "tenant_123"
config = {"key": "value"}
ExternalDataToolFactory.validate_config(name, tenant_id, config)
mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name)
mock_extension_class.validate_config.assert_called_once_with(tenant_id, config)
def test_external_data_tool_factory_query():
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
mock_extension_class = MagicMock()
mock_extension_instance = MagicMock()
mock_extension_class.return_value = mock_extension_instance
mock_code_based_extension.extension_class.return_value = mock_extension_class
mock_extension_instance.query.return_value = "query_result"
factory = ExternalDataToolFactory("name", "tenant", "app", "var", {})
inputs = {"input_key": "input_value"}
query = "search_query"
result = factory.query(inputs, query)
assert result == "query_result"
mock_extension_instance.query.assert_called_once_with(inputs, query)

View File

@ -0,0 +1,103 @@
import pytest
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.prompts import (
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
)
class TestRuleConfigGeneratorOutputParser:
def test_get_format_instructions(self):
parser = RuleConfigGeneratorOutputParser()
instructions = parser.get_format_instructions()
assert instructions == (
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
)
def test_parse_success(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": "This is a prompt",
"variables": ["var1", "var2"],
"opening_statement": "Hello!"
}
```
"""
result = parser.parse(text)
assert result["prompt"] == "This is a prompt"
assert result["variables"] == ["var1", "var2"]
assert result["opening_statement"] == "Hello!"
def test_parse_invalid_json(self):
parser = RuleConfigGeneratorOutputParser()
text = "invalid json"
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "Parsing text" in str(excinfo.value)
assert "could not find json block in the output" in str(excinfo.value)
def test_parse_missing_keys(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": "This is a prompt",
"variables": ["var1", "var2"]
}
```
"""
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "expected key `opening_statement` to be present" in str(excinfo.value)
def test_parse_wrong_type_prompt(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": 123,
"variables": ["var1", "var2"],
"opening_statement": "Hello!"
}
```
"""
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "Expected 'prompt' to be a string" in str(excinfo.value)
def test_parse_wrong_type_variables(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": "This is a prompt",
"variables": "not a list",
"opening_statement": "Hello!"
}
```
"""
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "Expected 'variables' to be a list" in str(excinfo.value)
def test_parse_wrong_type_opening_statement(self):
parser = RuleConfigGeneratorOutputParser()
text = """
```json
{
"prompt": "This is a prompt",
"variables": ["var1", "var2"],
"opening_statement": 123
}
```
"""
with pytest.raises(OutputParserError) as excinfo:
parser.parse(text)
assert "Expected 'opening_statement' to be a str" in str(excinfo.value)

View File

@ -0,0 +1,402 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import (
ResponseFormat,
_handle_native_json_schema,
_handle_prompt_based_schema,
_parse_structured_output,
_prepare_schema_for_model,
_set_response_format,
convert_boolean_to_string,
invoke_llm_with_structured_output,
remove_additional_properties,
)
from core.model_manager import ModelInstance
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultWithStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType
class TestStructuredOutput:
def test_remove_additional_properties(self):
schema = {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
"additionalProperties": False,
"nested": {"type": "object", "additionalProperties": True},
"items": [{"type": "object", "additionalProperties": False}],
}
remove_additional_properties(schema)
assert "additionalProperties" not in schema
assert "additionalProperties" not in schema["nested"]
assert "additionalProperties" not in schema["items"][0]
# Test with non-dict input
remove_additional_properties(None) # Should not raise
remove_additional_properties([]) # Should not raise
def test_convert_boolean_to_string(self):
schema = {
"type": "object",
"properties": {
"is_active": {"type": "boolean"},
"tags": {"type": "array", "items": {"type": "boolean"}},
"list_schema": [{"type": "boolean"}],
},
}
convert_boolean_to_string(schema)
assert schema["properties"]["is_active"]["type"] == "string"
assert schema["properties"]["tags"]["items"]["type"] == "string"
assert schema["properties"]["list_schema"][0]["type"] == "string"
# Test with non-dict input
convert_boolean_to_string(None) # Should not raise
convert_boolean_to_string([]) # Should not raise
def test_parse_structured_output_valid(self):
text = '{"key": "value"}'
assert _parse_structured_output(text) == {"key": "value"}
def test_parse_structured_output_non_dict_valid_json(self):
# Even if it's valid JSON, if it's not a dict, it should try repair or fail
text = '["a", "b"]'
with patch("json_repair.loads") as mock_repair:
mock_repair.return_value = {"key": "value"}
assert _parse_structured_output(text) == {"key": "value"}
def test_parse_structured_output_not_dict_fail_via_validate(self):
# Force TypeAdapter to return a non-dict to trigger line 292
with patch("pydantic.TypeAdapter.validate_json") as mock_validate:
mock_validate.return_value = ["a list"]
with pytest.raises(OutputParserError) as excinfo:
_parse_structured_output('["a list"]')
assert "Failed to parse structured output" in str(excinfo.value)
def test_parse_structured_output_repair_success(self):
text = "{'key': 'value'}" # Invalid JSON (single quotes)
# json_repair should handle this
assert _parse_structured_output(text) == {"key": "value"}
def test_parse_structured_output_repair_list(self):
# Deepseek-r1 case: result is a list containing a dict
text = '[{"key": "value"}]'
assert _parse_structured_output(text) == {"key": "value"}
def test_parse_structured_output_repair_list_no_dict(self):
# Deepseek-r1 case: result is a list with NO dict
text = "[1, 2, 3]"
assert _parse_structured_output(text) == {}
def test_parse_structured_output_repair_fail(self):
text = "not a json at all"
with patch("json_repair.loads") as mock_repair:
mock_repair.return_value = "still not a dict or list"
with pytest.raises(OutputParserError):
_parse_structured_output(text)
def test_set_response_format(self):
# Test JSON
params = {}
rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON],
)
]
_set_response_format(params, rules)
assert params["response_format"] == ResponseFormat.JSON
# Test JSON_OBJECT
params = {}
rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON_OBJECT],
)
]
_set_response_format(params, rules)
assert params["response_format"] == ResponseFormat.JSON_OBJECT
def test_handle_native_json_schema(self):
provider = "openai"
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "gpt-4"
structured_output_schema = {"type": "object"}
model_parameters = {}
rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON_SCHEMA],
)
]
updated_params = _handle_native_json_schema(
provider, model_schema, structured_output_schema, model_parameters, rules
)
assert "json_schema" in updated_params
assert json.loads(updated_params["json_schema"]) == {"schema": {"type": "object"}, "name": "llm_response"}
assert updated_params["response_format"] == ResponseFormat.JSON_SCHEMA
def test_handle_native_json_schema_no_format_rule(self):
provider = "openai"
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "gpt-4"
structured_output_schema = {"type": "object"}
model_parameters = {}
rules = []
updated_params = _handle_native_json_schema(
provider, model_schema, structured_output_schema, model_parameters, rules
)
assert "json_schema" in updated_params
assert "response_format" not in updated_params
def test_handle_prompt_based_schema_with_system_prompt(self):
prompt_messages = [
SystemPromptMessage(content="Existing system prompt"),
UserPromptMessage(content="User question"),
]
schema = {"type": "object"}
result = _handle_prompt_based_schema(prompt_messages, schema)
assert len(result) == 2
assert isinstance(result[0], SystemPromptMessage)
assert "Existing system prompt" in result[0].content
assert json.dumps(schema) in result[0].content
assert isinstance(result[1], UserPromptMessage)
def test_handle_prompt_based_schema_without_system_prompt(self):
prompt_messages = [UserPromptMessage(content="User question")]
schema = {"type": "object"}
result = _handle_prompt_based_schema(prompt_messages, schema)
assert len(result) == 2
assert isinstance(result[0], SystemPromptMessage)
assert json.dumps(schema) in result[0].content
assert isinstance(result[1], UserPromptMessage)
def test_prepare_schema_for_model_gemini(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "gemini-1.5-pro"
schema = {"type": "object", "additionalProperties": False}
result = _prepare_schema_for_model("google", model_schema, schema)
assert "additionalProperties" not in result
def test_prepare_schema_for_model_ollama(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "llama3"
schema = {"type": "object"}
result = _prepare_schema_for_model("ollama", model_schema, schema)
assert result == schema
def test_prepare_schema_for_model_default(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.model = "gpt-4"
schema = {"type": "object"}
result = _prepare_schema_for_model("openai", model_schema, schema)
assert result == {"schema": schema, "name": "llm_response"}
def test_invoke_llm_with_structured_output_no_stream_native(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = True
model_schema.parameter_rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON_SCHEMA],
)
]
model_schema.model = "gpt-4o"
model_instance = MagicMock(spec=ModelInstance)
mock_result = MagicMock(spec=LLMResult)
mock_result.message = AssistantPromptMessage(content='{"result": "success"}')
mock_result.model = "gpt-4o"
mock_result.usage = LLMUsage.empty_usage()
mock_result.system_fingerprint = "fp_native"
mock_result.prompt_messages = [UserPromptMessage(content="hi")]
model_instance.invoke_llm.return_value = mock_result
result = invoke_llm_with_structured_output(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="hi")],
json_schema={"type": "object"},
stream=False,
)
assert isinstance(result, LLMResultWithStructuredOutput)
assert result.structured_output == {"result": "success"}
assert result.system_fingerprint == "fp_native"
def test_invoke_llm_with_structured_output_no_stream_prompt_based(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = False
model_schema.parameter_rules = [
ParameterRule(
name="response_format",
label={"en_US": ""},
type=ParameterType.STRING,
help={"en_US": ""},
options=[ResponseFormat.JSON],
)
]
model_schema.model = "claude-3"
model_instance = MagicMock(spec=ModelInstance)
mock_result = MagicMock(spec=LLMResult)
mock_result.message = AssistantPromptMessage(content='{"result": "success"}')
mock_result.model = "claude-3"
mock_result.usage = LLMUsage.empty_usage()
mock_result.system_fingerprint = "fp_prompt"
mock_result.prompt_messages = []
model_instance.invoke_llm.return_value = mock_result
result = invoke_llm_with_structured_output(
provider="anthropic",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="hi")],
json_schema={"type": "object"},
stream=False,
)
assert isinstance(result, LLMResultWithStructuredOutput)
assert result.structured_output == {"result": "success"}
assert result.system_fingerprint == "fp_prompt"
def test_invoke_llm_with_structured_output_no_string_error(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = False
model_schema.parameter_rules = []
model_instance = MagicMock(spec=ModelInstance)
mock_result = MagicMock(spec=LLMResult)
mock_result.message = AssistantPromptMessage(content=[TextPromptMessageContent(data="not a string")])
model_instance.invoke_llm.return_value = mock_result
with pytest.raises(OutputParserError) as excinfo:
invoke_llm_with_structured_output(
provider="anthropic",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[],
json_schema={},
stream=False,
)
assert "Failed to parse structured output, LLM result is not a string" in str(excinfo.value)
def test_invoke_llm_with_structured_output_stream(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = False
model_schema.parameter_rules = []
model_schema.model = "gpt-4"
model_instance = MagicMock(spec=ModelInstance)
# Mock chunks
chunk1 = MagicMock(spec=LLMResultChunk)
chunk1.delta = LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content='{"key": '), usage=LLMUsage.empty_usage()
)
chunk1.prompt_messages = [UserPromptMessage(content="hi")]
chunk1.system_fingerprint = "fp1"
chunk2 = MagicMock(spec=LLMResultChunk)
chunk2.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content='"value"}'))
chunk2.prompt_messages = [UserPromptMessage(content="hi")]
chunk2.system_fingerprint = "fp1"
chunk3 = MagicMock(spec=LLMResultChunk)
chunk3.delta = LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data=" "),
]
),
)
chunk3.prompt_messages = [UserPromptMessage(content="hi")]
chunk3.system_fingerprint = "fp1"
event4 = MagicMock()
event4.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=""))
model_instance.invoke_llm.return_value = [chunk1, chunk2, chunk3, event4]
generator = invoke_llm_with_structured_output(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="hi")],
json_schema={},
stream=True,
)
chunks = list(generator)
assert len(chunks) == 5
assert chunks[-1].structured_output == {"key": "value"}
assert chunks[-1].system_fingerprint == "fp1"
assert chunks[-1].prompt_messages == [UserPromptMessage(content="hi")]
def test_invoke_llm_with_structured_output_stream_no_id_events(self):
model_schema = MagicMock(spec=AIModelEntity)
model_schema.support_structure_output = False
model_schema.parameter_rules = []
model_schema.model = "gpt-4"
model_instance = MagicMock(spec=ModelInstance)
model_instance.invoke_llm.return_value = []
generator = invoke_llm_with_structured_output(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[],
json_schema={},
stream=True,
)
with pytest.raises(OutputParserError):
list(generator)
def test_parse_structured_output_empty_string(self):
with pytest.raises(OutputParserError):
_parse_structured_output("")

View File

@ -0,0 +1,589 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
class TestLLMGenerator:
@pytest.fixture
def mock_model_instance(self):
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
instance = MagicMock()
mock_manager.return_value.get_default_model_instance.return_value = instance
mock_manager.return_value.get_model_instance.return_value = instance
yield instance
@pytest.fixture
def model_config_entity(self):
return ModelConfig(provider="openai", name="gpt-4", mode=LLMMode.CHAT, completion_params={"temperature": 0.7})
def test_generate_conversation_name_success(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Test Conversation Name"})
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager") as mock_trace:
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == "Test Conversation Name"
mock_trace.assert_called_once()
def test_generate_conversation_name_truncated(self, mock_model_instance):
long_query = "a" * 2100
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Short Name"})
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", long_query)
assert name == "Short Name"
def test_generate_conversation_name_empty_answer(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = ""
mock_model_instance.invoke_llm.return_value = mock_response
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == ""
def test_generate_conversation_name_json_repair(self, mock_model_instance):
mock_response = MagicMock()
# Invalid JSON that json_repair can fix
mock_response.message.get_text_content.return_value = "{'Your Output': 'Repaired Name'}"
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == "Repaired Name"
def test_generate_conversation_name_not_dict_result(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '["not a dict"]'
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == "test query"
def test_generate_conversation_name_no_output_in_dict(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"something": "else"}'
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert name == "test query"
def test_generate_conversation_name_long_output(self, mock_model_instance):
long_output = "a" * 100
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": long_output})
mock_model_instance.invoke_llm.return_value = mock_response
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
assert len(name) == 78 # 75 + "..."
assert name.endswith("...")
def test_generate_suggested_questions_after_answer_success(self, mock_model_instance):
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '["Question 1?", "Question 2?"]'
mock_model_instance.invoke_llm.return_value = mock_response
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
assert len(questions) == 2
assert questions[0] == "Question 1?"
def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance):
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed")
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
assert questions == []
def test_generate_suggested_questions_after_answer_invoke_error(self, mock_model_instance):
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
assert questions == []
def test_generate_suggested_questions_after_answer_exception(self, mock_model_instance):
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
assert questions == []
def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=True
)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "Generated Prompt"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert result["prompt"] == "Generated Prompt"
assert result["error"] == ""
def test_generate_rule_config_no_variable_invoke_error(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=True
)
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate rule config" in result["error"]
def test_generate_rule_config_no_variable_exception(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=True
)
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate rule config" in result["error"]
assert "Random error" in result["error"]
def test_generate_rule_config_with_variable_success(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
# Mocking 3 calls for invoke_llm
mock_res1 = MagicMock()
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
mock_res2 = MagicMock()
mock_res2.message.get_text_content.return_value = '"var1", "var2"'
mock_res3 = MagicMock()
mock_res3.message.get_text_content.return_value = "Opening Statement"
mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, mock_res3]
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert result["prompt"] == "Step 1 Prompt"
assert result["variables"] == ["var1", "var2"]
assert result["opening_statement"] == "Opening Statement"
assert result["error"] == ""
def test_generate_rule_config_with_variable_step1_error(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
mock_model_instance.invoke_llm.side_effect = InvokeError("Step 1 Failed")
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate prefix prompt" in result["error"]
def test_generate_rule_config_with_variable_step2_error(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
mock_res1 = MagicMock()
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
# Step 2 fails
mock_model_instance.invoke_llm.side_effect = [mock_res1, InvokeError("Step 2 Failed"), MagicMock()]
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate variables" in result["error"]
def test_generate_rule_config_with_variable_step3_error(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
mock_res1 = MagicMock()
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
mock_res2 = MagicMock()
mock_res2.message.get_text_content.return_value = '"var1"'
# Step 3 fails
mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, InvokeError("Step 3 Failed")]
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to generate conversation opener" in result["error"]
def test_generate_rule_config_with_variable_exception(self, mock_model_instance, model_config_entity):
payload = RuleGeneratePayload(
instruction="test instruction", model_config=model_config_entity, no_variable=False
)
# Mock any step to throw Exception
mock_model_instance.invoke_llm.side_effect = Exception("Unexpected multi-step error")
result = LLMGenerator.generate_rule_config("tenant_id", payload)
assert "Failed to handle unexpected exception" in result["error"]
assert "Unexpected multi-step error" in result["error"]
def test_generate_code_python_success(self, mock_model_instance, model_config_entity):
payload = RuleCodeGeneratePayload(
instruction="print hello", code_language="python", model_config=model_config_entity
)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "print('hello')"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_code("tenant_id", payload)
assert result["code"] == "print('hello')"
assert result["language"] == "python"
def test_generate_code_javascript_success(self, mock_model_instance, model_config_entity):
payload = RuleCodeGeneratePayload(
instruction="console log hello", code_language="javascript", model_config=model_config_entity
)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "console.log('hello')"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_code("tenant_id", payload)
assert result["code"] == "console.log('hello')"
assert result["language"] == "javascript"
def test_generate_code_invoke_error(self, mock_model_instance, model_config_entity):
payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity)
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
result = LLMGenerator.generate_code("tenant_id", payload)
assert "Failed to generate code" in result["error"]
def test_generate_code_exception(self, mock_model_instance, model_config_entity):
payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity)
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.generate_code("tenant_id", payload)
assert "An unexpected error occurred" in result["error"]
def test_generate_qa_document_success(self, mock_model_instance):
mock_response = MagicMock(spec=LLMResult)
mock_response.message = MagicMock()
mock_response.message.get_text_content.return_value = "QA Document Content"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_qa_document("tenant_id", "query", "English")
assert result == "QA Document Content"
def test_generate_qa_document_type_error(self, mock_model_instance):
mock_model_instance.invoke_llm.return_value = "Not an LLMResult"
with pytest.raises(TypeError, match="Expected LLMResult when stream=False"):
LLMGenerator.generate_qa_document("tenant_id", "query", "English")
def test_generate_structured_output_success(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"type": "object", "properties": {}}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_structured_output("tenant_id", payload)
parsed_output = json.loads(result["output"])
assert parsed_output["type"] == "object"
assert result["error"] == ""
def test_generate_structured_output_json_repair(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "{'type': 'object'}"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_structured_output("tenant_id", payload)
parsed_output = json.loads(result["output"])
assert parsed_output["type"] == "object"
def test_generate_structured_output_not_dict_or_list(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "true" # parsed as bool
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.generate_structured_output("tenant_id", payload)
assert "An unexpected error occurred" in result["error"]
assert "Failed to parse structured output" in result["error"]
def test_generate_structured_output_invoke_error(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity)
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
result = LLMGenerator.generate_structured_output("tenant_id", payload)
assert "Failed to generate JSON Schema" in result["error"]
def test_generate_structured_output_exception(self, mock_model_instance, model_config_entity):
payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity)
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.generate_structured_output("tenant_id", payload)
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
# Mock __instruction_modify_common call via invoke_llm
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert result == {"modified": "prompt"}
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
last_run = MagicMock()
last_run.query = "q"
last_run.answer = "a"
last_run.error = "e"
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert result == {"modified": "prompt"}
def test_instruction_modify_workflow_app_not_found(self):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="App not found."):
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
def test_instruction_modify_workflow_no_workflow(self):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = None
with pytest.raises(ValueError, match="Workflow not found for the given app model."):
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", workflow_service)
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
last_run = MagicMock()
last_run.node_type = "llm"
last_run.status = "s"
last_run.error = "e"
# Return regular values, not Mocks
last_run.execution_metadata_dict = {"agent_log": [{"status": "s", "error": "e", "data": {}}]}
last_run.load_full_inputs.return_value = {"in": "val"}
workflow_service.get_node_last_run.return_value = last_run
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
assert result == {"modified": "workflow"}
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
workflow_service.get_node_last_run.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
assert result == {"modified": "fallback"}
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
# Cause exception in node_type logic
workflow.graph_dict = {"graph": {"nodes": []}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
workflow_service.get_node_last_run.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
assert result == {"modified": "fallback"}
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
last_run = MagicMock()
last_run.node_type = "llm"
last_run.status = "s"
last_run.error = "e"
# Return regular empty list, not a Mock
last_run.execution_metadata_dict = {"agent_log": []}
last_run.load_full_inputs.return_value = {}
workflow_service.get_node_last_run.return_value = last_run
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
assert result == {"modified": "workflow"}
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
# Testing placeholders replacement via instruction_modify_legacy for convenience
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"ok": true}'
mock_model_instance.invoke_llm.return_value = mock_response
instruction = "Test {{#last_run#}} and {{#current#}} and {{#error_message#}}"
LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current_val", instruction, model_config_entity, "ideal"
)
# Verify the call to invoke_llm contains replaced instruction
args, kwargs = mock_model_instance.invoke_llm.call_args
prompt_messages = kwargs["prompt_messages"]
user_msg = prompt_messages[1].content
user_msg_dict = json.loads(user_msg)
assert "null" in user_msg_dict["instruction"] # because last_run is None and current is current_val etc.
assert "current_val" in user_msg_dict["instruction"]
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "No braces here"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert "An unexpected error occurred" in result["error"]
assert "Could not find a valid JSON object" in result["error"]
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
# The exception message is "Expected a JSON object, but got list"
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity):
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
instance = MagicMock()
mock_manager.return_value.get_model_instance.return_value = instance
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"ok": true}'
instance.invoke_llm.return_value = mock_response
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}}
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = workflow
workflow_service.get_node_last_run.return_value = None
LLMGenerator.instruction_modify_workflow(
"tenant_id",
"flow_id",
"node_id",
"current",
"instruction",
model_config_entity,
"ideal",
workflow_service,
)
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert "Failed to generate code" in result["error"]
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "No JSON here"
mock_model_instance.invoke_llm.return_value = mock_response
result = LLMGenerator.instruction_modify_legacy(
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert "An unexpected error occurred" in result["error"]