test: add UTs for api/services recommend_app, tools, workflow (#32645)

Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Dev Sharma 2026-03-12 09:07:03 +05:30 committed by GitHub
parent 07e19c0748
commit 7d2054d4f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1251 additions and 16 deletions

View File

@ -0,0 +1,91 @@
import json
from unittest.mock import MagicMock, patch
import pytest
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
from services.recommend_app.recommend_app_type import RecommendAppType
SAMPLE_BUILTIN_DATA = {
"recommended_apps": {
"en-US": {"categories": ["writing"], "apps": [{"id": "app-1"}]},
"zh-Hans": {"categories": ["search"], "apps": [{"id": "app-2"}]},
},
"app_details": {
"app-1": {"id": "app-1", "name": "Writer", "mode": "chat"},
"app-2": {"id": "app-2", "name": "Searcher", "mode": "workflow"},
},
}
@pytest.fixture(autouse=True)
def _reset_cache():
BuildInRecommendAppRetrieval.builtin_data = None
yield
BuildInRecommendAppRetrieval.builtin_data = None
class TestBuildInRecommendAppRetrieval:
def test_get_type(self):
retrieval = BuildInRecommendAppRetrieval()
assert retrieval.get_type() == RecommendAppType.BUILDIN
def test_get_recommended_apps_and_categories_delegates(self):
with patch.object(
BuildInRecommendAppRetrieval,
"fetch_recommended_apps_from_builtin",
return_value={"apps": []},
) as mock_fetch:
retrieval = BuildInRecommendAppRetrieval()
result = retrieval.get_recommended_apps_and_categories("en-US")
mock_fetch.assert_called_once_with("en-US")
assert result == {"apps": []}
def test_get_recommend_app_detail_delegates(self):
with patch.object(
BuildInRecommendAppRetrieval,
"fetch_recommended_app_detail_from_builtin",
return_value={"id": "app-1"},
) as mock_fetch:
retrieval = BuildInRecommendAppRetrieval()
result = retrieval.get_recommend_app_detail("app-1")
mock_fetch.assert_called_once_with("app-1")
assert result == {"id": "app-1"}
def test_get_builtin_data_reads_json_and_caches(self, tmp_path):
json_file = tmp_path / "constants" / "recommended_apps.json"
json_file.parent.mkdir(parents=True)
json_file.write_text(json.dumps(SAMPLE_BUILTIN_DATA))
mock_app = MagicMock()
mock_app.root_path = str(tmp_path)
with patch(
"services.recommend_app.buildin.buildin_retrieval.current_app",
mock_app,
):
first = BuildInRecommendAppRetrieval._get_builtin_data()
second = BuildInRecommendAppRetrieval._get_builtin_data()
assert first == SAMPLE_BUILTIN_DATA
assert first is second
def test_fetch_recommended_apps_from_builtin(self):
BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("en-US")
assert result == SAMPLE_BUILTIN_DATA["recommended_apps"]["en-US"]
def test_fetch_recommended_apps_from_builtin_missing_language(self):
BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("fr-FR")
assert result == {}
def test_fetch_recommended_app_detail_from_builtin(self):
BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("app-1")
assert result == {"id": "app-1", "name": "Writer", "mode": "chat"}
def test_fetch_recommended_app_detail_from_builtin_missing(self):
BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("nonexistent")
assert result is None

View File

@ -0,0 +1,145 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
from services.recommend_app.recommend_app_type import RecommendAppType
class TestDatabaseRecommendAppRetrieval:
def test_get_type(self):
assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE
def test_get_recommended_apps_delegates(self):
with patch.object(
DatabaseRecommendAppRetrieval,
"fetch_recommended_apps_from_db",
return_value={"recommended_apps": [], "categories": []},
) as mock_fetch:
result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
mock_fetch.assert_called_once_with("en-US")
assert result == {"recommended_apps": [], "categories": []}
def test_get_recommend_app_detail_delegates(self):
with patch.object(
DatabaseRecommendAppRetrieval,
"fetch_recommended_app_detail_from_db",
return_value={"id": "app-1"},
) as mock_fetch:
result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1")
mock_fetch.assert_called_once_with("app-1")
assert result == {"id": "app-1"}
class TestFetchRecommendedAppsFromDb:
def _make_recommended_app(self, app_id, category, is_public=True, has_site=True):
site = (
SimpleNamespace(
description="desc",
copyright="copy",
privacy_policy="pp",
custom_disclaimer="cd",
)
if has_site
else None
)
app = (
SimpleNamespace(is_public=is_public, site=site)
if is_public
else SimpleNamespace(is_public=False, site=site)
)
return SimpleNamespace(
id=f"rec-{app_id}",
app=app,
app_id=app_id,
category=category,
position=1,
is_listed=True,
)
@patch("services.recommend_app.database.database_retrieval.db")
def test_returns_apps_and_sorted_categories(self, mock_db):
rec1 = self._make_recommended_app("a1", "writing")
rec2 = self._make_recommended_app("a2", "assistant")
mock_db.session.scalars.return_value.all.return_value = [rec1, rec2]
result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
assert len(result["recommended_apps"]) == 2
assert result["categories"] == ["assistant", "writing"]
@patch("services.recommend_app.database.database_retrieval.db")
def test_falls_back_to_default_language_when_empty(self, mock_db):
mock_db.session.scalars.return_value.all.side_effect = [
[],
[self._make_recommended_app("a1", "chat")],
]
result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR")
assert len(result["recommended_apps"]) == 1
assert mock_db.session.scalars.call_count == 2
@patch("services.recommend_app.database.database_retrieval.db")
def test_skips_non_public_apps(self, mock_db):
rec = self._make_recommended_app("a1", "chat", is_public=False)
mock_db.session.scalars.return_value.all.return_value = [rec]
result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
assert result["recommended_apps"] == []
@patch("services.recommend_app.database.database_retrieval.db")
def test_skips_apps_without_site(self, mock_db):
rec = self._make_recommended_app("a1", "chat", has_site=False)
mock_db.session.scalars.return_value.all.return_value = [rec]
result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
assert result["recommended_apps"] == []
class TestFetchRecommendedAppDetailFromDb:
@patch("services.recommend_app.database.database_retrieval.db")
def test_returns_none_when_not_listed(self, mock_db):
mock_db.session.query.return_value.where.return_value.first.return_value = None
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
assert result is None
@patch("services.recommend_app.database.database_retrieval.AppDslService")
@patch("services.recommend_app.database.database_retrieval.db")
def test_returns_none_when_app_not_public(self, mock_db, mock_dsl):
rec_chain = MagicMock()
rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1")
app_chain = MagicMock()
app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False)
mock_db.session.query.side_effect = [rec_chain, app_chain]
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
assert result is None
@patch("services.recommend_app.database.database_retrieval.AppDslService")
@patch("services.recommend_app.database.database_retrieval.db")
def test_returns_detail_on_success(self, mock_db, mock_dsl):
app_model = SimpleNamespace(
id="app-1",
name="My App",
icon="icon.png",
icon_background="#fff",
mode="chat",
is_public=True,
)
rec_chain = MagicMock()
rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1")
app_chain = MagicMock()
app_chain.where.return_value.first.return_value = app_model
mock_db.session.query.side_effect = [rec_chain, app_chain]
mock_dsl.export_dsl.return_value = "exported_yaml"
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
assert result["id"] == "app-1"
assert result["name"] == "My App"
assert result["export_data"] == "exported_yaml"

View File

@ -0,0 +1,28 @@
import pytest
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval
class TestRecommendAppRetrievalFactory:
@pytest.mark.parametrize(
("mode", "expected_class"),
[
("remote", RemoteRecommendAppRetrieval),
("builtin", BuildInRecommendAppRetrieval),
("db", DatabaseRecommendAppRetrieval),
],
)
def test_factory_returns_correct_class(self, mode, expected_class):
result = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)
assert result is expected_class
def test_factory_raises_for_unknown_mode(self):
with pytest.raises(ValueError, match="invalid fetch recommended apps mode"):
RecommendAppRetrievalFactory.get_recommend_app_factory("invalid_mode")
def test_get_buildin_recommend_app_retrieval(self):
result = RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval()
assert result is BuildInRecommendAppRetrieval

View File

@ -0,0 +1,18 @@
from services.recommend_app.recommend_app_type import RecommendAppType
def test_enum_values():
assert RecommendAppType.REMOTE == "remote"
assert RecommendAppType.BUILDIN == "builtin"
assert RecommendAppType.DATABASE == "db"
def test_enum_membership():
assert "remote" in RecommendAppType.__members__.values()
assert "builtin" in RecommendAppType.__members__.values()
assert "db" in RecommendAppType.__members__.values()
def test_enum_is_str():
for member in RecommendAppType:
assert isinstance(member, str)

View File

@ -0,0 +1,120 @@
from unittest.mock import MagicMock, patch
import pytest
from services.recommend_app.recommend_app_type import RecommendAppType
from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval
class TestRemoteRecommendAppRetrieval:
def test_get_type(self):
assert RemoteRecommendAppRetrieval().get_type() == RecommendAppType.REMOTE
@patch.object(
RemoteRecommendAppRetrieval,
"fetch_recommended_app_detail_from_dify_official",
return_value={"id": "app-1"},
)
def test_get_recommend_app_detail_success(self, mock_fetch):
result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1")
assert result == {"id": "app-1"}
mock_fetch.assert_called_once_with("app-1")
@patch(
"services.recommend_app.remote.remote_retrieval"
".BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin",
return_value={"id": "fallback"},
)
@patch.object(
RemoteRecommendAppRetrieval,
"fetch_recommended_app_detail_from_dify_official",
side_effect=ConnectionError("timeout"),
)
def test_get_recommend_app_detail_falls_back_on_error(self, mock_fetch, mock_builtin):
result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1")
assert result == {"id": "fallback"}
mock_builtin.assert_called_once_with("app-1")
@patch.object(
RemoteRecommendAppRetrieval,
"fetch_recommended_apps_from_dify_official",
return_value={"recommended_apps": [], "categories": []},
)
def test_get_recommended_apps_success(self, mock_fetch):
result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
assert result == {"recommended_apps": [], "categories": []}
@patch(
"services.recommend_app.remote.remote_retrieval"
".BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin",
return_value={"recommended_apps": [{"id": "builtin"}]},
)
@patch.object(
RemoteRecommendAppRetrieval,
"fetch_recommended_apps_from_dify_official",
side_effect=ValueError("server error"),
)
def test_get_recommended_apps_falls_back_on_error(self, mock_fetch, mock_builtin):
result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
assert result == {"recommended_apps": [{"id": "builtin"}]}
class TestFetchFromDifyOfficial:
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_detail_returns_json_on_200(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = {"id": "app-1", "name": "Test"}
mock_get.return_value = mock_response
result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1")
assert result == {"id": "app-1", "name": "Test"}
mock_get.assert_called_once()
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_detail_returns_none_on_non_200(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_get.return_value = MagicMock(status_code=404)
result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1")
assert result is None
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_apps_returns_sorted_categories_on_200(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = {
"recommended_apps": [],
"categories": ["writing", "agent", "chat"],
}
mock_get.return_value = mock_response
result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
assert result["categories"] == ["agent", "chat", "writing"]
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_apps_raises_on_non_200(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_get.return_value = MagicMock(status_code=500)
with pytest.raises(ValueError, match="fetch recommended apps failed"):
RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_apps_without_categories_key(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = {"recommended_apps": []}
mock_get.return_value = mock_response
result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
assert "categories" not in result

View File

@ -0,0 +1,455 @@
from unittest.mock import MagicMock, patch
import pytest
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
MODULE = "services.tools.builtin_tools_manage_service"
def _mock_session(mock_session_cls):
"""Helper: set up a Session context manager mock and return the inner session."""
session = MagicMock()
mock_session_cls.return_value.__enter__ = MagicMock(return_value=session)
mock_session_cls.return_value.__exit__ = MagicMock(return_value=False)
return session
class TestDeleteCustomOauthClientParams:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_deletes_and_returns_success(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
assert result == {"result": "success"}
session.query.return_value.filter_by.return_value.delete.assert_called_once()
session.commit.assert_called_once()
class TestListBuiltinToolProviderTools:
@patch(f"{MODULE}.ToolLabelManager")
@patch(f"{MODULE}.ToolTransformService")
@patch(f"{MODULE}.ToolManager")
def test_transforms_each_tool(self, mock_manager, mock_transform, mock_labels):
mock_controller = MagicMock()
mock_controller.get_tools.return_value = [MagicMock(), MagicMock()]
mock_manager.get_builtin_provider.return_value = mock_controller
mock_transform.convert_tool_entity_to_api_entity.return_value = MagicMock()
result = BuiltinToolManageService.list_builtin_tool_provider_tools("tenant-1", "google")
assert len(result) == 2
@patch(f"{MODULE}.ToolLabelManager")
@patch(f"{MODULE}.ToolTransformService")
@patch(f"{MODULE}.ToolManager")
def test_empty_tools(self, mock_manager, mock_transform, mock_labels):
mock_controller = MagicMock()
mock_controller.get_tools.return_value = []
mock_manager.get_builtin_provider.return_value = mock_controller
assert BuiltinToolManageService.list_builtin_tool_provider_tools("t", "p") == []
class TestGetBuiltinToolProviderInfo:
@patch(f"{MODULE}.ToolTransformService")
@patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider")
@patch(f"{MODULE}.ToolManager")
def test_raises_when_not_found(self, mock_manager, mock_get, mock_transform):
mock_get.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
BuiltinToolManageService.get_builtin_tool_provider_info("t", "no")
@patch(f"{MODULE}.ToolTransformService")
@patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider")
@patch(f"{MODULE}.ToolManager")
def test_clears_original_credentials(self, mock_manager, mock_get, mock_transform):
mock_get.return_value = MagicMock()
entity = MagicMock()
mock_transform.builtin_provider_to_user_provider.return_value = entity
result = BuiltinToolManageService.get_builtin_tool_provider_info("t", "google")
assert result.original_credentials == {}
class TestListBuiltinProviderCredentialsSchema:
@patch(f"{MODULE}.ToolManager")
def test_returns_schema(self, mock_manager):
mock_manager.get_builtin_provider.return_value.get_credentials_schema_by_type.return_value = [{"f": "k"}]
result = BuiltinToolManageService.list_builtin_provider_credentials_schema("g", "api_key", "t")
assert result == [{"f": "k"}]
class TestGetBuiltinToolProviderIcon:
@patch(f"{MODULE}.Path")
@patch(f"{MODULE}.ToolManager")
def test_returns_bytes_and_mime(self, mock_manager, mock_path):
mock_manager.get_hardcoded_provider_icon.return_value = ("/icon.svg", "image/svg+xml")
mock_path.return_value.read_bytes.return_value = b"<svg/>"
icon, mime = BuiltinToolManageService.get_builtin_tool_provider_icon("google")
assert icon == b"<svg/>"
assert mime == "image/svg+xml"
class TestIsOauthSystemClientExists:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_true_when_exists(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = MagicMock()
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_false_when_missing(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
class TestIsOauthCustomClientEnabled:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_true_when_enabled(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True)
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_false_when_none(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
class TestDeleteBuiltinToolProvider:
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc):
session = _mock_session(mock_session_cls)
session.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc):
session = _mock_session(mock_session_cls)
db_provider = MagicMock()
session.query.return_value.where.return_value.first.return_value = db_provider
mock_cache = MagicMock()
mock_enc.return_value = (MagicMock(), mock_cache)
result = BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "c")
assert result == {"result": "success"}
session.delete.assert_called_once_with(db_provider)
session.commit.assert_called_once()
mock_cache.delete.assert_called_once()
class TestSetDefaultProvider:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="provider not found"):
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_sets_default_and_clears_old(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
target = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = target
result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
assert result == {"result": "success"}
assert target.is_default is True
session.commit.assert_called_once()
class TestUpdateBuiltinToolProvider:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.CredentialType")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc):
session = _mock_session(mock_session_cls)
db_provider = MagicMock(credential_type="api_key", credentials="{}")
session.query.return_value.where.return_value.first.return_value = db_provider
mock_cred_instance = MagicMock()
mock_cred_instance.is_editable.return_value = True
mock_cred_instance.is_validate_allowed.return_value = False
mock_cred_type.of.return_value = mock_cred_instance
mock_controller = MagicMock(need_credentials=True)
mock_tm.get_builtin_provider.return_value = mock_controller
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"key": "old"}
mock_encrypter.encrypt.return_value = {"key": "new"}
mock_cache = MagicMock()
mock_enc.return_value = (mock_encrypter, mock_cache)
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
assert result == {"result": "success"}
session.commit.assert_called_once()
mock_cache.delete.assert_called_once()
class TestGetOauthClientSchema:
@patch(f"{MODULE}.BuiltinToolManageService.get_custom_oauth_client_params", return_value={})
@patch(f"{MODULE}.BuiltinToolManageService.is_oauth_system_client_exists", return_value=False)
@patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=True)
@patch(f"{MODULE}.dify_config")
@patch(f"{MODULE}.PluginService")
@patch(f"{MODULE}.ToolManager")
def test_returns_schema_dict(self, mock_tm, mock_plugin, mock_config, mock_enabled, mock_sys, mock_params):
mock_config.CONSOLE_API_URL = "https://api.example.com"
mock_controller = MagicMock()
mock_controller.get_oauth_client_schema.return_value = []
mock_tm.get_builtin_provider.return_value = mock_controller
result = BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema("t", "google")
assert "schema" in result
assert result["is_oauth_custom_client_enabled"] is True
assert "redirect_uri" in result
class TestGetOauthClient:
@patch(f"{MODULE}.PluginService")
@patch(f"{MODULE}.create_provider_encrypter")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_returns_user_client_params_when_exists(
self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin
):
session = _mock_session(mock_session_cls)
mock_controller = MagicMock()
mock_controller.get_oauth_client_schema.return_value = []
mock_tm.get_builtin_provider.return_value = mock_controller
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"client_id": "id", "client_secret": "secret"}
mock_create_enc.return_value = (mock_encrypter, MagicMock())
user_client = MagicMock(oauth_params='{"encrypted": "data"}')
session.query.return_value.filter_by.return_value.first.return_value = user_client
result = BuiltinToolManageService.get_oauth_client("t", "google")
assert result == {"client_id": "id", "client_secret": "secret"}
@patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"})
@patch(f"{MODULE}.PluginService")
@patch(f"{MODULE}.create_provider_encrypter")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_falls_back_to_system_client(
self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin, mock_decrypt
):
session = _mock_session(mock_session_cls)
mock_controller = MagicMock()
mock_controller.get_oauth_client_schema.return_value = []
mock_tm.get_builtin_provider.return_value = mock_controller
mock_create_enc.return_value = (MagicMock(), MagicMock())
system_client = MagicMock(encrypted_oauth_params="enc")
session.query.return_value.filter_by.return_value.first.side_effect = [
None, # user client
system_client, # system client
]
result = BuiltinToolManageService.get_oauth_client("t", "google")
assert result == {"sys_key": "sys_val"}
class TestSaveCustomOauthClientParams:
def test_returns_early_when_no_params(self):
result = BuiltinToolManageService.save_custom_oauth_client_params("t", "p")
assert result == {"result": "success"}
@patch(f"{MODULE}.ToolManager")
def test_raises_when_provider_not_found(self, mock_tm):
mock_tm.get_builtin_provider.return_value = None
with pytest.raises((ValueError, Exception), match="not found|Provider"):
BuiltinToolManageService.save_custom_oauth_client_params("t", "p", enable_oauth_custom_client=True)
class TestGetCustomOauthClientParams:
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_returns_empty_when_none(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
assert result == {}
class TestGetBuiltinToolProviderCredentialInfo:
@patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=False)
@patch(f"{MODULE}.BuiltinToolManageService.get_builtin_tool_provider_credentials", return_value=[])
@patch(f"{MODULE}.ToolManager")
def test_returns_credential_info(self, mock_tm, mock_creds, mock_oauth):
mock_tm.get_builtin_provider.return_value.get_supported_credential_types.return_value = ["api-key"]
result = BuiltinToolManageService.get_builtin_tool_provider_credential_info("t", "google")
assert result.credentials == []
assert result.supported_credential_types == ["api-key"]
assert result.is_oauth_custom_client_enabled is False
class TestGetBuiltinToolProviderCredentials:
@patch(f"{MODULE}.db")
def test_returns_empty_when_no_providers(self, mock_db):
mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None)
mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google")
assert result == []
@patch(f"{MODULE}.ToolTransformService")
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.db")
def test_returns_credential_entities(self, mock_db, mock_tm, mock_enc, mock_transform):
mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None)
mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
provider = MagicMock(provider="google", is_default=False)
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider]
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"key": "decrypted"}
mock_encrypter.mask_plugin_credentials.return_value = {"key": "***"}
mock_enc.return_value = (mock_encrypter, MagicMock())
credential_entity = MagicMock()
mock_transform.convert_builtin_provider_to_credential_entity.return_value = credential_entity
result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google")
assert len(result) == 1
assert result[0] is credential_entity
assert provider.is_default is True
class TestGetBuiltinProvider:
@patch(f"{MODULE}.ToolProviderID")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_returns_none_when_not_found(self, mock_db, mock_session_cls, mock_prov_id):
session = _mock_session(mock_session_cls)
mock_prov_id.return_value.provider_name = "google"
mock_prov_id.return_value.organization = "langgenius"
session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
result = BuiltinToolManageService.get_builtin_provider("google", "t")
assert result is None
@patch(f"{MODULE}.ToolProviderID")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_returns_provider_for_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id):
session = _mock_session(mock_session_cls)
mock_prov_id.return_value.provider_name = "google"
mock_prov_id.return_value.organization = "langgenius"
db_provider = MagicMock(provider="google")
mock_prov_id_result = MagicMock()
mock_prov_id_result.to_string.return_value = "langgenius/google/google"
def prov_id_side_effect(name):
m = MagicMock()
m.provider_name = "google"
m.organization = "langgenius"
m.to_string.return_value = "langgenius/google/google"
m.plugin_id = "langgenius/google"
return m
mock_prov_id.side_effect = prov_id_side_effect
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
result = BuiltinToolManageService.get_builtin_provider("google", "t")
assert result is db_provider
@patch(f"{MODULE}.ToolProviderID")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_returns_provider_for_non_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id):
session = _mock_session(mock_session_cls)
def prov_id_side_effect(name):
m = MagicMock()
m.provider_name = "custom-tool"
m.organization = "third-party"
m.to_string.return_value = "third-party/custom/custom-tool"
m.plugin_id = "third-party/custom"
return m
mock_prov_id.side_effect = prov_id_side_effect
db_provider = MagicMock(provider="third-party/custom/custom-tool")
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
assert result is db_provider
@patch(f"{MODULE}.ToolProviderID")
@patch(f"{MODULE}.Session")
@patch(f"{MODULE}.db")
def test_falls_back_on_exception(self, mock_db, mock_session_cls, mock_prov_id):
session = _mock_session(mock_session_cls)
mock_prov_id.side_effect = Exception("parse error")
fallback = MagicMock()
session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback
result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")
assert result is fallback

View File

@ -0,0 +1,21 @@
from services.tools.tool_labels_service import ToolLabelsService
def test_list_tool_labels_returns_default_labels():
result = ToolLabelsService.list_tool_labels()
assert isinstance(result, list)
assert len(result) > 0
def test_list_tool_labels_items_are_tool_labels():
from core.tools.entities.tool_entities import ToolLabel
result = ToolLabelsService.list_tool_labels()
for label in result:
assert isinstance(label, ToolLabel)
def test_list_tool_labels_matches_default_values():
from core.tools.entities.values import default_tool_labels
assert ToolLabelsService.list_tool_labels() is default_tool_labels

View File

@ -0,0 +1,40 @@
from unittest.mock import MagicMock, patch
from services.tools.tools_manage_service import ToolCommonService
class TestToolCommonService:
@patch("services.tools.tools_manage_service.ToolTransformService")
@patch("services.tools.tools_manage_service.ToolManager")
def test_list_tool_providers_transforms_and_returns(self, mock_manager, mock_transform):
mock_provider1 = MagicMock()
mock_provider1.to_dict.return_value = {"name": "provider1"}
mock_provider2 = MagicMock()
mock_provider2.to_dict.return_value = {"name": "provider2"}
mock_manager.list_providers_from_api.return_value = [mock_provider1, mock_provider2]
result = ToolCommonService.list_tool_providers("user-1", "tenant-1")
mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", None)
assert mock_transform.repack_provider.call_count == 2
assert result == [{"name": "provider1"}, {"name": "provider2"}]
@patch("services.tools.tools_manage_service.ToolTransformService")
@patch("services.tools.tools_manage_service.ToolManager")
def test_list_tool_providers_with_type_filter(self, mock_manager, mock_transform):
mock_manager.list_providers_from_api.return_value = []
result = ToolCommonService.list_tool_providers("user-1", "tenant-1", typ="builtin")
mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", "builtin")
assert result == []
@patch("services.tools.tools_manage_service.ToolTransformService")
@patch("services.tools.tools_manage_service.ToolManager")
def test_list_tool_providers_empty(self, mock_manager, mock_transform):
mock_manager.list_providers_from_api.return_value = []
result = ToolCommonService.list_tool_providers("u", "t")
assert result == []
mock_transform.repack_provider.assert_not_called()

View File

@ -0,0 +1,110 @@
from unittest.mock import patch
import pytest
from services.workflow.queue_dispatcher import (
BaseQueueDispatcher,
ProfessionalQueueDispatcher,
QueueDispatcherManager,
QueuePriority,
SandboxQueueDispatcher,
TeamQueueDispatcher,
)
class TestQueuePriority:
def test_priority_values(self):
assert QueuePriority.PROFESSIONAL == "workflow_professional"
assert QueuePriority.TEAM == "workflow_team"
assert QueuePriority.SANDBOX == "workflow_sandbox"
class TestDispatchers:
def test_professional_dispatcher(self):
d = ProfessionalQueueDispatcher()
assert d.get_queue_name() == QueuePriority.PROFESSIONAL
assert d.get_priority() == 100
def test_team_dispatcher(self):
d = TeamQueueDispatcher()
assert d.get_queue_name() == QueuePriority.TEAM
assert d.get_priority() == 50
def test_sandbox_dispatcher(self):
d = SandboxQueueDispatcher()
assert d.get_queue_name() == QueuePriority.SANDBOX
assert d.get_priority() == 10
def test_base_dispatcher_is_abstract(self):
with pytest.raises(TypeError):
BaseQueueDispatcher()
class TestQueueDispatcherManager:
@patch("services.workflow.queue_dispatcher.BillingService")
@patch("services.workflow.queue_dispatcher.dify_config")
def test_billing_enabled_professional_plan(self, mock_config, mock_billing):
mock_config.BILLING_ENABLED = True
mock_billing.get_info.return_value = {"subscription": {"plan": "professional"}}
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
assert isinstance(dispatcher, ProfessionalQueueDispatcher)
@patch("services.workflow.queue_dispatcher.BillingService")
@patch("services.workflow.queue_dispatcher.dify_config")
def test_billing_enabled_team_plan(self, mock_config, mock_billing):
mock_config.BILLING_ENABLED = True
mock_billing.get_info.return_value = {"subscription": {"plan": "team"}}
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
assert isinstance(dispatcher, TeamQueueDispatcher)
@patch("services.workflow.queue_dispatcher.BillingService")
@patch("services.workflow.queue_dispatcher.dify_config")
def test_billing_enabled_sandbox_plan(self, mock_config, mock_billing):
mock_config.BILLING_ENABLED = True
mock_billing.get_info.return_value = {"subscription": {"plan": "sandbox"}}
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
assert isinstance(dispatcher, SandboxQueueDispatcher)
@patch("services.workflow.queue_dispatcher.BillingService")
@patch("services.workflow.queue_dispatcher.dify_config")
def test_billing_enabled_unknown_plan_defaults_to_sandbox(self, mock_config, mock_billing):
mock_config.BILLING_ENABLED = True
mock_billing.get_info.return_value = {"subscription": {"plan": "enterprise"}}
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
assert isinstance(dispatcher, SandboxQueueDispatcher)
@patch("services.workflow.queue_dispatcher.BillingService")
@patch("services.workflow.queue_dispatcher.dify_config")
def test_billing_enabled_service_failure_defaults_to_sandbox(self, mock_config, mock_billing):
mock_config.BILLING_ENABLED = True
mock_billing.get_info.side_effect = Exception("billing unavailable")
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
assert isinstance(dispatcher, SandboxQueueDispatcher)
@patch("services.workflow.queue_dispatcher.dify_config")
def test_billing_disabled_defaults_to_team(self, mock_config):
mock_config.BILLING_ENABLED = False
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
assert isinstance(dispatcher, TeamQueueDispatcher)
@patch("services.workflow.queue_dispatcher.BillingService")
@patch("services.workflow.queue_dispatcher.dify_config")
def test_missing_subscription_key_defaults_to_sandbox(self, mock_config, mock_billing):
mock_config.BILLING_ENABLED = True
mock_billing.get_info.return_value = {}
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
assert isinstance(dispatcher, SandboxQueueDispatcher)

View File

@ -0,0 +1,89 @@
import pytest
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
class TestSchedulerCommand:
def test_enum_values(self):
assert SchedulerCommand.RESOURCE_LIMIT_REACHED == "resource_limit_reached"
assert SchedulerCommand.NONE == "none"
def test_enum_is_str(self):
for member in SchedulerCommand:
assert isinstance(member, str)
class TestCFSPlanScheduler:
def test_stores_plan(self):
plan = WorkflowScheduleCFSPlanEntity(
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop,
granularity=-1,
)
class ConcretePlanScheduler(CFSPlanScheduler):
def can_schedule(self):
return SchedulerCommand.NONE
scheduler = ConcretePlanScheduler(plan)
assert scheduler.plan is plan
assert scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.Nop
assert scheduler.plan.granularity == -1
def test_cannot_instantiate_abstract(self):
plan = WorkflowScheduleCFSPlanEntity(
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
granularity=10,
)
with pytest.raises(TypeError):
CFSPlanScheduler(plan)
def test_concrete_subclass_can_schedule(self):
plan = WorkflowScheduleCFSPlanEntity(
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
granularity=5,
)
class TimedScheduler(CFSPlanScheduler):
def can_schedule(self):
if self.plan.granularity > 0:
return SchedulerCommand.NONE
return SchedulerCommand.RESOURCE_LIMIT_REACHED
scheduler = TimedScheduler(plan)
assert scheduler.can_schedule() == SchedulerCommand.NONE
def test_concrete_subclass_resource_limit(self):
plan = WorkflowScheduleCFSPlanEntity(
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
granularity=-1,
)
class TimedScheduler(CFSPlanScheduler):
def can_schedule(self):
if self.plan.granularity > 0:
return SchedulerCommand.NONE
return SchedulerCommand.RESOURCE_LIMIT_REACHED
scheduler = TimedScheduler(plan)
assert scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED
class TestWorkflowScheduleCFSPlanEntity:
def test_strategy_values(self):
assert WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice == "time-slice"
assert WorkflowScheduleCFSPlanEntity.Strategy.Nop == "nop"
def test_default_granularity(self):
plan = WorkflowScheduleCFSPlanEntity(
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop,
)
assert plan.granularity == -1
def test_explicit_granularity(self):
plan = WorkflowScheduleCFSPlanEntity(
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
granularity=100,
)
assert plan.granularity == 100

View File

@ -41,6 +41,23 @@ class TestWorkflowService:
workflows.append(workflow)
return workflows
@pytest.fixture
def dummy_session_cls(self):
class DummySession:
def __init__(self, *args, **kwargs):
self.commit = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def begin(self):
return nullcontext()
return DummySession
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
mock_app.workflow_id = None
mock_session = MagicMock()
@ -170,7 +187,10 @@ class TestWorkflowService:
mock_session.scalars.assert_called_once()
def test_submit_human_input_form_preview_uses_rendered_content(
self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
self,
workflow_service: WorkflowService,
monkeypatch: pytest.MonkeyPatch,
dummy_session_cls,
) -> None:
service = workflow_service
node_data = HumanInputNodeData(
@ -197,19 +217,6 @@ class TestWorkflowService:
saved_outputs: dict[str, object] = {}
class DummySession:
def __init__(self, *args, **kwargs):
self.commit = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def begin(self):
return nullcontext()
class DummySaver:
def __init__(self, *args, **kwargs):
pass
@ -217,7 +224,7 @@ class TestWorkflowService:
def save(self, outputs, process_data):
saved_outputs.update(outputs)
monkeypatch.setattr(workflow_service_module, "Session", DummySession)
monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
@ -278,7 +285,6 @@ class TestWorkflowService:
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
account = SimpleNamespace(id="account-1")
with pytest.raises(ValueError) as exc_info:
service.submit_human_input_form_preview(
app_model=app_model,
@ -290,3 +296,115 @@ class TestWorkflowService:
)
assert "Missing required inputs" in str(exc_info.value)
def test_run_draft_workflow_node_successful_behavior(
self, workflow_service, mock_app, monkeypatch, dummy_session_cls
):
"""Behavior: When a basic workflow node runs, it correctly sets up context,
executes the node, and saves outputs."""
service = workflow_service
account = SimpleNamespace(id="account-1")
mock_workflow = MagicMock()
mock_workflow.id = "wf-1"
mock_workflow.tenant_id = "tenant-1"
mock_workflow.environment_variables = []
mock_workflow.conversation_variables = []
# Mock node config
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
mock_workflow.get_enclosing_node_type_and_id.return_value = None
# Mock class methods
monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock())
# Mock workflow entry execution
mock_node_exec = MagicMock()
mock_node_exec.id = "exec-1"
mock_node_exec.process_data = {}
mock_run = MagicMock()
monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run)
# Mock execution handling
service._handle_single_step_result = MagicMock(return_value=mock_node_exec)
# Mock repository
mock_repo = MagicMock()
mock_repo.get_execution_by_id.return_value = mock_node_exec
mock_repo_factory = MagicMock(return_value=mock_repo)
monkeypatch.setattr(
workflow_service_module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
mock_repo_factory,
)
service._node_execution_service_repo = mock_repo
# Set up node execution service repo mock to return our exec node
mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"}
mock_node_exec.node_id = "node-1"
mock_node_exec.node_type = "llm"
# Mock draft variable saver
mock_saver = MagicMock()
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver))
# Mock DB
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
# Act
result = service.run_draft_workflow_node(
app_model=mock_app,
draft_workflow=mock_workflow,
node_id="node-1",
user_inputs={"input_val": "test"},
account=account,
)
# Assert
assert result == mock_node_exec
service._handle_single_step_result.assert_called_once()
mock_repo.save.assert_called_once_with(mock_node_exec)
mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"})
def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls):
"""Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations."""
service = workflow_service
account = SimpleNamespace(id="account-1")
mock_workflow = MagicMock()
mock_workflow.tenant_id = "tenant-1"
mock_workflow.environment_variables = []
mock_workflow.conversation_variables = []
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
mock_workflow.get_enclosing_node_type_and_id.return_value = None
monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock())
monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock())
mock_node_exec = MagicMock()
mock_node_exec.id = "exec-invalid"
service._handle_single_step_result = MagicMock(return_value=mock_node_exec)
mock_repo = MagicMock()
mock_repo_factory = MagicMock(return_value=mock_repo)
monkeypatch.setattr(
workflow_service_module.DifyCoreRepositoryFactory,
"create_workflow_node_execution_repository",
mock_repo_factory,
)
service._node_execution_service_repo = mock_repo
# Simulate failure to retrieve the saved execution
mock_repo.get_execution_by_id.return_value = None
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
# Act & Assert
with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"):
service.run_draft_workflow_node(
app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account
)