mirror of https://github.com/langgenius/dify.git
test: add test for core extension, external_data_tool and llm generator (#32468)
This commit is contained in:
parent
7d2054d4f4
commit
245f6b824d
|
|
@ -193,7 +193,8 @@ class LLMGenerator:
|
||||||
error_step = "generate rule config"
|
error_step = "generate rule config"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||||
rule_config["error"] = str(e)
|
error = str(e)
|
||||||
|
error_step = "generate rule config"
|
||||||
|
|
||||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||||
|
|
||||||
|
|
@ -279,7 +280,8 @@ class LLMGenerator:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||||
rule_config["error"] = str(e)
|
error = str(e)
|
||||||
|
error_step = "handle unexpected exception"
|
||||||
|
|
||||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,137 @@
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||||
|
from models.api_based_extension import APIBasedExtensionPoint
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_success(mocker):
|
||||||
|
# Mock httpx.Client and its context manager
|
||||||
|
mock_client = mocker.MagicMock()
|
||||||
|
mock_client_instance = mock_client.__enter__.return_value
|
||||||
|
mocker.patch("httpx.Client", return_value=mock_client)
|
||||||
|
|
||||||
|
mock_response = mocker.MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"result": "success"}
|
||||||
|
mock_client_instance.request.return_value = mock_response
|
||||||
|
|
||||||
|
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||||
|
result = requestor.request(APIBasedExtensionPoint.PING, {"foo": "bar"})
|
||||||
|
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
mock_client_instance.request.assert_called_once_with(
|
||||||
|
method="POST",
|
||||||
|
url="http://example.com",
|
||||||
|
json={"point": APIBasedExtensionPoint.PING.value, "params": {"foo": "bar"}},
|
||||||
|
headers={"Content-Type": "application/json", "Authorization": "Bearer test_key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_with_ssrf_proxy(mocker):
|
||||||
|
# Mock dify_config
|
||||||
|
mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080")
|
||||||
|
mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", "https://proxy:8081")
|
||||||
|
|
||||||
|
# Mock httpx.Client
|
||||||
|
mock_client = mocker.MagicMock()
|
||||||
|
mock_client_class = mocker.patch("httpx.Client", return_value=mock_client)
|
||||||
|
mock_client_instance = mock_client.__enter__.return_value
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
mock_response = mocker.MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"result": "success"}
|
||||||
|
mock_client_instance.request.return_value = mock_response
|
||||||
|
|
||||||
|
# Mock HTTPTransport
|
||||||
|
mock_transport = mocker.patch("httpx.HTTPTransport")
|
||||||
|
|
||||||
|
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||||
|
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||||
|
|
||||||
|
# Verify httpx.Client was called with mounts
|
||||||
|
mock_client_class.assert_called_once()
|
||||||
|
kwargs = mock_client_class.call_args.kwargs
|
||||||
|
assert "mounts" in kwargs
|
||||||
|
assert "http://" in kwargs["mounts"]
|
||||||
|
assert "https://" in kwargs["mounts"]
|
||||||
|
assert mock_transport.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_with_only_one_proxy_config(mocker):
|
||||||
|
# Mock dify_config with only one proxy
|
||||||
|
mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080")
|
||||||
|
mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", None)
|
||||||
|
|
||||||
|
# Mock httpx.Client
|
||||||
|
mock_client = mocker.MagicMock()
|
||||||
|
mock_client_class = mocker.patch("httpx.Client", return_value=mock_client)
|
||||||
|
mock_client_instance = mock_client.__enter__.return_value
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
mock_response = mocker.MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"result": "success"}
|
||||||
|
mock_client_instance.request.return_value = mock_response
|
||||||
|
|
||||||
|
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||||
|
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||||
|
|
||||||
|
# Verify httpx.Client was called with mounts=None (default)
|
||||||
|
mock_client_class.assert_called_once()
|
||||||
|
kwargs = mock_client_class.call_args.kwargs
|
||||||
|
assert kwargs.get("mounts") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_timeout(mocker):
|
||||||
|
mock_client = mocker.MagicMock()
|
||||||
|
mock_client_instance = mock_client.__enter__.return_value
|
||||||
|
mocker.patch("httpx.Client", return_value=mock_client)
|
||||||
|
mock_client_instance.request.side_effect = httpx.TimeoutException("timeout")
|
||||||
|
|
||||||
|
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||||
|
with pytest.raises(ValueError, match="request timeout"):
|
||||||
|
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_connection_error(mocker):
|
||||||
|
mock_client = mocker.MagicMock()
|
||||||
|
mock_client_instance = mock_client.__enter__.return_value
|
||||||
|
mocker.patch("httpx.Client", return_value=mock_client)
|
||||||
|
mock_client_instance.request.side_effect = httpx.RequestError("error")
|
||||||
|
|
||||||
|
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||||
|
with pytest.raises(ValueError, match="request connection error"):
|
||||||
|
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_error_status_code(mocker):
|
||||||
|
mock_client = mocker.MagicMock()
|
||||||
|
mock_client_instance = mock_client.__enter__.return_value
|
||||||
|
mocker.patch("httpx.Client", return_value=mock_client)
|
||||||
|
|
||||||
|
mock_response = mocker.MagicMock()
|
||||||
|
mock_response.status_code = 404
|
||||||
|
mock_response.text = "Not Found"
|
||||||
|
mock_client_instance.request.return_value = mock_response
|
||||||
|
|
||||||
|
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||||
|
with pytest.raises(ValueError, match="request error, status_code: 404, content: Not Found"):
|
||||||
|
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_error_status_code_long_content(mocker):
|
||||||
|
mock_client = mocker.MagicMock()
|
||||||
|
mock_client_instance = mock_client.__enter__.return_value
|
||||||
|
mocker.patch("httpx.Client", return_value=mock_client)
|
||||||
|
|
||||||
|
mock_response = mocker.MagicMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
|
mock_response.text = "A" * 200 # Testing truncation of content
|
||||||
|
mock_client_instance.request.return_value = mock_response
|
||||||
|
|
||||||
|
requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key")
|
||||||
|
expected_content = "A" * 100
|
||||||
|
with pytest.raises(ValueError, match=f"request error, status_code: 500, content: {expected_content}"):
|
||||||
|
requestor.request(APIBasedExtensionPoint.PING, {})
|
||||||
|
|
@ -0,0 +1,281 @@
|
||||||
|
import json
|
||||||
|
import types
|
||||||
|
from unittest.mock import MagicMock, mock_open, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.extension.extensible import Extensible
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtensible:
|
||||||
|
def test_init(self):
|
||||||
|
tenant_id = "tenant_123"
|
||||||
|
config = {"key": "value"}
|
||||||
|
ext = Extensible(tenant_id, config)
|
||||||
|
assert ext.tenant_id == tenant_id
|
||||||
|
assert ext.config == config
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
@patch("core.extension.extensible.os.path.dirname")
|
||||||
|
@patch("core.extension.extensible.os.listdir")
|
||||||
|
@patch("core.extension.extensible.os.path.isdir")
|
||||||
|
@patch("core.extension.extensible.os.path.exists")
|
||||||
|
@patch("core.extension.extensible.Path.read_text")
|
||||||
|
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||||
|
@patch("core.extension.extensible.sort_to_dict_by_position_map")
|
||||||
|
def test_scan_extensions_success(
|
||||||
|
self,
|
||||||
|
mock_sort,
|
||||||
|
mock_module_from_spec,
|
||||||
|
mock_read_text,
|
||||||
|
mock_exists,
|
||||||
|
mock_isdir,
|
||||||
|
mock_listdir,
|
||||||
|
mock_dirname,
|
||||||
|
mock_find_spec,
|
||||||
|
):
|
||||||
|
# Setup
|
||||||
|
package_spec = MagicMock()
|
||||||
|
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||||
|
|
||||||
|
module_spec = MagicMock()
|
||||||
|
module_spec.loader = MagicMock()
|
||||||
|
|
||||||
|
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||||
|
mock_dirname.return_value = "/path/to/pkg"
|
||||||
|
|
||||||
|
mock_listdir.side_effect = [
|
||||||
|
["ext1"], # package_dir
|
||||||
|
["ext1.py", "__builtin__"], # subdir_path
|
||||||
|
]
|
||||||
|
mock_isdir.return_value = True
|
||||||
|
|
||||||
|
mock_exists.return_value = True
|
||||||
|
mock_read_text.return_value = "10"
|
||||||
|
|
||||||
|
# Use types.ModuleType to avoid MagicMock __dict__ issues
|
||||||
|
mock_mod = types.ModuleType("ext1")
|
||||||
|
|
||||||
|
class MockExtension(Extensible):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_mod.MockExtension = MockExtension
|
||||||
|
mock_module_from_spec.return_value = mock_mod
|
||||||
|
|
||||||
|
mock_sort.side_effect = lambda position_map, data, name_func: data
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
results = Extensible.scan_extensions()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].name == "ext1"
|
||||||
|
assert results[0].position == 10
|
||||||
|
assert results[0].builtin is True
|
||||||
|
assert results[0].extension_class == MockExtension
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
def test_scan_extensions_package_not_found(self, mock_find_spec):
|
||||||
|
mock_find_spec.return_value = None
|
||||||
|
with pytest.raises(ImportError, match="Could not find package"):
|
||||||
|
Extensible.scan_extensions()
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
@patch("core.extension.extensible.os.path.dirname")
|
||||||
|
@patch("core.extension.extensible.os.listdir")
|
||||||
|
@patch("core.extension.extensible.os.path.isdir")
|
||||||
|
def test_scan_extensions_skip_subdirs(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec):
|
||||||
|
package_spec = MagicMock()
|
||||||
|
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||||
|
mock_find_spec.return_value = package_spec
|
||||||
|
mock_dirname.return_value = "/path/to/pkg"
|
||||||
|
|
||||||
|
mock_listdir.side_effect = [["__pycache__", "not_a_dir", "missing_py_file"], []]
|
||||||
|
|
||||||
|
mock_isdir.side_effect = [False, True]
|
||||||
|
|
||||||
|
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
|
||||||
|
results = Extensible.scan_extensions()
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
@patch("core.extension.extensible.os.path.dirname")
|
||||||
|
@patch("core.extension.extensible.os.listdir")
|
||||||
|
@patch("core.extension.extensible.os.path.isdir")
|
||||||
|
@patch("core.extension.extensible.os.path.exists")
|
||||||
|
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||||
|
def test_scan_extensions_not_builtin_success(
|
||||||
|
self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
|
||||||
|
):
|
||||||
|
package_spec = MagicMock()
|
||||||
|
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||||
|
|
||||||
|
module_spec = MagicMock()
|
||||||
|
module_spec.loader = MagicMock()
|
||||||
|
|
||||||
|
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||||
|
mock_dirname.return_value = "/path/to/pkg"
|
||||||
|
|
||||||
|
mock_listdir.side_effect = [["ext1"], ["ext1.py", "schema.json"]]
|
||||||
|
mock_isdir.return_value = True
|
||||||
|
|
||||||
|
# exists checks: only schema.json needs to exist
|
||||||
|
mock_exists.return_value = True
|
||||||
|
|
||||||
|
mock_mod = types.ModuleType("ext1")
|
||||||
|
|
||||||
|
class MockExtension(Extensible):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_mod.MockExtension = MockExtension
|
||||||
|
mock_module_from_spec.return_value = mock_mod
|
||||||
|
|
||||||
|
schema_content = json.dumps({"label": {"en": "Test"}, "form_schema": [{"name": "field1"}]})
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("builtins.open", mock_open(read_data=schema_content)),
|
||||||
|
patch(
|
||||||
|
"core.extension.extensible.sort_to_dict_by_position_map",
|
||||||
|
side_effect=lambda position_map, data, name_func: data,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
results = Extensible.scan_extensions()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].name == "ext1"
|
||||||
|
assert results[0].builtin is False
|
||||||
|
assert results[0].label == {"en": "Test"}
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
@patch("core.extension.extensible.os.path.dirname")
|
||||||
|
@patch("core.extension.extensible.os.listdir")
|
||||||
|
@patch("core.extension.extensible.os.path.isdir")
|
||||||
|
@patch("core.extension.extensible.os.path.exists")
|
||||||
|
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||||
|
def test_scan_extensions_not_builtin_missing_schema(
|
||||||
|
self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
|
||||||
|
):
|
||||||
|
package_spec = MagicMock()
|
||||||
|
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||||
|
|
||||||
|
module_spec = MagicMock()
|
||||||
|
module_spec.loader = MagicMock()
|
||||||
|
|
||||||
|
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||||
|
mock_dirname.return_value = "/path/to/pkg"
|
||||||
|
|
||||||
|
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
|
||||||
|
mock_isdir.return_value = True
|
||||||
|
|
||||||
|
# exists: only schema.json checked, and return False
|
||||||
|
mock_exists.return_value = False
|
||||||
|
|
||||||
|
mock_mod = types.ModuleType("ext1")
|
||||||
|
|
||||||
|
class MockExtension(Extensible):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_mod.MockExtension = MockExtension
|
||||||
|
mock_module_from_spec.return_value = mock_mod
|
||||||
|
|
||||||
|
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
|
||||||
|
results = Extensible.scan_extensions()
|
||||||
|
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
@patch("core.extension.extensible.os.path.dirname")
|
||||||
|
@patch("core.extension.extensible.os.listdir")
|
||||||
|
@patch("core.extension.extensible.os.path.isdir")
|
||||||
|
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||||
|
@patch("core.extension.extensible.os.path.exists")
|
||||||
|
def test_scan_extensions_no_extension_class(
|
||||||
|
self, mock_exists, mock_module_from_spec, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
|
||||||
|
):
|
||||||
|
package_spec = MagicMock()
|
||||||
|
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||||
|
module_spec = MagicMock()
|
||||||
|
module_spec.loader = MagicMock()
|
||||||
|
|
||||||
|
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||||
|
mock_dirname.return_value = "/path/to/pkg"
|
||||||
|
|
||||||
|
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
|
||||||
|
mock_isdir.return_value = True
|
||||||
|
|
||||||
|
# Mock not builtin
|
||||||
|
mock_exists.return_value = False
|
||||||
|
|
||||||
|
mock_mod = types.ModuleType("ext1")
|
||||||
|
mock_mod.SomeOtherClass = type("SomeOtherClass", (), {})
|
||||||
|
mock_module_from_spec.return_value = mock_mod
|
||||||
|
|
||||||
|
# We need to ensure we don't crash if checking schema (but we won't reach there because class not found)
|
||||||
|
|
||||||
|
with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]):
|
||||||
|
results = Extensible.scan_extensions()
|
||||||
|
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
@patch("core.extension.extensible.os.path.dirname")
|
||||||
|
@patch("core.extension.extensible.os.listdir")
|
||||||
|
@patch("core.extension.extensible.os.path.isdir")
|
||||||
|
def test_scan_extensions_module_import_error(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec):
|
||||||
|
package_spec = MagicMock()
|
||||||
|
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||||
|
|
||||||
|
mock_find_spec.side_effect = [package_spec, None] # No module spec
|
||||||
|
mock_dirname.return_value = "/path/to/pkg"
|
||||||
|
|
||||||
|
mock_listdir.side_effect = [["ext1"], ["ext1.py"]]
|
||||||
|
mock_isdir.return_value = True
|
||||||
|
|
||||||
|
with pytest.raises(ImportError, match="Failed to load module"):
|
||||||
|
Extensible.scan_extensions()
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
def test_scan_extensions_general_exception(self, mock_find_spec):
|
||||||
|
mock_find_spec.side_effect = Exception("Unexpected error")
|
||||||
|
with pytest.raises(Exception, match="Unexpected error"):
|
||||||
|
Extensible.scan_extensions()
|
||||||
|
|
||||||
|
@patch("core.extension.extensible.importlib.util.find_spec")
|
||||||
|
@patch("core.extension.extensible.os.path.dirname")
|
||||||
|
@patch("core.extension.extensible.os.listdir")
|
||||||
|
@patch("core.extension.extensible.os.path.isdir")
|
||||||
|
@patch("core.extension.extensible.os.path.exists")
|
||||||
|
@patch("core.extension.extensible.Path.read_text")
|
||||||
|
@patch("core.extension.extensible.importlib.util.module_from_spec")
|
||||||
|
def test_scan_extensions_builtin_without_position_file(
|
||||||
|
self, mock_module_from_spec, mock_read_text, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec
|
||||||
|
):
|
||||||
|
package_spec = MagicMock()
|
||||||
|
package_spec.origin = "/path/to/pkg/__init__.py"
|
||||||
|
module_spec = MagicMock()
|
||||||
|
module_spec.loader = MagicMock()
|
||||||
|
|
||||||
|
mock_find_spec.side_effect = [package_spec, module_spec]
|
||||||
|
mock_dirname.return_value = "/path/to/pkg"
|
||||||
|
mock_listdir.side_effect = [["ext1"], ["ext1.py", "__builtin__"]]
|
||||||
|
mock_isdir.return_value = True
|
||||||
|
|
||||||
|
# builtin exists in listdir, but os.path.exists(builtin_file_path) returns False
|
||||||
|
mock_exists.return_value = False
|
||||||
|
|
||||||
|
mock_mod = types.ModuleType("ext1")
|
||||||
|
|
||||||
|
class MockExtension(Extensible):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_mod.MockExtension = MockExtension
|
||||||
|
mock_module_from_spec.return_value = mock_mod
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"core.extension.extensible.sort_to_dict_by_position_map",
|
||||||
|
side_effect=lambda position_map, data, name_func: data,
|
||||||
|
):
|
||||||
|
results = Extensible.scan_extensions()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].position == 0
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.extension.extensible import ExtensionModule, ModuleExtension
|
||||||
|
from core.extension.extension import Extension
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtension:
|
||||||
|
def setup_method(self):
|
||||||
|
# Reset the private class attribute before each test
|
||||||
|
Extension._Extension__module_extensions = {}
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
# Mock scan_extensions for Moderation and ExternalDataTool
|
||||||
|
mock_mod_extensions = {"mod1": ModuleExtension(name="mod1")}
|
||||||
|
mock_ext_extensions = {"ext1": ModuleExtension(name="ext1")}
|
||||||
|
|
||||||
|
extension = Extension()
|
||||||
|
|
||||||
|
# We need to mock scan_extensions on the classes defined in Extension.module_classes
|
||||||
|
with (
|
||||||
|
patch("core.extension.extension.Moderation.scan_extensions", return_value=mock_mod_extensions),
|
||||||
|
patch("core.extension.extension.ExternalDataTool.scan_extensions", return_value=mock_ext_extensions),
|
||||||
|
):
|
||||||
|
extension.init()
|
||||||
|
|
||||||
|
# Check if internal state is updated
|
||||||
|
internal_state = Extension._Extension__module_extensions
|
||||||
|
assert internal_state[ExtensionModule.MODERATION.value] == mock_mod_extensions
|
||||||
|
assert internal_state[ExtensionModule.EXTERNAL_DATA_TOOL.value] == mock_ext_extensions
|
||||||
|
|
||||||
|
def test_module_extensions_success(self):
|
||||||
|
# Setup data
|
||||||
|
mock_extensions = {"name1": ModuleExtension(name="name1"), "name2": ModuleExtension(name="name2")}
|
||||||
|
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: mock_extensions}
|
||||||
|
|
||||||
|
extension = Extension()
|
||||||
|
result = extension.module_extensions(ExtensionModule.MODERATION.value)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert any(e.name == "name1" for e in result)
|
||||||
|
assert any(e.name == "name2" for e in result)
|
||||||
|
|
||||||
|
def test_module_extensions_not_found(self):
|
||||||
|
extension = Extension()
|
||||||
|
with pytest.raises(ValueError, match="Extension Module unknown not found"):
|
||||||
|
extension.module_extensions("unknown")
|
||||||
|
|
||||||
|
def test_module_extension_success(self):
|
||||||
|
mock_ext = ModuleExtension(name="test_ext")
|
||||||
|
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
|
||||||
|
|
||||||
|
extension = Extension()
|
||||||
|
result = extension.module_extension(ExtensionModule.MODERATION, "test_ext")
|
||||||
|
assert result == mock_ext
|
||||||
|
|
||||||
|
def test_module_extension_module_not_found(self):
|
||||||
|
extension = Extension()
|
||||||
|
# ExtensionModule.MODERATION is "moderation"
|
||||||
|
with pytest.raises(ValueError, match="Extension Module moderation not found"):
|
||||||
|
extension.module_extension(ExtensionModule.MODERATION, "any")
|
||||||
|
|
||||||
|
def test_module_extension_extension_not_found(self):
|
||||||
|
# We need a non-empty dict because 'if not module_extensions' in extension.py
|
||||||
|
# returns True for an empty dict, which raises the module not found error instead.
|
||||||
|
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"other": MagicMock()}}
|
||||||
|
|
||||||
|
extension = Extension()
|
||||||
|
with pytest.raises(ValueError, match="Extension unknown not found"):
|
||||||
|
extension.module_extension(ExtensionModule.MODERATION, "unknown")
|
||||||
|
|
||||||
|
def test_extension_class_success(self):
|
||||||
|
class MockClass:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_ext = ModuleExtension(name="test_ext", extension_class=MockClass)
|
||||||
|
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
|
||||||
|
|
||||||
|
extension = Extension()
|
||||||
|
result = extension.extension_class(ExtensionModule.MODERATION, "test_ext")
|
||||||
|
assert result == MockClass
|
||||||
|
|
||||||
|
def test_extension_class_none(self):
|
||||||
|
mock_ext = ModuleExtension(name="test_ext", extension_class=None)
|
||||||
|
Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}}
|
||||||
|
|
||||||
|
extension = Extension()
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
extension.extension_class(ExtensionModule.MODERATION, "test_ext")
|
||||||
|
|
@ -0,0 +1,145 @@
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.external_data_tool.api.api import ApiExternalDataTool
|
||||||
|
from models.api_based_extension import APIBasedExtensionPoint
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_external_data_tool_name():
|
||||||
|
assert ApiExternalDataTool.name == "api"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.external_data_tool.api.api.db")
|
||||||
|
def test_validate_config_success(mock_db):
|
||||||
|
mock_extension = MagicMock()
|
||||||
|
mock_extension.id = "ext_id"
|
||||||
|
mock_extension.tenant_id = "tenant_id"
|
||||||
|
mock_db.session.scalar.return_value = mock_extension
|
||||||
|
|
||||||
|
# Should not raise exception
|
||||||
|
ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_config_missing_id():
|
||||||
|
with pytest.raises(ValueError, match="api_based_extension_id is required"):
|
||||||
|
ApiExternalDataTool.validate_config("tenant_id", {})
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.external_data_tool.api.api.db")
|
||||||
|
def test_validate_config_invalid_id(mock_db):
|
||||||
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="api_based_extension_id is invalid"):
|
||||||
|
ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def api_tool():
|
||||||
|
# Use standard kwargs as it inherits from ExternalDataTool which is typically a Pydantic BaseModel
|
||||||
|
return ApiExternalDataTool(
|
||||||
|
tenant_id="tenant_id", app_id="app_id", variable="var1", config={"api_based_extension_id": "ext_id"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.external_data_tool.api.api.db")
|
||||||
|
@patch("core.external_data_tool.api.api.encrypter")
|
||||||
|
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
|
||||||
|
def test_query_success(mock_requestor_class, mock_encrypter, mock_db, api_tool):
|
||||||
|
mock_extension = MagicMock()
|
||||||
|
mock_extension.id = "ext_id"
|
||||||
|
mock_extension.tenant_id = "tenant_id"
|
||||||
|
mock_extension.api_endpoint = "http://api"
|
||||||
|
mock_extension.api_key = "encrypted_key"
|
||||||
|
mock_db.session.scalar.return_value = mock_extension
|
||||||
|
mock_encrypter.decrypt_token.return_value = "decrypted_key"
|
||||||
|
|
||||||
|
mock_requestor = mock_requestor_class.return_value
|
||||||
|
mock_requestor.request.return_value = {"result": "success_result"}
|
||||||
|
|
||||||
|
res = api_tool.query({"input1": "value1"}, "query_str")
|
||||||
|
|
||||||
|
assert res == "success_result"
|
||||||
|
|
||||||
|
mock_requestor_class.assert_called_once_with(api_endpoint="http://api", api_key="decrypted_key")
|
||||||
|
mock_requestor.request.assert_called_once_with(
|
||||||
|
point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
|
||||||
|
params={"app_id": "app_id", "tool_variable": "var1", "inputs": {"input1": "value1"}, "query": "query_str"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_missing_config():
|
||||||
|
api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1")
|
||||||
|
api_tool.config = None # Force None
|
||||||
|
with pytest.raises(ValueError, match="config is required"):
|
||||||
|
api_tool.query({}, "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_missing_extension_id():
|
||||||
|
api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1", config={"dummy": "value"})
|
||||||
|
with pytest.raises(AssertionError, match="api_based_extension_id is required"):
|
||||||
|
api_tool.query({}, "")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.external_data_tool.api.api.db")
|
||||||
|
def test_query_invalid_extension(mock_db, api_tool):
|
||||||
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=".*error: api_based_extension_id is invalid"):
|
||||||
|
api_tool.query({}, "")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.external_data_tool.api.api.db")
|
||||||
|
@patch("core.external_data_tool.api.api.encrypter")
|
||||||
|
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
|
||||||
|
def test_query_requestor_init_error(mock_requestor_class, mock_encrypter, mock_db, api_tool):
|
||||||
|
mock_extension = MagicMock()
|
||||||
|
mock_extension.id = "ext_id"
|
||||||
|
mock_extension.tenant_id = "tenant_id"
|
||||||
|
mock_extension.api_endpoint = "http://api"
|
||||||
|
mock_extension.api_key = "encrypted_key"
|
||||||
|
mock_db.session.scalar.return_value = mock_extension
|
||||||
|
mock_encrypter.decrypt_token.return_value = "decrypted_key"
|
||||||
|
|
||||||
|
mock_requestor_class.side_effect = Exception("init error")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=".*error: init error"):
|
||||||
|
api_tool.query({}, "")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.external_data_tool.api.api.db")
|
||||||
|
@patch("core.external_data_tool.api.api.encrypter")
|
||||||
|
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
|
||||||
|
def test_query_no_result_in_response(mock_requestor_class, mock_encrypter, mock_db, api_tool):
|
||||||
|
mock_extension = MagicMock()
|
||||||
|
mock_extension.id = "ext_id"
|
||||||
|
mock_extension.tenant_id = "tenant_id"
|
||||||
|
mock_extension.api_endpoint = "http://api"
|
||||||
|
mock_extension.api_key = "encrypted_key"
|
||||||
|
mock_db.session.scalar.return_value = mock_extension
|
||||||
|
mock_encrypter.decrypt_token.return_value = "decrypted_key"
|
||||||
|
|
||||||
|
mock_requestor = mock_requestor_class.return_value
|
||||||
|
mock_requestor.request.return_value = {"other": "value"}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=".*error: result not found in response"):
|
||||||
|
api_tool.query({}, "")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.external_data_tool.api.api.db")
|
||||||
|
@patch("core.external_data_tool.api.api.encrypter")
|
||||||
|
@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor")
|
||||||
|
def test_query_result_not_string(mock_requestor_class, mock_encrypter, mock_db, api_tool):
|
||||||
|
mock_extension = MagicMock()
|
||||||
|
mock_extension.id = "ext_id"
|
||||||
|
mock_extension.tenant_id = "tenant_id"
|
||||||
|
mock_extension.api_endpoint = "http://api"
|
||||||
|
mock_extension.api_key = "encrypted_key"
|
||||||
|
mock_db.session.scalar.return_value = mock_extension
|
||||||
|
mock_encrypter.decrypt_token.return_value = "decrypted_key"
|
||||||
|
|
||||||
|
mock_requestor = mock_requestor_class.return_value
|
||||||
|
mock_requestor.request.return_value = {"result": 123} # Not a string
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=".*error: result is not string"):
|
||||||
|
api_tool.query({}, "")
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.extension.extensible import ExtensionModule
|
||||||
|
from core.external_data_tool.base import ExternalDataTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestExternalDataTool:
|
||||||
|
def test_module_attribute(self):
|
||||||
|
assert ExternalDataTool.module == ExtensionModule.EXTERNAL_DATA_TOOL
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
# Create a concrete subclass to test init
|
||||||
|
class ConcreteTool(ExternalDataTool):
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict):
|
||||||
|
return super().validate_config(tenant_id, config)
|
||||||
|
|
||||||
|
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||||
|
return super().query(inputs, query)
|
||||||
|
|
||||||
|
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"})
|
||||||
|
assert tool.tenant_id == "tenant_1"
|
||||||
|
assert tool.app_id == "app_1"
|
||||||
|
assert tool.variable == "var_1"
|
||||||
|
assert tool.config == {"key": "value"}
|
||||||
|
|
||||||
|
def test_init_without_config(self):
|
||||||
|
# Create a concrete subclass to test init
|
||||||
|
class ConcreteTool(ExternalDataTool):
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
|
||||||
|
assert tool.tenant_id == "tenant_1"
|
||||||
|
assert tool.app_id == "app_1"
|
||||||
|
assert tool.variable == "var_1"
|
||||||
|
assert tool.config is None
|
||||||
|
|
||||||
|
def test_validate_config_raises_not_implemented(self):
|
||||||
|
class ConcreteTool(ExternalDataTool):
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict):
|
||||||
|
return super().validate_config(tenant_id, config)
|
||||||
|
|
||||||
|
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
ConcreteTool.validate_config("tenant_1", {})
|
||||||
|
|
||||||
|
def test_query_raises_not_implemented(self):
|
||||||
|
class ConcreteTool(ExternalDataTool):
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||||
|
return super().query(inputs, query)
|
||||||
|
|
||||||
|
tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1")
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
tool.query({})
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||||
|
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||||
|
|
||||||
|
|
||||||
|
class TestExternalDataFetch:
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self):
|
||||||
|
app = Flask(__name__)
|
||||||
|
return app
|
||||||
|
|
||||||
|
def test_fetch_success(self, app):
|
||||||
|
with app.app_context():
|
||||||
|
fetcher = ExternalDataFetch()
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
tool1 = ExternalDataVariableEntity(variable="var1", type="type1", config={"c1": "v1"})
|
||||||
|
tool2 = ExternalDataVariableEntity(variable="var2", type="type2", config={"c2": "v2"})
|
||||||
|
|
||||||
|
external_data_tools = [tool1, tool2]
|
||||||
|
inputs = {"input_key": "input_value"}
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory:
|
||||||
|
# Create distinct mock instances for each tool to ensure deterministic results
|
||||||
|
# This approach is robust regardless of thread scheduling order
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
def factory_side_effect(*args, **kwargs):
|
||||||
|
variable = kwargs.get("variable")
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
if variable == "var1":
|
||||||
|
mock_instance.query.return_value = "result1"
|
||||||
|
elif variable == "var2":
|
||||||
|
mock_instance.query.return_value = "result2"
|
||||||
|
return mock_instance
|
||||||
|
|
||||||
|
MockFactory.side_effect = factory_side_effect
|
||||||
|
|
||||||
|
result_inputs = fetcher.fetch(
|
||||||
|
tenant_id="tenant1",
|
||||||
|
app_id="app1",
|
||||||
|
external_data_tools=external_data_tools,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each tool gets its deterministic result regardless of thread completion order
|
||||||
|
assert result_inputs["var1"] == "result1"
|
||||||
|
assert result_inputs["var2"] == "result2"
|
||||||
|
assert result_inputs["input_key"] == "input_value"
|
||||||
|
assert len(result_inputs) == 3
|
||||||
|
|
||||||
|
# Verify factory calls
|
||||||
|
assert MockFactory.call_count == 2
|
||||||
|
MockFactory.assert_any_call(
|
||||||
|
name="type1", tenant_id="tenant1", app_id="app1", variable="var1", config={"c1": "v1"}
|
||||||
|
)
|
||||||
|
MockFactory.assert_any_call(
|
||||||
|
name="type2", tenant_id="tenant1", app_id="app1", variable="var2", config={"c2": "v2"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fetch_no_tools(self):
|
||||||
|
# We don't necessarily need app_context if there are no tools,
|
||||||
|
# but fetch calls current_app._get_current_object() only inside the loop.
|
||||||
|
# Wait, let's look at the code.
|
||||||
|
# for tool in external_data_tools:
|
||||||
|
# executor.submit(..., current_app._get_current_object(), ...)
|
||||||
|
# So if external_data_tools is empty, it shouldn't access current_app.
|
||||||
|
fetcher = ExternalDataFetch()
|
||||||
|
inputs = {"input_key": "input_value"}
|
||||||
|
result_inputs = fetcher.fetch(
|
||||||
|
tenant_id="tenant1", app_id="app1", external_data_tools=[], inputs=inputs, query="test query"
|
||||||
|
)
|
||||||
|
assert result_inputs == inputs
|
||||||
|
assert result_inputs is not inputs # Should be a copy
|
||||||
|
|
||||||
|
def test_fetch_with_none_variable(self, app):
|
||||||
|
with app.app_context():
|
||||||
|
fetcher = ExternalDataFetch()
|
||||||
|
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={})
|
||||||
|
|
||||||
|
# Patch _query_external_data_tool to return None variable
|
||||||
|
with patch.object(ExternalDataFetch, "_query_external_data_tool") as mock_query:
|
||||||
|
mock_query.return_value = (None, "some_result")
|
||||||
|
|
||||||
|
result_inputs = fetcher.fetch(
|
||||||
|
tenant_id="t1", app_id="a1", external_data_tools=[tool], inputs={"in": "val"}, query="q"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "var1" not in result_inputs
|
||||||
|
assert result_inputs == {"in": "val"}
|
||||||
|
|
||||||
|
def test_query_external_data_tool(self, app):
|
||||||
|
fetcher = ExternalDataFetch()
|
||||||
|
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"})
|
||||||
|
|
||||||
|
with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory:
|
||||||
|
mock_factory_instance = MockFactory.return_value
|
||||||
|
mock_factory_instance.query.return_value = "query_result"
|
||||||
|
|
||||||
|
var, res = fetcher._query_external_data_tool(
|
||||||
|
flask_app=app, tenant_id="t1", app_id="a1", external_data_tool=tool, inputs={"i": "v"}, query="q"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert var == "var1"
|
||||||
|
assert res == "query_result"
|
||||||
|
MockFactory.assert_called_once_with(
|
||||||
|
name="type1", tenant_id="t1", app_id="a1", variable="var1", config={"k": "v"}
|
||||||
|
)
|
||||||
|
mock_factory_instance.query.assert_called_once_with(inputs={"i": "v"}, query="q")
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from core.extension.extensible import ExtensionModule
|
||||||
|
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||||
|
|
||||||
|
|
||||||
|
def test_external_data_tool_factory_init():
|
||||||
|
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
|
||||||
|
mock_extension_class = MagicMock()
|
||||||
|
mock_code_based_extension.extension_class.return_value = mock_extension_class
|
||||||
|
|
||||||
|
name = "test_tool"
|
||||||
|
tenant_id = "tenant_123"
|
||||||
|
app_id = "app_456"
|
||||||
|
variable = "var_v"
|
||||||
|
config = {"key": "value"}
|
||||||
|
|
||||||
|
factory = ExternalDataToolFactory(name, tenant_id, app_id, variable, config)
|
||||||
|
|
||||||
|
mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||||
|
mock_extension_class.assert_called_once_with(
|
||||||
|
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_external_data_tool_factory_validate_config():
|
||||||
|
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
|
||||||
|
mock_extension_class = MagicMock()
|
||||||
|
mock_code_based_extension.extension_class.return_value = mock_extension_class
|
||||||
|
|
||||||
|
name = "test_tool"
|
||||||
|
tenant_id = "tenant_123"
|
||||||
|
config = {"key": "value"}
|
||||||
|
|
||||||
|
ExternalDataToolFactory.validate_config(name, tenant_id, config)
|
||||||
|
|
||||||
|
mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||||
|
mock_extension_class.validate_config.assert_called_once_with(tenant_id, config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_external_data_tool_factory_query():
|
||||||
|
with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension:
|
||||||
|
mock_extension_class = MagicMock()
|
||||||
|
mock_extension_instance = MagicMock()
|
||||||
|
mock_extension_class.return_value = mock_extension_instance
|
||||||
|
mock_code_based_extension.extension_class.return_value = mock_extension_class
|
||||||
|
|
||||||
|
mock_extension_instance.query.return_value = "query_result"
|
||||||
|
|
||||||
|
factory = ExternalDataToolFactory("name", "tenant", "app", "var", {})
|
||||||
|
|
||||||
|
inputs = {"input_key": "input_value"}
|
||||||
|
query = "search_query"
|
||||||
|
|
||||||
|
result = factory.query(inputs, query)
|
||||||
|
|
||||||
|
assert result == "query_result"
|
||||||
|
mock_extension_instance.query.assert_called_once_with(inputs, query)
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
|
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||||
|
from core.llm_generator.prompts import (
|
||||||
|
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||||
|
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||||
|
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleConfigGeneratorOutputParser:
|
||||||
|
def test_get_format_instructions(self):
|
||||||
|
parser = RuleConfigGeneratorOutputParser()
|
||||||
|
instructions = parser.get_format_instructions()
|
||||||
|
assert instructions == (
|
||||||
|
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||||
|
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||||
|
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_parse_success(self):
|
||||||
|
parser = RuleConfigGeneratorOutputParser()
|
||||||
|
text = """
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "This is a prompt",
|
||||||
|
"variables": ["var1", "var2"],
|
||||||
|
"opening_statement": "Hello!"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
result = parser.parse(text)
|
||||||
|
assert result["prompt"] == "This is a prompt"
|
||||||
|
assert result["variables"] == ["var1", "var2"]
|
||||||
|
assert result["opening_statement"] == "Hello!"
|
||||||
|
|
||||||
|
def test_parse_invalid_json(self):
|
||||||
|
parser = RuleConfigGeneratorOutputParser()
|
||||||
|
text = "invalid json"
|
||||||
|
with pytest.raises(OutputParserError) as excinfo:
|
||||||
|
parser.parse(text)
|
||||||
|
assert "Parsing text" in str(excinfo.value)
|
||||||
|
assert "could not find json block in the output" in str(excinfo.value)
|
||||||
|
|
||||||
|
def test_parse_missing_keys(self):
|
||||||
|
parser = RuleConfigGeneratorOutputParser()
|
||||||
|
text = """
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "This is a prompt",
|
||||||
|
"variables": ["var1", "var2"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with pytest.raises(OutputParserError) as excinfo:
|
||||||
|
parser.parse(text)
|
||||||
|
assert "expected key `opening_statement` to be present" in str(excinfo.value)
|
||||||
|
|
||||||
|
def test_parse_wrong_type_prompt(self):
|
||||||
|
parser = RuleConfigGeneratorOutputParser()
|
||||||
|
text = """
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": 123,
|
||||||
|
"variables": ["var1", "var2"],
|
||||||
|
"opening_statement": "Hello!"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with pytest.raises(OutputParserError) as excinfo:
|
||||||
|
parser.parse(text)
|
||||||
|
assert "Expected 'prompt' to be a string" in str(excinfo.value)
|
||||||
|
|
||||||
|
def test_parse_wrong_type_variables(self):
|
||||||
|
parser = RuleConfigGeneratorOutputParser()
|
||||||
|
text = """
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "This is a prompt",
|
||||||
|
"variables": "not a list",
|
||||||
|
"opening_statement": "Hello!"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with pytest.raises(OutputParserError) as excinfo:
|
||||||
|
parser.parse(text)
|
||||||
|
assert "Expected 'variables' to be a list" in str(excinfo.value)
|
||||||
|
|
||||||
|
def test_parse_wrong_type_opening_statement(self):
|
||||||
|
parser = RuleConfigGeneratorOutputParser()
|
||||||
|
text = """
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "This is a prompt",
|
||||||
|
"variables": ["var1", "var2"],
|
||||||
|
"opening_statement": 123
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with pytest.raises(OutputParserError) as excinfo:
|
||||||
|
parser.parse(text)
|
||||||
|
assert "Expected 'opening_statement' to be a str" in str(excinfo.value)
|
||||||
|
|
@ -0,0 +1,402 @@
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
|
from core.llm_generator.output_parser.structured_output import (
|
||||||
|
ResponseFormat,
|
||||||
|
_handle_native_json_schema,
|
||||||
|
_handle_prompt_based_schema,
|
||||||
|
_parse_structured_output,
|
||||||
|
_prepare_schema_for_model,
|
||||||
|
_set_response_format,
|
||||||
|
convert_boolean_to_string,
|
||||||
|
invoke_llm_with_structured_output,
|
||||||
|
remove_additional_properties,
|
||||||
|
)
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
from dify_graph.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
LLMResultWithStructuredOutput,
|
||||||
|
LLMUsage,
|
||||||
|
)
|
||||||
|
from dify_graph.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType
|
||||||
|
|
||||||
|
|
||||||
|
class TestStructuredOutput:
|
||||||
|
def test_remove_additional_properties(self):
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||||
|
"additionalProperties": False,
|
||||||
|
"nested": {"type": "object", "additionalProperties": True},
|
||||||
|
"items": [{"type": "object", "additionalProperties": False}],
|
||||||
|
}
|
||||||
|
remove_additional_properties(schema)
|
||||||
|
assert "additionalProperties" not in schema
|
||||||
|
assert "additionalProperties" not in schema["nested"]
|
||||||
|
assert "additionalProperties" not in schema["items"][0]
|
||||||
|
|
||||||
|
# Test with non-dict input
|
||||||
|
remove_additional_properties(None) # Should not raise
|
||||||
|
remove_additional_properties([]) # Should not raise
|
||||||
|
|
||||||
|
def test_convert_boolean_to_string(self):
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"is_active": {"type": "boolean"},
|
||||||
|
"tags": {"type": "array", "items": {"type": "boolean"}},
|
||||||
|
"list_schema": [{"type": "boolean"}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
convert_boolean_to_string(schema)
|
||||||
|
assert schema["properties"]["is_active"]["type"] == "string"
|
||||||
|
assert schema["properties"]["tags"]["items"]["type"] == "string"
|
||||||
|
assert schema["properties"]["list_schema"][0]["type"] == "string"
|
||||||
|
|
||||||
|
# Test with non-dict input
|
||||||
|
convert_boolean_to_string(None) # Should not raise
|
||||||
|
convert_boolean_to_string([]) # Should not raise
|
||||||
|
|
||||||
|
def test_parse_structured_output_valid(self):
|
||||||
|
text = '{"key": "value"}'
|
||||||
|
assert _parse_structured_output(text) == {"key": "value"}
|
||||||
|
|
||||||
|
def test_parse_structured_output_non_dict_valid_json(self):
|
||||||
|
# Even if it's valid JSON, if it's not a dict, it should try repair or fail
|
||||||
|
text = '["a", "b"]'
|
||||||
|
with patch("json_repair.loads") as mock_repair:
|
||||||
|
mock_repair.return_value = {"key": "value"}
|
||||||
|
assert _parse_structured_output(text) == {"key": "value"}
|
||||||
|
|
||||||
|
def test_parse_structured_output_not_dict_fail_via_validate(self):
|
||||||
|
# Force TypeAdapter to return a non-dict to trigger line 292
|
||||||
|
with patch("pydantic.TypeAdapter.validate_json") as mock_validate:
|
||||||
|
mock_validate.return_value = ["a list"]
|
||||||
|
with pytest.raises(OutputParserError) as excinfo:
|
||||||
|
_parse_structured_output('["a list"]')
|
||||||
|
assert "Failed to parse structured output" in str(excinfo.value)
|
||||||
|
|
||||||
|
def test_parse_structured_output_repair_success(self):
|
||||||
|
text = "{'key': 'value'}" # Invalid JSON (single quotes)
|
||||||
|
# json_repair should handle this
|
||||||
|
assert _parse_structured_output(text) == {"key": "value"}
|
||||||
|
|
||||||
|
def test_parse_structured_output_repair_list(self):
|
||||||
|
# Deepseek-r1 case: result is a list containing a dict
|
||||||
|
text = '[{"key": "value"}]'
|
||||||
|
assert _parse_structured_output(text) == {"key": "value"}
|
||||||
|
|
||||||
|
def test_parse_structured_output_repair_list_no_dict(self):
|
||||||
|
# Deepseek-r1 case: result is a list with NO dict
|
||||||
|
text = "[1, 2, 3]"
|
||||||
|
assert _parse_structured_output(text) == {}
|
||||||
|
|
||||||
|
def test_parse_structured_output_repair_fail(self):
|
||||||
|
text = "not a json at all"
|
||||||
|
with patch("json_repair.loads") as mock_repair:
|
||||||
|
mock_repair.return_value = "still not a dict or list"
|
||||||
|
with pytest.raises(OutputParserError):
|
||||||
|
_parse_structured_output(text)
|
||||||
|
|
||||||
|
def test_set_response_format(self):
|
||||||
|
# Test JSON
|
||||||
|
params = {}
|
||||||
|
rules = [
|
||||||
|
ParameterRule(
|
||||||
|
name="response_format",
|
||||||
|
label={"en_US": ""},
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
help={"en_US": ""},
|
||||||
|
options=[ResponseFormat.JSON],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
_set_response_format(params, rules)
|
||||||
|
assert params["response_format"] == ResponseFormat.JSON
|
||||||
|
|
||||||
|
# Test JSON_OBJECT
|
||||||
|
params = {}
|
||||||
|
rules = [
|
||||||
|
ParameterRule(
|
||||||
|
name="response_format",
|
||||||
|
label={"en_US": ""},
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
help={"en_US": ""},
|
||||||
|
options=[ResponseFormat.JSON_OBJECT],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
_set_response_format(params, rules)
|
||||||
|
assert params["response_format"] == ResponseFormat.JSON_OBJECT
|
||||||
|
|
||||||
|
def test_handle_native_json_schema(self):
|
||||||
|
provider = "openai"
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.model = "gpt-4"
|
||||||
|
structured_output_schema = {"type": "object"}
|
||||||
|
model_parameters = {}
|
||||||
|
rules = [
|
||||||
|
ParameterRule(
|
||||||
|
name="response_format",
|
||||||
|
label={"en_US": ""},
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
help={"en_US": ""},
|
||||||
|
options=[ResponseFormat.JSON_SCHEMA],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
updated_params = _handle_native_json_schema(
|
||||||
|
provider, model_schema, structured_output_schema, model_parameters, rules
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "json_schema" in updated_params
|
||||||
|
assert json.loads(updated_params["json_schema"]) == {"schema": {"type": "object"}, "name": "llm_response"}
|
||||||
|
assert updated_params["response_format"] == ResponseFormat.JSON_SCHEMA
|
||||||
|
|
||||||
|
def test_handle_native_json_schema_no_format_rule(self):
|
||||||
|
provider = "openai"
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.model = "gpt-4"
|
||||||
|
structured_output_schema = {"type": "object"}
|
||||||
|
model_parameters = {}
|
||||||
|
rules = []
|
||||||
|
|
||||||
|
updated_params = _handle_native_json_schema(
|
||||||
|
provider, model_schema, structured_output_schema, model_parameters, rules
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "json_schema" in updated_params
|
||||||
|
assert "response_format" not in updated_params
|
||||||
|
|
||||||
|
def test_handle_prompt_based_schema_with_system_prompt(self):
|
||||||
|
prompt_messages = [
|
||||||
|
SystemPromptMessage(content="Existing system prompt"),
|
||||||
|
UserPromptMessage(content="User question"),
|
||||||
|
]
|
||||||
|
schema = {"type": "object"}
|
||||||
|
|
||||||
|
result = _handle_prompt_based_schema(prompt_messages, schema)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert isinstance(result[0], SystemPromptMessage)
|
||||||
|
assert "Existing system prompt" in result[0].content
|
||||||
|
assert json.dumps(schema) in result[0].content
|
||||||
|
assert isinstance(result[1], UserPromptMessage)
|
||||||
|
|
||||||
|
def test_handle_prompt_based_schema_without_system_prompt(self):
|
||||||
|
prompt_messages = [UserPromptMessage(content="User question")]
|
||||||
|
schema = {"type": "object"}
|
||||||
|
|
||||||
|
result = _handle_prompt_based_schema(prompt_messages, schema)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert isinstance(result[0], SystemPromptMessage)
|
||||||
|
assert json.dumps(schema) in result[0].content
|
||||||
|
assert isinstance(result[1], UserPromptMessage)
|
||||||
|
|
||||||
|
def test_prepare_schema_for_model_gemini(self):
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.model = "gemini-1.5-pro"
|
||||||
|
schema = {"type": "object", "additionalProperties": False}
|
||||||
|
|
||||||
|
result = _prepare_schema_for_model("google", model_schema, schema)
|
||||||
|
assert "additionalProperties" not in result
|
||||||
|
|
||||||
|
def test_prepare_schema_for_model_ollama(self):
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.model = "llama3"
|
||||||
|
schema = {"type": "object"}
|
||||||
|
|
||||||
|
result = _prepare_schema_for_model("ollama", model_schema, schema)
|
||||||
|
assert result == schema
|
||||||
|
|
||||||
|
def test_prepare_schema_for_model_default(self):
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.model = "gpt-4"
|
||||||
|
schema = {"type": "object"}
|
||||||
|
|
||||||
|
result = _prepare_schema_for_model("openai", model_schema, schema)
|
||||||
|
assert result == {"schema": schema, "name": "llm_response"}
|
||||||
|
|
||||||
|
def test_invoke_llm_with_structured_output_no_stream_native(self):
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.support_structure_output = True
|
||||||
|
model_schema.parameter_rules = [
|
||||||
|
ParameterRule(
|
||||||
|
name="response_format",
|
||||||
|
label={"en_US": ""},
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
help={"en_US": ""},
|
||||||
|
options=[ResponseFormat.JSON_SCHEMA],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
model_schema.model = "gpt-4o"
|
||||||
|
|
||||||
|
model_instance = MagicMock(spec=ModelInstance)
|
||||||
|
mock_result = MagicMock(spec=LLMResult)
|
||||||
|
mock_result.message = AssistantPromptMessage(content='{"result": "success"}')
|
||||||
|
mock_result.model = "gpt-4o"
|
||||||
|
mock_result.usage = LLMUsage.empty_usage()
|
||||||
|
mock_result.system_fingerprint = "fp_native"
|
||||||
|
mock_result.prompt_messages = [UserPromptMessage(content="hi")]
|
||||||
|
|
||||||
|
model_instance.invoke_llm.return_value = mock_result
|
||||||
|
|
||||||
|
result = invoke_llm_with_structured_output(
|
||||||
|
provider="openai",
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_instance=model_instance,
|
||||||
|
prompt_messages=[UserPromptMessage(content="hi")],
|
||||||
|
json_schema={"type": "object"},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||||
|
assert result.structured_output == {"result": "success"}
|
||||||
|
assert result.system_fingerprint == "fp_native"
|
||||||
|
|
||||||
|
def test_invoke_llm_with_structured_output_no_stream_prompt_based(self):
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.support_structure_output = False
|
||||||
|
model_schema.parameter_rules = [
|
||||||
|
ParameterRule(
|
||||||
|
name="response_format",
|
||||||
|
label={"en_US": ""},
|
||||||
|
type=ParameterType.STRING,
|
||||||
|
help={"en_US": ""},
|
||||||
|
options=[ResponseFormat.JSON],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
model_schema.model = "claude-3"
|
||||||
|
|
||||||
|
model_instance = MagicMock(spec=ModelInstance)
|
||||||
|
mock_result = MagicMock(spec=LLMResult)
|
||||||
|
mock_result.message = AssistantPromptMessage(content='{"result": "success"}')
|
||||||
|
mock_result.model = "claude-3"
|
||||||
|
mock_result.usage = LLMUsage.empty_usage()
|
||||||
|
mock_result.system_fingerprint = "fp_prompt"
|
||||||
|
mock_result.prompt_messages = []
|
||||||
|
|
||||||
|
model_instance.invoke_llm.return_value = mock_result
|
||||||
|
|
||||||
|
result = invoke_llm_with_structured_output(
|
||||||
|
provider="anthropic",
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_instance=model_instance,
|
||||||
|
prompt_messages=[UserPromptMessage(content="hi")],
|
||||||
|
json_schema={"type": "object"},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||||
|
assert result.structured_output == {"result": "success"}
|
||||||
|
assert result.system_fingerprint == "fp_prompt"
|
||||||
|
|
||||||
|
def test_invoke_llm_with_structured_output_no_string_error(self):
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.support_structure_output = False
|
||||||
|
model_schema.parameter_rules = []
|
||||||
|
|
||||||
|
model_instance = MagicMock(spec=ModelInstance)
|
||||||
|
mock_result = MagicMock(spec=LLMResult)
|
||||||
|
mock_result.message = AssistantPromptMessage(content=[TextPromptMessageContent(data="not a string")])
|
||||||
|
|
||||||
|
model_instance.invoke_llm.return_value = mock_result
|
||||||
|
|
||||||
|
with pytest.raises(OutputParserError) as excinfo:
|
||||||
|
invoke_llm_with_structured_output(
|
||||||
|
provider="anthropic",
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_instance=model_instance,
|
||||||
|
prompt_messages=[],
|
||||||
|
json_schema={},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert "Failed to parse structured output, LLM result is not a string" in str(excinfo.value)
|
||||||
|
|
||||||
|
def test_invoke_llm_with_structured_output_stream(self):
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.support_structure_output = False
|
||||||
|
model_schema.parameter_rules = []
|
||||||
|
model_schema.model = "gpt-4"
|
||||||
|
|
||||||
|
model_instance = MagicMock(spec=ModelInstance)
|
||||||
|
|
||||||
|
# Mock chunks
|
||||||
|
chunk1 = MagicMock(spec=LLMResultChunk)
|
||||||
|
chunk1.delta = LLMResultChunkDelta(
|
||||||
|
index=0, message=AssistantPromptMessage(content='{"key": '), usage=LLMUsage.empty_usage()
|
||||||
|
)
|
||||||
|
chunk1.prompt_messages = [UserPromptMessage(content="hi")]
|
||||||
|
chunk1.system_fingerprint = "fp1"
|
||||||
|
|
||||||
|
chunk2 = MagicMock(spec=LLMResultChunk)
|
||||||
|
chunk2.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content='"value"}'))
|
||||||
|
chunk2.prompt_messages = [UserPromptMessage(content="hi")]
|
||||||
|
chunk2.system_fingerprint = "fp1"
|
||||||
|
|
||||||
|
chunk3 = MagicMock(spec=LLMResultChunk)
|
||||||
|
chunk3.delta = LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
content=[
|
||||||
|
TextPromptMessageContent(data=" "),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chunk3.prompt_messages = [UserPromptMessage(content="hi")]
|
||||||
|
chunk3.system_fingerprint = "fp1"
|
||||||
|
|
||||||
|
event4 = MagicMock()
|
||||||
|
event4.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=""))
|
||||||
|
|
||||||
|
model_instance.invoke_llm.return_value = [chunk1, chunk2, chunk3, event4]
|
||||||
|
|
||||||
|
generator = invoke_llm_with_structured_output(
|
||||||
|
provider="openai",
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_instance=model_instance,
|
||||||
|
prompt_messages=[UserPromptMessage(content="hi")],
|
||||||
|
json_schema={},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = list(generator)
|
||||||
|
assert len(chunks) == 5
|
||||||
|
assert chunks[-1].structured_output == {"key": "value"}
|
||||||
|
assert chunks[-1].system_fingerprint == "fp1"
|
||||||
|
assert chunks[-1].prompt_messages == [UserPromptMessage(content="hi")]
|
||||||
|
|
||||||
|
def test_invoke_llm_with_structured_output_stream_no_id_events(self):
|
||||||
|
model_schema = MagicMock(spec=AIModelEntity)
|
||||||
|
model_schema.support_structure_output = False
|
||||||
|
model_schema.parameter_rules = []
|
||||||
|
model_schema.model = "gpt-4"
|
||||||
|
|
||||||
|
model_instance = MagicMock(spec=ModelInstance)
|
||||||
|
model_instance.invoke_llm.return_value = []
|
||||||
|
|
||||||
|
generator = invoke_llm_with_structured_output(
|
||||||
|
provider="openai",
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_instance=model_instance,
|
||||||
|
prompt_messages=[],
|
||||||
|
json_schema={},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OutputParserError):
|
||||||
|
list(generator)
|
||||||
|
|
||||||
|
def test_parse_structured_output_empty_string(self):
|
||||||
|
with pytest.raises(OutputParserError):
|
||||||
|
_parse_structured_output("")
|
||||||
|
|
@ -0,0 +1,589 @@
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.app_config.entities import ModelConfig
|
||||||
|
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||||
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
|
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||||
|
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMGenerator:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_model_instance(self):
|
||||||
|
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
|
||||||
|
instance = MagicMock()
|
||||||
|
mock_manager.return_value.get_default_model_instance.return_value = instance
|
||||||
|
mock_manager.return_value.get_model_instance.return_value = instance
|
||||||
|
yield instance
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_config_entity(self):
|
||||||
|
return ModelConfig(provider="openai", name="gpt-4", mode=LLMMode.CHAT, completion_params={"temperature": 0.7})
|
||||||
|
|
||||||
|
def test_generate_conversation_name_success(self, mock_model_instance):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Test Conversation Name"})
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("core.llm_generator.llm_generator.TraceQueueManager") as mock_trace:
|
||||||
|
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||||
|
assert name == "Test Conversation Name"
|
||||||
|
mock_trace.assert_called_once()
|
||||||
|
|
||||||
|
def test_generate_conversation_name_truncated(self, mock_model_instance):
|
||||||
|
long_query = "a" * 2100
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Short Name"})
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||||
|
name = LLMGenerator.generate_conversation_name("tenant_id", long_query)
|
||||||
|
assert name == "Short Name"
|
||||||
|
|
||||||
|
def test_generate_conversation_name_empty_answer(self, mock_model_instance):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = ""
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||||
|
assert name == ""
|
||||||
|
|
||||||
|
def test_generate_conversation_name_json_repair(self, mock_model_instance):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
# Invalid JSON that json_repair can fix
|
||||||
|
mock_response.message.get_text_content.return_value = "{'Your Output': 'Repaired Name'}"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||||
|
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||||
|
assert name == "Repaired Name"
|
||||||
|
|
||||||
|
def test_generate_conversation_name_not_dict_result(self, mock_model_instance):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '["not a dict"]'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||||
|
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||||
|
assert name == "test query"
|
||||||
|
|
||||||
|
def test_generate_conversation_name_no_output_in_dict(self, mock_model_instance):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"something": "else"}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||||
|
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||||
|
assert name == "test query"
|
||||||
|
|
||||||
|
def test_generate_conversation_name_long_output(self, mock_model_instance):
|
||||||
|
long_output = "a" * 100
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = json.dumps({"Your Output": long_output})
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("core.llm_generator.llm_generator.TraceQueueManager"):
|
||||||
|
name = LLMGenerator.generate_conversation_name("tenant_id", "test query")
|
||||||
|
assert len(name) == 78 # 75 + "..."
|
||||||
|
assert name.endswith("...")
|
||||||
|
|
||||||
|
def test_generate_suggested_questions_after_answer_success(self, mock_model_instance):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '["Question 1?", "Question 2?"]'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
|
||||||
|
assert len(questions) == 2
|
||||||
|
assert questions[0] == "Question 1?"
|
||||||
|
|
||||||
|
def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance):
|
||||||
|
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
|
||||||
|
mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed")
|
||||||
|
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
|
||||||
|
assert questions == []
|
||||||
|
|
||||||
|
def test_generate_suggested_questions_after_answer_invoke_error(self, mock_model_instance):
|
||||||
|
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
|
||||||
|
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
|
||||||
|
assert questions == []
|
||||||
|
|
||||||
|
def test_generate_suggested_questions_after_answer_exception(self, mock_model_instance):
|
||||||
|
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||||
|
questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories")
|
||||||
|
assert questions == []
|
||||||
|
|
||||||
|
def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleGeneratePayload(
|
||||||
|
instruction="test instruction", model_config=model_config_entity, no_variable=True
|
||||||
|
)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "Generated Prompt"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||||
|
assert result["prompt"] == "Generated Prompt"
|
||||||
|
assert result["error"] == ""
|
||||||
|
|
||||||
|
def test_generate_rule_config_no_variable_invoke_error(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleGeneratePayload(
|
||||||
|
instruction="test instruction", model_config=model_config_entity, no_variable=True
|
||||||
|
)
|
||||||
|
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||||
|
assert "Failed to generate rule config" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_rule_config_no_variable_exception(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleGeneratePayload(
|
||||||
|
instruction="test instruction", model_config=model_config_entity, no_variable=True
|
||||||
|
)
|
||||||
|
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||||
|
assert "Failed to generate rule config" in result["error"]
|
||||||
|
assert "Random error" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_rule_config_with_variable_success(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleGeneratePayload(
|
||||||
|
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||||
|
)
|
||||||
|
# Mocking 3 calls for invoke_llm
|
||||||
|
mock_res1 = MagicMock()
|
||||||
|
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
|
||||||
|
|
||||||
|
mock_res2 = MagicMock()
|
||||||
|
mock_res2.message.get_text_content.return_value = '"var1", "var2"'
|
||||||
|
|
||||||
|
mock_res3 = MagicMock()
|
||||||
|
mock_res3.message.get_text_content.return_value = "Opening Statement"
|
||||||
|
|
||||||
|
mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, mock_res3]
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||||
|
assert result["prompt"] == "Step 1 Prompt"
|
||||||
|
assert result["variables"] == ["var1", "var2"]
|
||||||
|
assert result["opening_statement"] == "Opening Statement"
|
||||||
|
assert result["error"] == ""
|
||||||
|
|
||||||
|
def test_generate_rule_config_with_variable_step1_error(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleGeneratePayload(
|
||||||
|
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||||
|
)
|
||||||
|
mock_model_instance.invoke_llm.side_effect = InvokeError("Step 1 Failed")
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||||
|
assert "Failed to generate prefix prompt" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_rule_config_with_variable_step2_error(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleGeneratePayload(
|
||||||
|
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||||
|
)
|
||||||
|
mock_res1 = MagicMock()
|
||||||
|
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
|
||||||
|
|
||||||
|
# Step 2 fails
|
||||||
|
mock_model_instance.invoke_llm.side_effect = [mock_res1, InvokeError("Step 2 Failed"), MagicMock()]
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||||
|
assert "Failed to generate variables" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_rule_config_with_variable_step3_error(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleGeneratePayload(
|
||||||
|
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||||
|
)
|
||||||
|
mock_res1 = MagicMock()
|
||||||
|
mock_res1.message.get_text_content.return_value = "Step 1 Prompt"
|
||||||
|
|
||||||
|
mock_res2 = MagicMock()
|
||||||
|
mock_res2.message.get_text_content.return_value = '"var1"'
|
||||||
|
|
||||||
|
# Step 3 fails
|
||||||
|
mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, InvokeError("Step 3 Failed")]
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||||
|
assert "Failed to generate conversation opener" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_rule_config_with_variable_exception(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleGeneratePayload(
|
||||||
|
instruction="test instruction", model_config=model_config_entity, no_variable=False
|
||||||
|
)
|
||||||
|
# Mock any step to throw Exception
|
||||||
|
mock_model_instance.invoke_llm.side_effect = Exception("Unexpected multi-step error")
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_rule_config("tenant_id", payload)
|
||||||
|
assert "Failed to handle unexpected exception" in result["error"]
|
||||||
|
assert "Unexpected multi-step error" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_code_python_success(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleCodeGeneratePayload(
|
||||||
|
instruction="print hello", code_language="python", model_config=model_config_entity
|
||||||
|
)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "print('hello')"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_code("tenant_id", payload)
|
||||||
|
assert result["code"] == "print('hello')"
|
||||||
|
assert result["language"] == "python"
|
||||||
|
|
||||||
|
def test_generate_code_javascript_success(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleCodeGeneratePayload(
|
||||||
|
instruction="console log hello", code_language="javascript", model_config=model_config_entity
|
||||||
|
)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "console.log('hello')"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_code("tenant_id", payload)
|
||||||
|
assert result["code"] == "console.log('hello')"
|
||||||
|
assert result["language"] == "javascript"
|
||||||
|
|
||||||
|
def test_generate_code_invoke_error(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity)
|
||||||
|
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_code("tenant_id", payload)
|
||||||
|
assert "Failed to generate code" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_code_exception(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity)
|
||||||
|
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_code("tenant_id", payload)
|
||||||
|
assert "An unexpected error occurred" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_qa_document_success(self, mock_model_instance):
|
||||||
|
mock_response = MagicMock(spec=LLMResult)
|
||||||
|
mock_response.message = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "QA Document Content"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_qa_document("tenant_id", "query", "English")
|
||||||
|
assert result == "QA Document Content"
|
||||||
|
|
||||||
|
def test_generate_qa_document_type_error(self, mock_model_instance):
|
||||||
|
mock_model_instance.invoke_llm.return_value = "Not an LLMResult"
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="Expected LLMResult when stream=False"):
|
||||||
|
LLMGenerator.generate_qa_document("tenant_id", "query", "English")
|
||||||
|
|
||||||
|
def test_generate_structured_output_success(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"type": "object", "properties": {}}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||||
|
parsed_output = json.loads(result["output"])
|
||||||
|
assert parsed_output["type"] == "object"
|
||||||
|
assert result["error"] == ""
|
||||||
|
|
||||||
|
def test_generate_structured_output_json_repair(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "{'type': 'object'}"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||||
|
parsed_output = json.loads(result["output"])
|
||||||
|
assert parsed_output["type"] == "object"
|
||||||
|
|
||||||
|
def test_generate_structured_output_not_dict_or_list(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "true" # parsed as bool
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||||
|
assert "An unexpected error occurred" in result["error"]
|
||||||
|
assert "Failed to parse structured output" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_structured_output_invoke_error(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity)
|
||||||
|
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed")
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||||
|
assert "Failed to generate JSON Schema" in result["error"]
|
||||||
|
|
||||||
|
def test_generate_structured_output_exception(self, mock_model_instance, model_config_entity):
|
||||||
|
payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity)
|
||||||
|
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||||
|
|
||||||
|
result = LLMGenerator.generate_structured_output("tenant_id", payload)
|
||||||
|
assert "An unexpected error occurred" in result["error"]
|
||||||
|
|
||||||
|
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||||
|
|
||||||
|
# Mock __instruction_modify_common call via invoke_llm
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_legacy(
|
||||||
|
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||||
|
)
|
||||||
|
assert result == {"modified": "prompt"}
|
||||||
|
|
||||||
|
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||||
|
last_run = MagicMock()
|
||||||
|
last_run.query = "q"
|
||||||
|
last_run.answer = "a"
|
||||||
|
last_run.error = "e"
|
||||||
|
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_legacy(
|
||||||
|
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||||
|
)
|
||||||
|
assert result == {"modified": "prompt"}
|
||||||
|
|
||||||
|
def test_instruction_modify_workflow_app_not_found(self):
|
||||||
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
|
mock_session.return_value.query.return_value.where.return_value.first.return_value = None
|
||||||
|
with pytest.raises(ValueError, match="App not found."):
|
||||||
|
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
|
||||||
|
|
||||||
|
def test_instruction_modify_workflow_no_workflow(self):
|
||||||
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
|
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||||
|
workflow_service = MagicMock()
|
||||||
|
workflow_service.get_draft_workflow.return_value = None
|
||||||
|
with pytest.raises(ValueError, match="Workflow not found for the given app model."):
|
||||||
|
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", workflow_service)
|
||||||
|
|
||||||
|
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
|
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||||
|
workflow = MagicMock()
|
||||||
|
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
||||||
|
|
||||||
|
workflow_service = MagicMock()
|
||||||
|
workflow_service.get_draft_workflow.return_value = workflow
|
||||||
|
|
||||||
|
last_run = MagicMock()
|
||||||
|
last_run.node_type = "llm"
|
||||||
|
last_run.status = "s"
|
||||||
|
last_run.error = "e"
|
||||||
|
# Return regular values, not Mocks
|
||||||
|
last_run.execution_metadata_dict = {"agent_log": [{"status": "s", "error": "e", "data": {}}]}
|
||||||
|
last_run.load_full_inputs.return_value = {"in": "val"}
|
||||||
|
|
||||||
|
workflow_service.get_node_last_run.return_value = last_run
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_workflow(
|
||||||
|
"tenant_id",
|
||||||
|
"flow_id",
|
||||||
|
"node_id",
|
||||||
|
"current",
|
||||||
|
"instruction",
|
||||||
|
model_config_entity,
|
||||||
|
"ideal",
|
||||||
|
workflow_service,
|
||||||
|
)
|
||||||
|
assert result == {"modified": "workflow"}
|
||||||
|
|
||||||
|
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
|
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||||
|
workflow = MagicMock()
|
||||||
|
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}}
|
||||||
|
|
||||||
|
workflow_service = MagicMock()
|
||||||
|
workflow_service.get_draft_workflow.return_value = workflow
|
||||||
|
workflow_service.get_node_last_run.return_value = None
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_workflow(
|
||||||
|
"tenant_id",
|
||||||
|
"flow_id",
|
||||||
|
"node_id",
|
||||||
|
"current",
|
||||||
|
"instruction",
|
||||||
|
model_config_entity,
|
||||||
|
"ideal",
|
||||||
|
workflow_service,
|
||||||
|
)
|
||||||
|
assert result == {"modified": "fallback"}
|
||||||
|
|
||||||
|
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
|
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||||
|
workflow = MagicMock()
|
||||||
|
# Cause exception in node_type logic
|
||||||
|
workflow.graph_dict = {"graph": {"nodes": []}}
|
||||||
|
|
||||||
|
workflow_service = MagicMock()
|
||||||
|
workflow_service.get_draft_workflow.return_value = workflow
|
||||||
|
workflow_service.get_node_last_run.return_value = None
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_workflow(
|
||||||
|
"tenant_id",
|
||||||
|
"flow_id",
|
||||||
|
"node_id",
|
||||||
|
"current",
|
||||||
|
"instruction",
|
||||||
|
model_config_entity,
|
||||||
|
"ideal",
|
||||||
|
workflow_service,
|
||||||
|
)
|
||||||
|
assert result == {"modified": "fallback"}
|
||||||
|
|
||||||
|
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
|
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||||
|
workflow = MagicMock()
|
||||||
|
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
||||||
|
|
||||||
|
workflow_service = MagicMock()
|
||||||
|
workflow_service.get_draft_workflow.return_value = workflow
|
||||||
|
|
||||||
|
last_run = MagicMock()
|
||||||
|
last_run.node_type = "llm"
|
||||||
|
last_run.status = "s"
|
||||||
|
last_run.error = "e"
|
||||||
|
# Return regular empty list, not a Mock
|
||||||
|
last_run.execution_metadata_dict = {"agent_log": []}
|
||||||
|
last_run.load_full_inputs.return_value = {}
|
||||||
|
|
||||||
|
workflow_service.get_node_last_run.return_value = last_run
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_workflow(
|
||||||
|
"tenant_id",
|
||||||
|
"flow_id",
|
||||||
|
"node_id",
|
||||||
|
"current",
|
||||||
|
"instruction",
|
||||||
|
model_config_entity,
|
||||||
|
"ideal",
|
||||||
|
workflow_service,
|
||||||
|
)
|
||||||
|
assert result == {"modified": "workflow"}
|
||||||
|
|
||||||
|
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
|
||||||
|
# Testing placeholders replacement via instruction_modify_legacy for convenience
|
||||||
|
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"ok": true}'
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
instruction = "Test {{#last_run#}} and {{#current#}} and {{#error_message#}}"
|
||||||
|
LLMGenerator.instruction_modify_legacy(
|
||||||
|
"tenant_id", "flow_id", "current_val", instruction, model_config_entity, "ideal"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the call to invoke_llm contains replaced instruction
|
||||||
|
args, kwargs = mock_model_instance.invoke_llm.call_args
|
||||||
|
prompt_messages = kwargs["prompt_messages"]
|
||||||
|
user_msg = prompt_messages[1].content
|
||||||
|
user_msg_dict = json.loads(user_msg)
|
||||||
|
assert "null" in user_msg_dict["instruction"] # because last_run is None and current is current_val etc.
|
||||||
|
assert "current_val" in user_msg_dict["instruction"]
|
||||||
|
|
||||||
|
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "No braces here"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
result = LLMGenerator.instruction_modify_legacy(
|
||||||
|
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||||
|
)
|
||||||
|
assert "An unexpected error occurred" in result["error"]
|
||||||
|
assert "Could not find a valid JSON object" in result["error"]
|
||||||
|
|
||||||
|
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
result = LLMGenerator.instruction_modify_legacy(
|
||||||
|
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||||
|
)
|
||||||
|
# The exception message is "Expected a JSON object, but got list"
|
||||||
|
assert "An unexpected error occurred" in result["error"]
|
||||||
|
|
||||||
|
def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager:
|
||||||
|
instance = MagicMock()
|
||||||
|
mock_manager.return_value.get_model_instance.return_value = instance
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = '{"ok": true}'
|
||||||
|
instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
|
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||||
|
workflow = MagicMock()
|
||||||
|
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}}
|
||||||
|
|
||||||
|
workflow_service = MagicMock()
|
||||||
|
workflow_service.get_draft_workflow.return_value = workflow
|
||||||
|
workflow_service.get_node_last_run.return_value = None
|
||||||
|
|
||||||
|
LLMGenerator.instruction_modify_workflow(
|
||||||
|
"tenant_id",
|
||||||
|
"flow_id",
|
||||||
|
"node_id",
|
||||||
|
"current",
|
||||||
|
"instruction",
|
||||||
|
model_config_entity,
|
||||||
|
"ideal",
|
||||||
|
workflow_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||||
|
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_legacy(
|
||||||
|
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||||
|
)
|
||||||
|
assert "Failed to generate code" in result["error"]
|
||||||
|
|
||||||
|
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||||
|
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_legacy(
|
||||||
|
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||||
|
)
|
||||||
|
assert "An unexpected error occurred" in result["error"]
|
||||||
|
|
||||||
|
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
|
||||||
|
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.message.get_text_content.return_value = "No JSON here"
|
||||||
|
mock_model_instance.invoke_llm.return_value = mock_response
|
||||||
|
|
||||||
|
result = LLMGenerator.instruction_modify_legacy(
|
||||||
|
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||||
|
)
|
||||||
|
assert "An unexpected error occurred" in result["error"]
|
||||||
Loading…
Reference in New Issue