diff --git a/api/tests/unit_tests/services/recommend_app/__init__.py b/api/tests/unit_tests/services/recommend_app/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py
new file mode 100644
index 0000000000..770344aa39
--- /dev/null
+++ b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py
@@ -0,0 +1,91 @@
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
+from services.recommend_app.recommend_app_type import RecommendAppType
+
+SAMPLE_BUILTIN_DATA = {
+ "recommended_apps": {
+ "en-US": {"categories": ["writing"], "apps": [{"id": "app-1"}]},
+ "zh-Hans": {"categories": ["search"], "apps": [{"id": "app-2"}]},
+ },
+ "app_details": {
+ "app-1": {"id": "app-1", "name": "Writer", "mode": "chat"},
+ "app-2": {"id": "app-2", "name": "Searcher", "mode": "workflow"},
+ },
+}
+
+
+@pytest.fixture(autouse=True)
+def _reset_cache():
+ BuildInRecommendAppRetrieval.builtin_data = None
+ yield
+ BuildInRecommendAppRetrieval.builtin_data = None
+
+
+class TestBuildInRecommendAppRetrieval:
+ def test_get_type(self):
+ retrieval = BuildInRecommendAppRetrieval()
+ assert retrieval.get_type() == RecommendAppType.BUILDIN
+
+ def test_get_recommended_apps_and_categories_delegates(self):
+ with patch.object(
+ BuildInRecommendAppRetrieval,
+ "fetch_recommended_apps_from_builtin",
+ return_value={"apps": []},
+ ) as mock_fetch:
+ retrieval = BuildInRecommendAppRetrieval()
+ result = retrieval.get_recommended_apps_and_categories("en-US")
+ mock_fetch.assert_called_once_with("en-US")
+ assert result == {"apps": []}
+
+ def test_get_recommend_app_detail_delegates(self):
+ with patch.object(
+ BuildInRecommendAppRetrieval,
+ "fetch_recommended_app_detail_from_builtin",
+ return_value={"id": "app-1"},
+ ) as mock_fetch:
+ retrieval = BuildInRecommendAppRetrieval()
+ result = retrieval.get_recommend_app_detail("app-1")
+ mock_fetch.assert_called_once_with("app-1")
+ assert result == {"id": "app-1"}
+
+ def test_get_builtin_data_reads_json_and_caches(self, tmp_path):
+ json_file = tmp_path / "constants" / "recommended_apps.json"
+ json_file.parent.mkdir(parents=True)
+ json_file.write_text(json.dumps(SAMPLE_BUILTIN_DATA))
+
+ mock_app = MagicMock()
+ mock_app.root_path = str(tmp_path)
+
+ with patch(
+ "services.recommend_app.buildin.buildin_retrieval.current_app",
+ mock_app,
+ ):
+ first = BuildInRecommendAppRetrieval._get_builtin_data()
+ second = BuildInRecommendAppRetrieval._get_builtin_data()
+
+ assert first == SAMPLE_BUILTIN_DATA
+ assert first is second
+
+ def test_fetch_recommended_apps_from_builtin(self):
+ BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
+ result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("en-US")
+ assert result == SAMPLE_BUILTIN_DATA["recommended_apps"]["en-US"]
+
+ def test_fetch_recommended_apps_from_builtin_missing_language(self):
+ BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
+ result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("fr-FR")
+ assert result == {}
+
+ def test_fetch_recommended_app_detail_from_builtin(self):
+ BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
+ result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("app-1")
+ assert result == {"id": "app-1", "name": "Writer", "mode": "chat"}
+
+ def test_fetch_recommended_app_detail_from_builtin_missing(self):
+ BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
+ result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("nonexistent")
+ assert result is None
diff --git a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py
new file mode 100644
index 0000000000..5d21665f75
--- /dev/null
+++ b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py
@@ -0,0 +1,145 @@
+from types import SimpleNamespace
+from unittest.mock import MagicMock, patch
+
+from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
+from services.recommend_app.recommend_app_type import RecommendAppType
+
+
+class TestDatabaseRecommendAppRetrieval:
+ def test_get_type(self):
+ assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE
+
+ def test_get_recommended_apps_delegates(self):
+ with patch.object(
+ DatabaseRecommendAppRetrieval,
+ "fetch_recommended_apps_from_db",
+ return_value={"recommended_apps": [], "categories": []},
+ ) as mock_fetch:
+ result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
+ mock_fetch.assert_called_once_with("en-US")
+ assert result == {"recommended_apps": [], "categories": []}
+
+ def test_get_recommend_app_detail_delegates(self):
+ with patch.object(
+ DatabaseRecommendAppRetrieval,
+ "fetch_recommended_app_detail_from_db",
+ return_value={"id": "app-1"},
+ ) as mock_fetch:
+ result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1")
+ mock_fetch.assert_called_once_with("app-1")
+ assert result == {"id": "app-1"}
+
+
+class TestFetchRecommendedAppsFromDb:
+ def _make_recommended_app(self, app_id, category, is_public=True, has_site=True):
+ site = (
+ SimpleNamespace(
+ description="desc",
+ copyright="copy",
+ privacy_policy="pp",
+ custom_disclaimer="cd",
+ )
+ if has_site
+ else None
+ )
+ app = (
+ SimpleNamespace(is_public=is_public, site=site)
+ if is_public
+ else SimpleNamespace(is_public=False, site=site)
+ )
+ return SimpleNamespace(
+ id=f"rec-{app_id}",
+ app=app,
+ app_id=app_id,
+ category=category,
+ position=1,
+ is_listed=True,
+ )
+
+ @patch("services.recommend_app.database.database_retrieval.db")
+ def test_returns_apps_and_sorted_categories(self, mock_db):
+ rec1 = self._make_recommended_app("a1", "writing")
+ rec2 = self._make_recommended_app("a2", "assistant")
+ mock_db.session.scalars.return_value.all.return_value = [rec1, rec2]
+
+ result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
+
+ assert len(result["recommended_apps"]) == 2
+ assert result["categories"] == ["assistant", "writing"]
+
+ @patch("services.recommend_app.database.database_retrieval.db")
+ def test_falls_back_to_default_language_when_empty(self, mock_db):
+ mock_db.session.scalars.return_value.all.side_effect = [
+ [],
+ [self._make_recommended_app("a1", "chat")],
+ ]
+
+ result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR")
+
+ assert len(result["recommended_apps"]) == 1
+ assert mock_db.session.scalars.call_count == 2
+
+ @patch("services.recommend_app.database.database_retrieval.db")
+ def test_skips_non_public_apps(self, mock_db):
+ rec = self._make_recommended_app("a1", "chat", is_public=False)
+ mock_db.session.scalars.return_value.all.return_value = [rec]
+
+ result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
+
+ assert result["recommended_apps"] == []
+
+ @patch("services.recommend_app.database.database_retrieval.db")
+ def test_skips_apps_without_site(self, mock_db):
+ rec = self._make_recommended_app("a1", "chat", has_site=False)
+ mock_db.session.scalars.return_value.all.return_value = [rec]
+
+ result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
+
+ assert result["recommended_apps"] == []
+
+
+class TestFetchRecommendedAppDetailFromDb:
+ @patch("services.recommend_app.database.database_retrieval.db")
+ def test_returns_none_when_not_listed(self, mock_db):
+ mock_db.session.query.return_value.where.return_value.first.return_value = None
+
+ result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
+
+ assert result is None
+
+ @patch("services.recommend_app.database.database_retrieval.AppDslService")
+ @patch("services.recommend_app.database.database_retrieval.db")
+ def test_returns_none_when_app_not_public(self, mock_db, mock_dsl):
+ rec_chain = MagicMock()
+ rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1")
+ app_chain = MagicMock()
+ app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False)
+ mock_db.session.query.side_effect = [rec_chain, app_chain]
+
+ result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
+
+ assert result is None
+
+ @patch("services.recommend_app.database.database_retrieval.AppDslService")
+ @patch("services.recommend_app.database.database_retrieval.db")
+ def test_returns_detail_on_success(self, mock_db, mock_dsl):
+ app_model = SimpleNamespace(
+ id="app-1",
+ name="My App",
+ icon="icon.png",
+ icon_background="#fff",
+ mode="chat",
+ is_public=True,
+ )
+ rec_chain = MagicMock()
+ rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1")
+ app_chain = MagicMock()
+ app_chain.where.return_value.first.return_value = app_model
+ mock_db.session.query.side_effect = [rec_chain, app_chain]
+ mock_dsl.export_dsl.return_value = "exported_yaml"
+
+ result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
+
+ assert result["id"] == "app-1"
+ assert result["name"] == "My App"
+ assert result["export_data"] == "exported_yaml"
diff --git a/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py b/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py
new file mode 100644
index 0000000000..036cba0cc0
--- /dev/null
+++ b/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py
@@ -0,0 +1,28 @@
+import pytest
+
+from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
+from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
+from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
+from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval
+
+
+class TestRecommendAppRetrievalFactory:
+ @pytest.mark.parametrize(
+ ("mode", "expected_class"),
+ [
+ ("remote", RemoteRecommendAppRetrieval),
+ ("builtin", BuildInRecommendAppRetrieval),
+ ("db", DatabaseRecommendAppRetrieval),
+ ],
+ )
+ def test_factory_returns_correct_class(self, mode, expected_class):
+ result = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)
+ assert result is expected_class
+
+ def test_factory_raises_for_unknown_mode(self):
+ with pytest.raises(ValueError, match="invalid fetch recommended apps mode"):
+ RecommendAppRetrievalFactory.get_recommend_app_factory("invalid_mode")
+
+ def test_get_buildin_recommend_app_retrieval(self):
+ result = RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval()
+ assert result is BuildInRecommendAppRetrieval
diff --git a/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py b/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py
new file mode 100644
index 0000000000..08f72a6f77
--- /dev/null
+++ b/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py
@@ -0,0 +1,18 @@
+from services.recommend_app.recommend_app_type import RecommendAppType
+
+
+def test_enum_values():
+ assert RecommendAppType.REMOTE == "remote"
+ assert RecommendAppType.BUILDIN == "builtin"
+ assert RecommendAppType.DATABASE == "db"
+
+
+def test_enum_membership():
+ assert "remote" in RecommendAppType.__members__.values()
+ assert "builtin" in RecommendAppType.__members__.values()
+ assert "db" in RecommendAppType.__members__.values()
+
+
+def test_enum_is_str():
+ for member in RecommendAppType:
+ assert isinstance(member, str)
diff --git a/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py
new file mode 100644
index 0000000000..e322fbed4c
--- /dev/null
+++ b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py
@@ -0,0 +1,120 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from services.recommend_app.recommend_app_type import RecommendAppType
+from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval
+
+
+class TestRemoteRecommendAppRetrieval:
+ def test_get_type(self):
+ assert RemoteRecommendAppRetrieval().get_type() == RecommendAppType.REMOTE
+
+ @patch.object(
+ RemoteRecommendAppRetrieval,
+ "fetch_recommended_app_detail_from_dify_official",
+ return_value={"id": "app-1"},
+ )
+ def test_get_recommend_app_detail_success(self, mock_fetch):
+ result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1")
+ assert result == {"id": "app-1"}
+ mock_fetch.assert_called_once_with("app-1")
+
+ @patch(
+ "services.recommend_app.remote.remote_retrieval"
+ ".BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin",
+ return_value={"id": "fallback"},
+ )
+ @patch.object(
+ RemoteRecommendAppRetrieval,
+ "fetch_recommended_app_detail_from_dify_official",
+ side_effect=ConnectionError("timeout"),
+ )
+ def test_get_recommend_app_detail_falls_back_on_error(self, mock_fetch, mock_builtin):
+ result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1")
+ assert result == {"id": "fallback"}
+ mock_builtin.assert_called_once_with("app-1")
+
+ @patch.object(
+ RemoteRecommendAppRetrieval,
+ "fetch_recommended_apps_from_dify_official",
+ return_value={"recommended_apps": [], "categories": []},
+ )
+ def test_get_recommended_apps_success(self, mock_fetch):
+ result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
+ assert result == {"recommended_apps": [], "categories": []}
+
+ @patch(
+ "services.recommend_app.remote.remote_retrieval"
+ ".BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin",
+ return_value={"recommended_apps": [{"id": "builtin"}]},
+ )
+ @patch.object(
+ RemoteRecommendAppRetrieval,
+ "fetch_recommended_apps_from_dify_official",
+ side_effect=ValueError("server error"),
+ )
+ def test_get_recommended_apps_falls_back_on_error(self, mock_fetch, mock_builtin):
+ result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
+ assert result == {"recommended_apps": [{"id": "builtin"}]}
+
+
+class TestFetchFromDifyOfficial:
+ @patch("services.recommend_app.remote.remote_retrieval.dify_config")
+ @patch("services.recommend_app.remote.remote_retrieval.httpx.get")
+ def test_detail_returns_json_on_200(self, mock_get, mock_config):
+ mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
+ mock_response = MagicMock(status_code=200)
+ mock_response.json.return_value = {"id": "app-1", "name": "Test"}
+ mock_get.return_value = mock_response
+
+ result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1")
+
+ assert result == {"id": "app-1", "name": "Test"}
+ mock_get.assert_called_once()
+
+ @patch("services.recommend_app.remote.remote_retrieval.dify_config")
+ @patch("services.recommend_app.remote.remote_retrieval.httpx.get")
+ def test_detail_returns_none_on_non_200(self, mock_get, mock_config):
+ mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
+ mock_get.return_value = MagicMock(status_code=404)
+
+ result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1")
+
+ assert result is None
+
+ @patch("services.recommend_app.remote.remote_retrieval.dify_config")
+ @patch("services.recommend_app.remote.remote_retrieval.httpx.get")
+ def test_apps_returns_sorted_categories_on_200(self, mock_get, mock_config):
+ mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
+ mock_response = MagicMock(status_code=200)
+ mock_response.json.return_value = {
+ "recommended_apps": [],
+ "categories": ["writing", "agent", "chat"],
+ }
+ mock_get.return_value = mock_response
+
+ result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
+
+ assert result["categories"] == ["agent", "chat", "writing"]
+
+ @patch("services.recommend_app.remote.remote_retrieval.dify_config")
+ @patch("services.recommend_app.remote.remote_retrieval.httpx.get")
+ def test_apps_raises_on_non_200(self, mock_get, mock_config):
+ mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
+ mock_get.return_value = MagicMock(status_code=500)
+
+ with pytest.raises(ValueError, match="fetch recommended apps failed"):
+ RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
+
+ @patch("services.recommend_app.remote.remote_retrieval.dify_config")
+ @patch("services.recommend_app.remote.remote_retrieval.httpx.get")
+ def test_apps_without_categories_key(self, mock_get, mock_config):
+ mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
+ mock_response = MagicMock(status_code=200)
+ mock_response.json.return_value = {"recommended_apps": []}
+ mock_get.return_value = mock_response
+
+ result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
+
+ assert "categories" not in result
diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py
new file mode 100644
index 0000000000..439d203c58
--- /dev/null
+++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py
@@ -0,0 +1,455 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from services.tools.builtin_tools_manage_service import BuiltinToolManageService
+
+MODULE = "services.tools.builtin_tools_manage_service"
+
+
+def _mock_session(mock_session_cls):
+ """Helper: set up a Session context manager mock and return the inner session."""
+ session = MagicMock()
+ mock_session_cls.return_value.__enter__ = MagicMock(return_value=session)
+ mock_session_cls.return_value.__exit__ = MagicMock(return_value=False)
+ return session
+
+
+class TestDeleteCustomOauthClientParams:
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_deletes_and_returns_success(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+
+ result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
+
+ assert result == {"result": "success"}
+ session.query.return_value.filter_by.return_value.delete.assert_called_once()
+ session.commit.assert_called_once()
+
+
+class TestListBuiltinToolProviderTools:
+ @patch(f"{MODULE}.ToolLabelManager")
+ @patch(f"{MODULE}.ToolTransformService")
+ @patch(f"{MODULE}.ToolManager")
+ def test_transforms_each_tool(self, mock_manager, mock_transform, mock_labels):
+ mock_controller = MagicMock()
+ mock_controller.get_tools.return_value = [MagicMock(), MagicMock()]
+ mock_manager.get_builtin_provider.return_value = mock_controller
+ mock_transform.convert_tool_entity_to_api_entity.return_value = MagicMock()
+
+ result = BuiltinToolManageService.list_builtin_tool_provider_tools("tenant-1", "google")
+
+ assert len(result) == 2
+
+ @patch(f"{MODULE}.ToolLabelManager")
+ @patch(f"{MODULE}.ToolTransformService")
+ @patch(f"{MODULE}.ToolManager")
+ def test_empty_tools(self, mock_manager, mock_transform, mock_labels):
+ mock_controller = MagicMock()
+ mock_controller.get_tools.return_value = []
+ mock_manager.get_builtin_provider.return_value = mock_controller
+
+ assert BuiltinToolManageService.list_builtin_tool_provider_tools("t", "p") == []
+
+
+class TestGetBuiltinToolProviderInfo:
+ @patch(f"{MODULE}.ToolTransformService")
+ @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider")
+ @patch(f"{MODULE}.ToolManager")
+ def test_raises_when_not_found(self, mock_manager, mock_get, mock_transform):
+ mock_get.return_value = None
+
+ with pytest.raises(ValueError, match="you have not added provider"):
+ BuiltinToolManageService.get_builtin_tool_provider_info("t", "no")
+
+ @patch(f"{MODULE}.ToolTransformService")
+ @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider")
+ @patch(f"{MODULE}.ToolManager")
+ def test_clears_original_credentials(self, mock_manager, mock_get, mock_transform):
+ mock_get.return_value = MagicMock()
+ entity = MagicMock()
+ mock_transform.builtin_provider_to_user_provider.return_value = entity
+
+ result = BuiltinToolManageService.get_builtin_tool_provider_info("t", "google")
+
+ assert result.original_credentials == {}
+
+
+class TestListBuiltinProviderCredentialsSchema:
+ @patch(f"{MODULE}.ToolManager")
+ def test_returns_schema(self, mock_manager):
+ mock_manager.get_builtin_provider.return_value.get_credentials_schema_by_type.return_value = [{"f": "k"}]
+
+ result = BuiltinToolManageService.list_builtin_provider_credentials_schema("g", "api_key", "t")
+
+ assert result == [{"f": "k"}]
+
+
+class TestGetBuiltinToolProviderIcon:
+ @patch(f"{MODULE}.Path")
+ @patch(f"{MODULE}.ToolManager")
+ def test_returns_bytes_and_mime(self, mock_manager, mock_path):
+ mock_manager.get_hardcoded_provider_icon.return_value = ("/icon.svg", "image/svg+xml")
+ mock_path.return_value.read_bytes.return_value = b""
+
+ icon, mime = BuiltinToolManageService.get_builtin_tool_provider_icon("google")
+
+ assert icon == b""
+ assert mime == "image/svg+xml"
+
+
+class TestIsOauthSystemClientExists:
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_true_when_exists(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+ session.query.return_value.filter_by.return_value.first.return_value = MagicMock()
+
+ assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
+
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_false_when_missing(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+ session.query.return_value.filter_by.return_value.first.return_value = None
+
+ assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
+
+
+class TestIsOauthCustomClientEnabled:
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_true_when_enabled(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+ session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True)
+
+ assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
+
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_false_when_none(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+ session.query.return_value.filter_by.return_value.first.return_value = None
+
+ assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
+
+
+class TestDeleteBuiltinToolProvider:
+ @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
+ @patch(f"{MODULE}.ToolManager")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc):
+ session = _mock_session(mock_session_cls)
+ session.query.return_value.where.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="you have not added provider"):
+ BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
+
+ @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
+ @patch(f"{MODULE}.ToolManager")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc):
+ session = _mock_session(mock_session_cls)
+ db_provider = MagicMock()
+ session.query.return_value.where.return_value.first.return_value = db_provider
+ mock_cache = MagicMock()
+ mock_enc.return_value = (MagicMock(), mock_cache)
+
+ result = BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "c")
+
+ assert result == {"result": "success"}
+ session.delete.assert_called_once_with(db_provider)
+ session.commit.assert_called_once()
+ mock_cache.delete.assert_called_once()
+
+
+class TestSetDefaultProvider:
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_raises_when_not_found(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+ session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="provider not found"):
+ BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
+
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_sets_default_and_clears_old(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+ target = MagicMock()
+ session.query.return_value.filter_by.return_value.first.return_value = target
+
+ result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
+
+ assert result == {"result": "success"}
+ assert target.is_default is True
+ session.commit.assert_called_once()
+
+
+class TestUpdateBuiltinToolProvider:
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+ session.query.return_value.where.return_value.first.return_value = None
+
+ with pytest.raises(ValueError, match="you have not added provider"):
+ BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
+
+ @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
+ @patch(f"{MODULE}.CredentialType")
+ @patch(f"{MODULE}.ToolManager")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc):
+ session = _mock_session(mock_session_cls)
+ db_provider = MagicMock(credential_type="api_key", credentials="{}")
+ session.query.return_value.where.return_value.first.return_value = db_provider
+
+ mock_cred_instance = MagicMock()
+ mock_cred_instance.is_editable.return_value = True
+ mock_cred_instance.is_validate_allowed.return_value = False
+ mock_cred_type.of.return_value = mock_cred_instance
+
+ mock_controller = MagicMock(need_credentials=True)
+ mock_tm.get_builtin_provider.return_value = mock_controller
+
+ mock_encrypter = MagicMock()
+ mock_encrypter.decrypt.return_value = {"key": "old"}
+ mock_encrypter.encrypt.return_value = {"key": "new"}
+ mock_cache = MagicMock()
+ mock_enc.return_value = (mock_encrypter, mock_cache)
+
+ result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
+
+ assert result == {"result": "success"}
+ session.commit.assert_called_once()
+ mock_cache.delete.assert_called_once()
+
+
+class TestGetOauthClientSchema:
+ @patch(f"{MODULE}.BuiltinToolManageService.get_custom_oauth_client_params", return_value={})
+ @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_system_client_exists", return_value=False)
+ @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=True)
+ @patch(f"{MODULE}.dify_config")
+ @patch(f"{MODULE}.PluginService")
+ @patch(f"{MODULE}.ToolManager")
+ def test_returns_schema_dict(self, mock_tm, mock_plugin, mock_config, mock_enabled, mock_sys, mock_params):
+ mock_config.CONSOLE_API_URL = "https://api.example.com"
+ mock_controller = MagicMock()
+ mock_controller.get_oauth_client_schema.return_value = []
+ mock_tm.get_builtin_provider.return_value = mock_controller
+
+ result = BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema("t", "google")
+
+ assert "schema" in result
+ assert result["is_oauth_custom_client_enabled"] is True
+ assert "redirect_uri" in result
+
+
+class TestGetOauthClient:
+ @patch(f"{MODULE}.PluginService")
+ @patch(f"{MODULE}.create_provider_encrypter")
+ @patch(f"{MODULE}.ToolManager")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_returns_user_client_params_when_exists(
+ self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin
+ ):
+ session = _mock_session(mock_session_cls)
+ mock_controller = MagicMock()
+ mock_controller.get_oauth_client_schema.return_value = []
+ mock_tm.get_builtin_provider.return_value = mock_controller
+
+ mock_encrypter = MagicMock()
+ mock_encrypter.decrypt.return_value = {"client_id": "id", "client_secret": "secret"}
+ mock_create_enc.return_value = (mock_encrypter, MagicMock())
+
+ user_client = MagicMock(oauth_params='{"encrypted": "data"}')
+ session.query.return_value.filter_by.return_value.first.return_value = user_client
+
+ result = BuiltinToolManageService.get_oauth_client("t", "google")
+
+ assert result == {"client_id": "id", "client_secret": "secret"}
+
+ @patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"})
+ @patch(f"{MODULE}.PluginService")
+ @patch(f"{MODULE}.create_provider_encrypter")
+ @patch(f"{MODULE}.ToolManager")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_falls_back_to_system_client(
+ self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin, mock_decrypt
+ ):
+ session = _mock_session(mock_session_cls)
+ mock_controller = MagicMock()
+ mock_controller.get_oauth_client_schema.return_value = []
+ mock_tm.get_builtin_provider.return_value = mock_controller
+
+ mock_create_enc.return_value = (MagicMock(), MagicMock())
+
+ system_client = MagicMock(encrypted_oauth_params="enc")
+ session.query.return_value.filter_by.return_value.first.side_effect = [
+ None, # user client
+ system_client, # system client
+ ]
+
+ result = BuiltinToolManageService.get_oauth_client("t", "google")
+
+ assert result == {"sys_key": "sys_val"}
+
+
+class TestSaveCustomOauthClientParams:
+ def test_returns_early_when_no_params(self):
+ result = BuiltinToolManageService.save_custom_oauth_client_params("t", "p")
+ assert result == {"result": "success"}
+
+ @patch(f"{MODULE}.ToolManager")
+ def test_raises_when_provider_not_found(self, mock_tm):
+ mock_tm.get_builtin_provider.return_value = None
+
+ with pytest.raises((ValueError, Exception), match="not found|Provider"):
+ BuiltinToolManageService.save_custom_oauth_client_params("t", "p", enable_oauth_custom_client=True)
+
+
+class TestGetCustomOauthClientParams:
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_returns_empty_when_none(self, mock_db, mock_session_cls):
+ session = _mock_session(mock_session_cls)
+ session.query.return_value.filter_by.return_value.first.return_value = None
+
+ result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
+
+ assert result == {}
+
+
+class TestGetBuiltinToolProviderCredentialInfo:
+ @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=False)
+ @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_tool_provider_credentials", return_value=[])
+ @patch(f"{MODULE}.ToolManager")
+ def test_returns_credential_info(self, mock_tm, mock_creds, mock_oauth):
+ mock_tm.get_builtin_provider.return_value.get_supported_credential_types.return_value = ["api-key"]
+
+ result = BuiltinToolManageService.get_builtin_tool_provider_credential_info("t", "google")
+
+ assert result.credentials == []
+ assert result.supported_credential_types == ["api-key"]
+ assert result.is_oauth_custom_client_enabled is False
+
+
+class TestGetBuiltinToolProviderCredentials:
+ @patch(f"{MODULE}.db")
+ def test_returns_empty_when_no_providers(self, mock_db):
+ mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None)
+ mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
+ mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
+
+ result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google")
+
+ assert result == []
+
+ @patch(f"{MODULE}.ToolTransformService")
+ @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
+ @patch(f"{MODULE}.ToolManager")
+ @patch(f"{MODULE}.db")
+ def test_returns_credential_entities(self, mock_db, mock_tm, mock_enc, mock_transform):
+ mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None)
+ mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
+
+ provider = MagicMock(provider="google", is_default=False)
+ mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider]
+
+ mock_encrypter = MagicMock()
+ mock_encrypter.decrypt.return_value = {"key": "decrypted"}
+ mock_encrypter.mask_plugin_credentials.return_value = {"key": "***"}
+ mock_enc.return_value = (mock_encrypter, MagicMock())
+
+ credential_entity = MagicMock()
+ mock_transform.convert_builtin_provider_to_credential_entity.return_value = credential_entity
+
+ result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google")
+
+ assert len(result) == 1
+ assert result[0] is credential_entity
+ assert provider.is_default is True
+
+
+class TestGetBuiltinProvider:
+ @patch(f"{MODULE}.ToolProviderID")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_returns_none_when_not_found(self, mock_db, mock_session_cls, mock_prov_id):
+ session = _mock_session(mock_session_cls)
+ mock_prov_id.return_value.provider_name = "google"
+ mock_prov_id.return_value.organization = "langgenius"
+ session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
+
+ result = BuiltinToolManageService.get_builtin_provider("google", "t")
+
+ assert result is None
+
+ @patch(f"{MODULE}.ToolProviderID")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_returns_provider_for_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id):
+ session = _mock_session(mock_session_cls)
+ mock_prov_id.return_value.provider_name = "google"
+ mock_prov_id.return_value.organization = "langgenius"
+ db_provider = MagicMock(provider="google")
+ mock_prov_id_result = MagicMock()
+ mock_prov_id_result.to_string.return_value = "langgenius/google/google"
+
+ def prov_id_side_effect(name):
+ m = MagicMock()
+ m.provider_name = "google"
+ m.organization = "langgenius"
+ m.to_string.return_value = "langgenius/google/google"
+ m.plugin_id = "langgenius/google"
+ return m
+
+ mock_prov_id.side_effect = prov_id_side_effect
+ session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
+
+ result = BuiltinToolManageService.get_builtin_provider("google", "t")
+
+ assert result is db_provider
+
+ @patch(f"{MODULE}.ToolProviderID")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_returns_provider_for_non_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id):
+ session = _mock_session(mock_session_cls)
+
+ def prov_id_side_effect(name):
+ m = MagicMock()
+ m.provider_name = "custom-tool"
+ m.organization = "third-party"
+ m.to_string.return_value = "third-party/custom/custom-tool"
+ m.plugin_id = "third-party/custom"
+ return m
+
+ mock_prov_id.side_effect = prov_id_side_effect
+ db_provider = MagicMock(provider="third-party/custom/custom-tool")
+ session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
+
+ result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
+
+ assert result is db_provider
+
+ @patch(f"{MODULE}.ToolProviderID")
+ @patch(f"{MODULE}.Session")
+ @patch(f"{MODULE}.db")
+ def test_falls_back_on_exception(self, mock_db, mock_session_cls, mock_prov_id):
+ session = _mock_session(mock_session_cls)
+ mock_prov_id.side_effect = Exception("parse error")
+ fallback = MagicMock()
+ session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback
+
+ result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")
+
+ assert result is fallback
diff --git a/api/tests/unit_tests/services/tools/test_tool_labels_service.py b/api/tests/unit_tests/services/tools/test_tool_labels_service.py
new file mode 100644
index 0000000000..6acdbb7901
--- /dev/null
+++ b/api/tests/unit_tests/services/tools/test_tool_labels_service.py
@@ -0,0 +1,21 @@
+from services.tools.tool_labels_service import ToolLabelsService
+
+
+def test_list_tool_labels_returns_default_labels():
+ result = ToolLabelsService.list_tool_labels()
+ assert isinstance(result, list)
+ assert len(result) > 0
+
+
+def test_list_tool_labels_items_are_tool_labels():
+ from core.tools.entities.tool_entities import ToolLabel
+
+ result = ToolLabelsService.list_tool_labels()
+ for label in result:
+ assert isinstance(label, ToolLabel)
+
+
+def test_list_tool_labels_matches_default_values():
+ from core.tools.entities.values import default_tool_labels
+
+ assert ToolLabelsService.list_tool_labels() is default_tool_labels
diff --git a/api/tests/unit_tests/services/tools/test_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_tools_manage_service.py
new file mode 100644
index 0000000000..73ac9a10c6
--- /dev/null
+++ b/api/tests/unit_tests/services/tools/test_tools_manage_service.py
@@ -0,0 +1,40 @@
+from unittest.mock import MagicMock, patch
+
+from services.tools.tools_manage_service import ToolCommonService
+
+
+class TestToolCommonService:
+ @patch("services.tools.tools_manage_service.ToolTransformService")
+ @patch("services.tools.tools_manage_service.ToolManager")
+ def test_list_tool_providers_transforms_and_returns(self, mock_manager, mock_transform):
+ mock_provider1 = MagicMock()
+ mock_provider1.to_dict.return_value = {"name": "provider1"}
+ mock_provider2 = MagicMock()
+ mock_provider2.to_dict.return_value = {"name": "provider2"}
+ mock_manager.list_providers_from_api.return_value = [mock_provider1, mock_provider2]
+
+ result = ToolCommonService.list_tool_providers("user-1", "tenant-1")
+
+ mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", None)
+ assert mock_transform.repack_provider.call_count == 2
+ assert result == [{"name": "provider1"}, {"name": "provider2"}]
+
+ @patch("services.tools.tools_manage_service.ToolTransformService")
+ @patch("services.tools.tools_manage_service.ToolManager")
+ def test_list_tool_providers_with_type_filter(self, mock_manager, mock_transform):
+ mock_manager.list_providers_from_api.return_value = []
+
+ result = ToolCommonService.list_tool_providers("user-1", "tenant-1", typ="builtin")
+
+ mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", "builtin")
+ assert result == []
+
+ @patch("services.tools.tools_manage_service.ToolTransformService")
+ @patch("services.tools.tools_manage_service.ToolManager")
+ def test_list_tool_providers_empty(self, mock_manager, mock_transform):
+ mock_manager.list_providers_from_api.return_value = []
+
+ result = ToolCommonService.list_tool_providers("u", "t")
+
+ assert result == []
+ mock_transform.repack_provider.assert_not_called()
diff --git a/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py b/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py
new file mode 100644
index 0000000000..bbfc1cc294
--- /dev/null
+++ b/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py
@@ -0,0 +1,110 @@
+from unittest.mock import patch
+
+import pytest
+
+from services.workflow.queue_dispatcher import (
+ BaseQueueDispatcher,
+ ProfessionalQueueDispatcher,
+ QueueDispatcherManager,
+ QueuePriority,
+ SandboxQueueDispatcher,
+ TeamQueueDispatcher,
+)
+
+
+class TestQueuePriority:
+ def test_priority_values(self):
+ assert QueuePriority.PROFESSIONAL == "workflow_professional"
+ assert QueuePriority.TEAM == "workflow_team"
+ assert QueuePriority.SANDBOX == "workflow_sandbox"
+
+
+class TestDispatchers:
+ def test_professional_dispatcher(self):
+ d = ProfessionalQueueDispatcher()
+ assert d.get_queue_name() == QueuePriority.PROFESSIONAL
+ assert d.get_priority() == 100
+
+ def test_team_dispatcher(self):
+ d = TeamQueueDispatcher()
+ assert d.get_queue_name() == QueuePriority.TEAM
+ assert d.get_priority() == 50
+
+ def test_sandbox_dispatcher(self):
+ d = SandboxQueueDispatcher()
+ assert d.get_queue_name() == QueuePriority.SANDBOX
+ assert d.get_priority() == 10
+
+ def test_base_dispatcher_is_abstract(self):
+ with pytest.raises(TypeError):
+ BaseQueueDispatcher()
+
+
+class TestQueueDispatcherManager:
+ @patch("services.workflow.queue_dispatcher.BillingService")
+ @patch("services.workflow.queue_dispatcher.dify_config")
+ def test_billing_enabled_professional_plan(self, mock_config, mock_billing):
+ mock_config.BILLING_ENABLED = True
+ mock_billing.get_info.return_value = {"subscription": {"plan": "professional"}}
+
+ dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
+
+ assert isinstance(dispatcher, ProfessionalQueueDispatcher)
+
+ @patch("services.workflow.queue_dispatcher.BillingService")
+ @patch("services.workflow.queue_dispatcher.dify_config")
+ def test_billing_enabled_team_plan(self, mock_config, mock_billing):
+ mock_config.BILLING_ENABLED = True
+ mock_billing.get_info.return_value = {"subscription": {"plan": "team"}}
+
+ dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
+
+ assert isinstance(dispatcher, TeamQueueDispatcher)
+
+ @patch("services.workflow.queue_dispatcher.BillingService")
+ @patch("services.workflow.queue_dispatcher.dify_config")
+ def test_billing_enabled_sandbox_plan(self, mock_config, mock_billing):
+ mock_config.BILLING_ENABLED = True
+ mock_billing.get_info.return_value = {"subscription": {"plan": "sandbox"}}
+
+ dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
+
+ assert isinstance(dispatcher, SandboxQueueDispatcher)
+
+ @patch("services.workflow.queue_dispatcher.BillingService")
+ @patch("services.workflow.queue_dispatcher.dify_config")
+ def test_billing_enabled_unknown_plan_defaults_to_sandbox(self, mock_config, mock_billing):
+ mock_config.BILLING_ENABLED = True
+ mock_billing.get_info.return_value = {"subscription": {"plan": "enterprise"}}
+
+ dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
+
+ assert isinstance(dispatcher, SandboxQueueDispatcher)
+
+ @patch("services.workflow.queue_dispatcher.BillingService")
+ @patch("services.workflow.queue_dispatcher.dify_config")
+ def test_billing_enabled_service_failure_defaults_to_sandbox(self, mock_config, mock_billing):
+ mock_config.BILLING_ENABLED = True
+ mock_billing.get_info.side_effect = Exception("billing unavailable")
+
+ dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
+
+ assert isinstance(dispatcher, SandboxQueueDispatcher)
+
+ @patch("services.workflow.queue_dispatcher.dify_config")
+ def test_billing_disabled_defaults_to_team(self, mock_config):
+ mock_config.BILLING_ENABLED = False
+
+ dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
+
+ assert isinstance(dispatcher, TeamQueueDispatcher)
+
+ @patch("services.workflow.queue_dispatcher.BillingService")
+ @patch("services.workflow.queue_dispatcher.dify_config")
+ def test_missing_subscription_key_defaults_to_sandbox(self, mock_config, mock_billing):
+ mock_config.BILLING_ENABLED = True
+ mock_billing.get_info.return_value = {}
+
+ dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
+
+ assert isinstance(dispatcher, SandboxQueueDispatcher)
diff --git a/api/tests/unit_tests/services/workflow/test_scheduler.py b/api/tests/unit_tests/services/workflow/test_scheduler.py
new file mode 100644
index 0000000000..90b6cb2d8b
--- /dev/null
+++ b/api/tests/unit_tests/services/workflow/test_scheduler.py
@@ -0,0 +1,89 @@
+import pytest
+
+from services.workflow.entities import WorkflowScheduleCFSPlanEntity
+from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
+
+
+class TestSchedulerCommand:
+ def test_enum_values(self):
+ assert SchedulerCommand.RESOURCE_LIMIT_REACHED == "resource_limit_reached"
+ assert SchedulerCommand.NONE == "none"
+
+ def test_enum_is_str(self):
+ for member in SchedulerCommand:
+ assert isinstance(member, str)
+
+
+class TestCFSPlanScheduler:
+ def test_stores_plan(self):
+ plan = WorkflowScheduleCFSPlanEntity(
+ schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop,
+ granularity=-1,
+ )
+
+ class ConcretePlanScheduler(CFSPlanScheduler):
+ def can_schedule(self):
+ return SchedulerCommand.NONE
+
+ scheduler = ConcretePlanScheduler(plan)
+
+ assert scheduler.plan is plan
+ assert scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.Nop
+ assert scheduler.plan.granularity == -1
+
+ def test_cannot_instantiate_abstract(self):
+ plan = WorkflowScheduleCFSPlanEntity(
+ schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
+ granularity=10,
+ )
+ with pytest.raises(TypeError):
+ CFSPlanScheduler(plan)
+
+ def test_concrete_subclass_can_schedule(self):
+ plan = WorkflowScheduleCFSPlanEntity(
+ schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
+ granularity=5,
+ )
+
+ class TimedScheduler(CFSPlanScheduler):
+ def can_schedule(self):
+ if self.plan.granularity > 0:
+ return SchedulerCommand.NONE
+ return SchedulerCommand.RESOURCE_LIMIT_REACHED
+
+ scheduler = TimedScheduler(plan)
+ assert scheduler.can_schedule() == SchedulerCommand.NONE
+
+ def test_concrete_subclass_resource_limit(self):
+ plan = WorkflowScheduleCFSPlanEntity(
+ schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
+ granularity=-1,
+ )
+
+ class TimedScheduler(CFSPlanScheduler):
+ def can_schedule(self):
+ if self.plan.granularity > 0:
+ return SchedulerCommand.NONE
+ return SchedulerCommand.RESOURCE_LIMIT_REACHED
+
+ scheduler = TimedScheduler(plan)
+ assert scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED
+
+
+class TestWorkflowScheduleCFSPlanEntity:
+ def test_strategy_values(self):
+ assert WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice == "time-slice"
+ assert WorkflowScheduleCFSPlanEntity.Strategy.Nop == "nop"
+
+ def test_default_granularity(self):
+ plan = WorkflowScheduleCFSPlanEntity(
+ schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop,
+ )
+ assert plan.granularity == -1
+
+ def test_explicit_granularity(self):
+ plan = WorkflowScheduleCFSPlanEntity(
+ schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
+ granularity=100,
+ )
+ assert plan.granularity == 100
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py
index e428c603a4..3953248c47 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_service.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py
@@ -41,6 +41,23 @@ class TestWorkflowService:
workflows.append(workflow)
return workflows
+ @pytest.fixture
+ def dummy_session_cls(self):
+ class DummySession:
+ def __init__(self, *args, **kwargs):
+ self.commit = MagicMock()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc, tb):
+ return False
+
+ def begin(self):
+ return nullcontext()
+
+ return DummySession
+
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
mock_app.workflow_id = None
mock_session = MagicMock()
@@ -170,7 +187,10 @@ class TestWorkflowService:
mock_session.scalars.assert_called_once()
def test_submit_human_input_form_preview_uses_rendered_content(
- self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
+ self,
+ workflow_service: WorkflowService,
+ monkeypatch: pytest.MonkeyPatch,
+ dummy_session_cls,
) -> None:
service = workflow_service
node_data = HumanInputNodeData(
@@ -197,19 +217,6 @@ class TestWorkflowService:
saved_outputs: dict[str, object] = {}
- class DummySession:
- def __init__(self, *args, **kwargs):
- self.commit = MagicMock()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc, tb):
- return False
-
- def begin(self):
- return nullcontext()
-
class DummySaver:
def __init__(self, *args, **kwargs):
pass
@@ -217,7 +224,7 @@ class TestWorkflowService:
def save(self, outputs, process_data):
saved_outputs.update(outputs)
- monkeypatch.setattr(workflow_service_module, "Session", DummySession)
+ monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
@@ -278,7 +285,6 @@ class TestWorkflowService:
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
account = SimpleNamespace(id="account-1")
-
with pytest.raises(ValueError) as exc_info:
service.submit_human_input_form_preview(
app_model=app_model,
@@ -290,3 +296,115 @@ class TestWorkflowService:
)
assert "Missing required inputs" in str(exc_info.value)
+
+ def test_run_draft_workflow_node_successful_behavior(
+ self, workflow_service, mock_app, monkeypatch, dummy_session_cls
+ ):
+ """Behavior: When a basic workflow node runs, it correctly sets up context,
+ executes the node, and saves outputs."""
+ service = workflow_service
+ account = SimpleNamespace(id="account-1")
+ mock_workflow = MagicMock()
+ mock_workflow.id = "wf-1"
+ mock_workflow.tenant_id = "tenant-1"
+ mock_workflow.environment_variables = []
+ mock_workflow.conversation_variables = []
+
+ # Mock node config
+ mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
+ mock_workflow.get_enclosing_node_type_and_id.return_value = None
+
+ # Mock class methods
+ monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
+ monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock())
+
+ # Mock workflow entry execution
+ mock_node_exec = MagicMock()
+ mock_node_exec.id = "exec-1"
+ mock_node_exec.process_data = {}
+ mock_run = MagicMock()
+ monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run)
+
+ # Mock execution handling
+ service._handle_single_step_result = MagicMock(return_value=mock_node_exec)
+
+ # Mock repository
+ mock_repo = MagicMock()
+ mock_repo.get_execution_by_id.return_value = mock_node_exec
+ mock_repo_factory = MagicMock(return_value=mock_repo)
+ monkeypatch.setattr(
+ workflow_service_module.DifyCoreRepositoryFactory,
+ "create_workflow_node_execution_repository",
+ mock_repo_factory,
+ )
+ service._node_execution_service_repo = mock_repo
+
+ # Set up node execution service repo mock to return our exec node
+ mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"}
+ mock_node_exec.node_id = "node-1"
+ mock_node_exec.node_type = "llm"
+
+ # Mock draft variable saver
+ mock_saver = MagicMock()
+ monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver))
+
+ # Mock DB
+ monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
+
+ monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
+
+ # Act
+ result = service.run_draft_workflow_node(
+ app_model=mock_app,
+ draft_workflow=mock_workflow,
+ node_id="node-1",
+ user_inputs={"input_val": "test"},
+ account=account,
+ )
+
+ # Assert
+ assert result == mock_node_exec
+ service._handle_single_step_result.assert_called_once()
+ mock_repo.save.assert_called_once_with(mock_node_exec)
+ mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"})
+
+ def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls):
+ """Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations."""
+ service = workflow_service
+ account = SimpleNamespace(id="account-1")
+ mock_workflow = MagicMock()
+ mock_workflow.tenant_id = "tenant-1"
+ mock_workflow.environment_variables = []
+ mock_workflow.conversation_variables = []
+ mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
+ mock_workflow.get_enclosing_node_type_and_id.return_value = None
+
+ monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
+ monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock())
+ monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock())
+
+ mock_node_exec = MagicMock()
+ mock_node_exec.id = "exec-invalid"
+ service._handle_single_step_result = MagicMock(return_value=mock_node_exec)
+
+ mock_repo = MagicMock()
+ mock_repo_factory = MagicMock(return_value=mock_repo)
+ monkeypatch.setattr(
+ workflow_service_module.DifyCoreRepositoryFactory,
+ "create_workflow_node_execution_repository",
+ mock_repo_factory,
+ )
+ service._node_execution_service_repo = mock_repo
+
+ # Simulate failure to retrieve the saved execution
+ mock_repo.get_execution_by_id.return_value = None
+
+ monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
+
+ monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
+
+ # Act & Assert
+ with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"):
+ service.run_draft_workflow_node(
+ app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account
+ )