diff --git a/api/tests/unit_tests/services/recommend_app/__init__.py b/api/tests/unit_tests/services/recommend_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py new file mode 100644 index 0000000000..770344aa39 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py @@ -0,0 +1,91 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + +SAMPLE_BUILTIN_DATA = { + "recommended_apps": { + "en-US": {"categories": ["writing"], "apps": [{"id": "app-1"}]}, + "zh-Hans": {"categories": ["search"], "apps": [{"id": "app-2"}]}, + }, + "app_details": { + "app-1": {"id": "app-1", "name": "Writer", "mode": "chat"}, + "app-2": {"id": "app-2", "name": "Searcher", "mode": "workflow"}, + }, +} + + +@pytest.fixture(autouse=True) +def _reset_cache(): + BuildInRecommendAppRetrieval.builtin_data = None + yield + BuildInRecommendAppRetrieval.builtin_data = None + + +class TestBuildInRecommendAppRetrieval: + def test_get_type(self): + retrieval = BuildInRecommendAppRetrieval() + assert retrieval.get_type() == RecommendAppType.BUILDIN + + def test_get_recommended_apps_and_categories_delegates(self): + with patch.object( + BuildInRecommendAppRetrieval, + "fetch_recommended_apps_from_builtin", + return_value={"apps": []}, + ) as mock_fetch: + retrieval = BuildInRecommendAppRetrieval() + result = retrieval.get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"apps": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + BuildInRecommendAppRetrieval, + "fetch_recommended_app_detail_from_builtin", + return_value={"id": "app-1"}, + ) as mock_fetch: + retrieval = BuildInRecommendAppRetrieval() + result = retrieval.get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + def test_get_builtin_data_reads_json_and_caches(self, tmp_path): + json_file = tmp_path / "constants" / "recommended_apps.json" + json_file.parent.mkdir(parents=True) + json_file.write_text(json.dumps(SAMPLE_BUILTIN_DATA)) + + mock_app = MagicMock() + mock_app.root_path = str(tmp_path) + + with patch( + "services.recommend_app.buildin.buildin_retrieval.current_app", + mock_app, + ): + first = BuildInRecommendAppRetrieval._get_builtin_data() + second = BuildInRecommendAppRetrieval._get_builtin_data() + + assert first == SAMPLE_BUILTIN_DATA + assert first is second + + def test_fetch_recommended_apps_from_builtin(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("en-US") + assert result == SAMPLE_BUILTIN_DATA["recommended_apps"]["en-US"] + + def test_fetch_recommended_apps_from_builtin_missing_language(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("fr-FR") + assert result == {} + + def test_fetch_recommended_app_detail_from_builtin(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("app-1") + assert result == {"id": "app-1", "name": "Writer", "mode": "chat"} + + def test_fetch_recommended_app_detail_from_builtin_missing(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("nonexistent") + assert result is None diff --git a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py new file mode 100644 index 0000000000..5d21665f75 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py @@ -0,0 +1,145 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + + +class TestDatabaseRecommendAppRetrieval: + def test_get_type(self): + assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE + + def test_get_recommended_apps_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_apps_from_db", + return_value={"recommended_apps": [], "categories": []}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"recommended_apps": [], "categories": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_app_detail_from_db", + return_value={"id": "app-1"}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + +class TestFetchRecommendedAppsFromDb: + def _make_recommended_app(self, app_id, category, is_public=True, has_site=True): + site = ( + SimpleNamespace( + description="desc", + copyright="copy", + privacy_policy="pp", + custom_disclaimer="cd", + ) + if has_site + else None + ) + app = ( + SimpleNamespace(is_public=is_public, site=site) + if is_public + else SimpleNamespace(is_public=False, site=site) + ) + return SimpleNamespace( + id=f"rec-{app_id}", + app=app, + app_id=app_id, + category=category, + position=1, + is_listed=True, + ) + + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_apps_and_sorted_categories(self, mock_db): + rec1 = self._make_recommended_app("a1", "writing") + rec2 = self._make_recommended_app("a2", "assistant") + mock_db.session.scalars.return_value.all.return_value = [rec1, rec2] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert len(result["recommended_apps"]) == 2 + assert result["categories"] == ["assistant", "writing"] + + @patch("services.recommend_app.database.database_retrieval.db") + def test_falls_back_to_default_language_when_empty(self, mock_db): + mock_db.session.scalars.return_value.all.side_effect = [ + [], + [self._make_recommended_app("a1", "chat")], + ] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") + + assert len(result["recommended_apps"]) == 1 + assert mock_db.session.scalars.call_count == 2 + + @patch("services.recommend_app.database.database_retrieval.db") + def test_skips_non_public_apps(self, mock_db): + rec = self._make_recommended_app("a1", "chat", is_public=False) + mock_db.session.scalars.return_value.all.return_value = [rec] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert result["recommended_apps"] == [] + + @patch("services.recommend_app.database.database_retrieval.db") + def test_skips_apps_without_site(self, mock_db): + rec = self._make_recommended_app("a1", "chat", has_site=False) + mock_db.session.scalars.return_value.all.return_value = [rec] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert result["recommended_apps"] == [] + + +class TestFetchRecommendedAppDetailFromDb: + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_none_when_not_listed(self, mock_db): + mock_db.session.query.return_value.where.return_value.first.return_value = None + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_none_when_app_not_public(self, mock_db, mock_dsl): + rec_chain = MagicMock() + rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") + app_chain = MagicMock() + app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False) + mock_db.session.query.side_effect = [rec_chain, app_chain] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_detail_on_success(self, mock_db, mock_dsl): + app_model = SimpleNamespace( + id="app-1", + name="My App", + icon="icon.png", + icon_background="#fff", + mode="chat", + is_public=True, + ) + rec_chain = MagicMock() + rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") + app_chain = MagicMock() + app_chain.where.return_value.first.return_value = app_model + mock_db.session.query.side_effect = [rec_chain, app_chain] + mock_dsl.export_dsl.return_value = "exported_yaml" + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result["id"] == "app-1" + assert result["name"] == "My App" + assert result["export_data"] == "exported_yaml" diff --git a/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py b/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py new file mode 100644 index 0000000000..036cba0cc0 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py @@ -0,0 +1,28 @@ +import pytest + +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class TestRecommendAppRetrievalFactory: + @pytest.mark.parametrize( + ("mode", "expected_class"), + [ + ("remote", RemoteRecommendAppRetrieval), + ("builtin", BuildInRecommendAppRetrieval), + ("db", DatabaseRecommendAppRetrieval), + ], + ) + def test_factory_returns_correct_class(self, mode, expected_class): + result = RecommendAppRetrievalFactory.get_recommend_app_factory(mode) + assert result is expected_class + + def test_factory_raises_for_unknown_mode(self): + with pytest.raises(ValueError, match="invalid fetch recommended apps mode"): + RecommendAppRetrievalFactory.get_recommend_app_factory("invalid_mode") + + def test_get_buildin_recommend_app_retrieval(self): + result = RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval() + assert result is BuildInRecommendAppRetrieval diff --git a/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py b/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py new file mode 100644 index 0000000000..08f72a6f77 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py @@ -0,0 +1,18 @@ +from services.recommend_app.recommend_app_type import RecommendAppType + + +def test_enum_values(): + assert RecommendAppType.REMOTE == "remote" + assert RecommendAppType.BUILDIN == "builtin" + assert RecommendAppType.DATABASE == "db" + + +def test_enum_membership(): + assert "remote" in RecommendAppType.__members__.values() + assert "builtin" in RecommendAppType.__members__.values() + assert "db" in RecommendAppType.__members__.values() + + +def test_enum_is_str(): + for member in RecommendAppType: + assert isinstance(member, str) diff --git a/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py new file mode 100644 index 0000000000..e322fbed4c --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from services.recommend_app.recommend_app_type import RecommendAppType +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class TestRemoteRecommendAppRetrieval: + def test_get_type(self): + assert RemoteRecommendAppRetrieval().get_type() == RecommendAppType.REMOTE + + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_app_detail_from_dify_official", + return_value={"id": "app-1"}, + ) + def test_get_recommend_app_detail_success(self, mock_fetch): + result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1") + assert result == {"id": "app-1"} + mock_fetch.assert_called_once_with("app-1") + + @patch( + "services.recommend_app.remote.remote_retrieval" + ".BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin", + return_value={"id": "fallback"}, + ) + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_app_detail_from_dify_official", + side_effect=ConnectionError("timeout"), + ) + def test_get_recommend_app_detail_falls_back_on_error(self, mock_fetch, mock_builtin): + result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1") + assert result == {"id": "fallback"} + mock_builtin.assert_called_once_with("app-1") + + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_apps_from_dify_official", + return_value={"recommended_apps": [], "categories": []}, + ) + def test_get_recommended_apps_success(self, mock_fetch): + result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + assert result == {"recommended_apps": [], "categories": []} + + @patch( + "services.recommend_app.remote.remote_retrieval" + ".BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin", + return_value={"recommended_apps": [{"id": "builtin"}]}, + ) + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_apps_from_dify_official", + side_effect=ValueError("server error"), + ) + def test_get_recommended_apps_falls_back_on_error(self, mock_fetch, mock_builtin): + result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + assert result == {"recommended_apps": [{"id": "builtin"}]} + + +class TestFetchFromDifyOfficial: + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_detail_returns_json_on_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"id": "app-1", "name": "Test"} + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1") + + assert result == {"id": "app-1", "name": "Test"} + mock_get.assert_called_once() + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_detail_returns_none_on_non_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_get.return_value = MagicMock(status_code=404) + + result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1") + + assert result is None + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_returns_sorted_categories_on_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = { + "recommended_apps": [], + "categories": ["writing", "agent", "chat"], + } + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert result["categories"] == ["agent", "chat", "writing"] + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_raises_on_non_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_get.return_value = MagicMock(status_code=500) + + with pytest.raises(ValueError, match="fetch recommended apps failed"): + RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_without_categories_key(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"recommended_apps": []} + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert "categories" not in result diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py new file mode 100644 index 0000000000..439d203c58 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -0,0 +1,455 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +MODULE = "services.tools.builtin_tools_manage_service" + + +def _mock_session(mock_session_cls): + """Helper: set up a Session context manager mock and return the inner session.""" + session = MagicMock() + mock_session_cls.return_value.__enter__ = MagicMock(return_value=session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + return session + + +class TestDeleteCustomOauthClientParams: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_deletes_and_returns_success(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + + result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google") + + assert result == {"result": "success"} + session.query.return_value.filter_by.return_value.delete.assert_called_once() + session.commit.assert_called_once() + + +class TestListBuiltinToolProviderTools: + @patch(f"{MODULE}.ToolLabelManager") + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.ToolManager") + def test_transforms_each_tool(self, mock_manager, mock_transform, mock_labels): + mock_controller = MagicMock() + mock_controller.get_tools.return_value = [MagicMock(), MagicMock()] + mock_manager.get_builtin_provider.return_value = mock_controller + mock_transform.convert_tool_entity_to_api_entity.return_value = MagicMock() + + result = BuiltinToolManageService.list_builtin_tool_provider_tools("tenant-1", "google") + + assert len(result) == 2 + + @patch(f"{MODULE}.ToolLabelManager") + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.ToolManager") + def test_empty_tools(self, mock_manager, mock_transform, mock_labels): + mock_controller = MagicMock() + mock_controller.get_tools.return_value = [] + mock_manager.get_builtin_provider.return_value = mock_controller + + assert BuiltinToolManageService.list_builtin_tool_provider_tools("t", "p") == [] + + +class TestGetBuiltinToolProviderInfo: + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider") + @patch(f"{MODULE}.ToolManager") + def test_raises_when_not_found(self, mock_manager, mock_get, mock_transform): + mock_get.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.get_builtin_tool_provider_info("t", "no") + + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider") + @patch(f"{MODULE}.ToolManager") + def test_clears_original_credentials(self, mock_manager, mock_get, mock_transform): + mock_get.return_value = MagicMock() + entity = MagicMock() + mock_transform.builtin_provider_to_user_provider.return_value = entity + + result = BuiltinToolManageService.get_builtin_tool_provider_info("t", "google") + + assert result.original_credentials == {} + + +class TestListBuiltinProviderCredentialsSchema: + @patch(f"{MODULE}.ToolManager") + def test_returns_schema(self, mock_manager): + mock_manager.get_builtin_provider.return_value.get_credentials_schema_by_type.return_value = [{"f": "k"}] + + result = BuiltinToolManageService.list_builtin_provider_credentials_schema("g", "api_key", "t") + + assert result == [{"f": "k"}] + + +class TestGetBuiltinToolProviderIcon: + @patch(f"{MODULE}.Path") + @patch(f"{MODULE}.ToolManager") + def test_returns_bytes_and_mime(self, mock_manager, mock_path): + mock_manager.get_hardcoded_provider_icon.return_value = ("/icon.svg", "image/svg+xml") + mock_path.return_value.read_bytes.return_value = b"" + + icon, mime = BuiltinToolManageService.get_builtin_tool_provider_icon("google") + + assert icon == b"" + assert mime == "image/svg+xml" + + +class TestIsOauthSystemClientExists: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_true_when_exists(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = MagicMock() + + assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_false_when_missing(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False + + +class TestIsOauthCustomClientEnabled: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_true_when_enabled(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True) + + assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_false_when_none(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False + + +class TestDeleteBuiltinToolProvider: + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc): + session = _mock_session(mock_session_cls) + session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id") + + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc): + session = _mock_session(mock_session_cls) + db_provider = MagicMock() + session.query.return_value.where.return_value.first.return_value = db_provider + mock_cache = MagicMock() + mock_enc.return_value = (MagicMock(), mock_cache) + + result = BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "c") + + assert result == {"result": "success"} + session.delete.assert_called_once_with(db_provider) + session.commit.assert_called_once() + mock_cache.delete.assert_called_once() + + +class TestSetDefaultProvider: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_not_found(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="provider not found"): + BuiltinToolManageService.set_default_provider("t", "u", "p", "id") + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_sets_default_and_clears_old(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + target = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = target + + result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id") + + assert result == {"result": "success"} + assert target.is_default is True + session.commit.assert_called_once() + + +class TestUpdateBuiltinToolProvider: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c") + + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.CredentialType") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc): + session = _mock_session(mock_session_cls) + db_provider = MagicMock(credential_type="api_key", credentials="{}") + session.query.return_value.where.return_value.first.return_value = db_provider + + mock_cred_instance = MagicMock() + mock_cred_instance.is_editable.return_value = True + mock_cred_instance.is_validate_allowed.return_value = False + mock_cred_type.of.return_value = mock_cred_instance + + mock_controller = MagicMock(need_credentials=True) + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"key": "old"} + mock_encrypter.encrypt.return_value = {"key": "new"} + mock_cache = MagicMock() + mock_enc.return_value = (mock_encrypter, mock_cache) + + result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"}) + + assert result == {"result": "success"} + session.commit.assert_called_once() + mock_cache.delete.assert_called_once() + + +class TestGetOauthClientSchema: + @patch(f"{MODULE}.BuiltinToolManageService.get_custom_oauth_client_params", return_value={}) + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_system_client_exists", return_value=False) + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=True) + @patch(f"{MODULE}.dify_config") + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.ToolManager") + def test_returns_schema_dict(self, mock_tm, mock_plugin, mock_config, mock_enabled, mock_sys, mock_params): + mock_config.CONSOLE_API_URL = "https://api.example.com" + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + result = BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema("t", "google") + + assert "schema" in result + assert result["is_oauth_custom_client_enabled"] is True + assert "redirect_uri" in result + + +class TestGetOauthClient: + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.create_provider_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_user_client_params_when_exists( + self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin + ): + session = _mock_session(mock_session_cls) + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"client_id": "id", "client_secret": "secret"} + mock_create_enc.return_value = (mock_encrypter, MagicMock()) + + user_client = MagicMock(oauth_params='{"encrypted": "data"}') + session.query.return_value.filter_by.return_value.first.return_value = user_client + + result = BuiltinToolManageService.get_oauth_client("t", "google") + + assert result == {"client_id": "id", "client_secret": "secret"} + + @patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"}) + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.create_provider_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_falls_back_to_system_client( + self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin, mock_decrypt + ): + session = _mock_session(mock_session_cls) + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_create_enc.return_value = (MagicMock(), MagicMock()) + + system_client = MagicMock(encrypted_oauth_params="enc") + session.query.return_value.filter_by.return_value.first.side_effect = [ + None, # user client + system_client, # system client + ] + + result = BuiltinToolManageService.get_oauth_client("t", "google") + + assert result == {"sys_key": "sys_val"} + + +class TestSaveCustomOauthClientParams: + def test_returns_early_when_no_params(self): + result = BuiltinToolManageService.save_custom_oauth_client_params("t", "p") + assert result == {"result": "success"} + + @patch(f"{MODULE}.ToolManager") + def test_raises_when_provider_not_found(self, mock_tm): + mock_tm.get_builtin_provider.return_value = None + + with pytest.raises((ValueError, Exception), match="not found|Provider"): + BuiltinToolManageService.save_custom_oauth_client_params("t", "p", enable_oauth_custom_client=True) + + +class TestGetCustomOauthClientParams: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_empty_when_none(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p") + + assert result == {} + + +class TestGetBuiltinToolProviderCredentialInfo: + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=False) + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_tool_provider_credentials", return_value=[]) + @patch(f"{MODULE}.ToolManager") + def test_returns_credential_info(self, mock_tm, mock_creds, mock_oauth): + mock_tm.get_builtin_provider.return_value.get_supported_credential_types.return_value = ["api-key"] + + result = BuiltinToolManageService.get_builtin_tool_provider_credential_info("t", "google") + + assert result.credentials == [] + assert result.supported_credential_types == ["api-key"] + assert result.is_oauth_custom_client_enabled is False + + +class TestGetBuiltinToolProviderCredentials: + @patch(f"{MODULE}.db") + def test_returns_empty_when_no_providers(self, mock_db): + mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None) + mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google") + + assert result == [] + + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.db") + def test_returns_credential_entities(self, mock_db, mock_tm, mock_enc, mock_transform): + mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None) + mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) + + provider = MagicMock(provider="google", is_default=False) + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider] + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"key": "decrypted"} + mock_encrypter.mask_plugin_credentials.return_value = {"key": "***"} + mock_enc.return_value = (mock_encrypter, MagicMock()) + + credential_entity = MagicMock() + mock_transform.convert_builtin_provider_to_credential_entity.return_value = credential_entity + + result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google") + + assert len(result) == 1 + assert result[0] is credential_entity + assert provider.is_default is True + + +class TestGetBuiltinProvider: + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_none_when_not_found(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.return_value.provider_name = "google" + mock_prov_id.return_value.organization = "langgenius" + session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + result = BuiltinToolManageService.get_builtin_provider("google", "t") + + assert result is None + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_provider_for_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.return_value.provider_name = "google" + mock_prov_id.return_value.organization = "langgenius" + db_provider = MagicMock(provider="google") + mock_prov_id_result = MagicMock() + mock_prov_id_result.to_string.return_value = "langgenius/google/google" + + def prov_id_side_effect(name): + m = MagicMock() + m.provider_name = "google" + m.organization = "langgenius" + m.to_string.return_value = "langgenius/google/google" + m.plugin_id = "langgenius/google" + return m + + mock_prov_id.side_effect = prov_id_side_effect + session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider + + result = BuiltinToolManageService.get_builtin_provider("google", "t") + + assert result is db_provider + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_provider_for_non_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + + def prov_id_side_effect(name): + m = MagicMock() + m.provider_name = "custom-tool" + m.organization = "third-party" + m.to_string.return_value = "third-party/custom/custom-tool" + m.plugin_id = "third-party/custom" + return m + + mock_prov_id.side_effect = prov_id_side_effect + db_provider = MagicMock(provider="third-party/custom/custom-tool") + session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider + + result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t") + + assert result is db_provider + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_falls_back_on_exception(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.side_effect = Exception("parse error") + fallback = MagicMock() + session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback + + result = BuiltinToolManageService.get_builtin_provider("old-provider", "t") + + assert result is fallback diff --git a/api/tests/unit_tests/services/tools/test_tool_labels_service.py b/api/tests/unit_tests/services/tools/test_tool_labels_service.py new file mode 100644 index 0000000000..6acdbb7901 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_tool_labels_service.py @@ -0,0 +1,21 @@ +from services.tools.tool_labels_service import ToolLabelsService + + +def test_list_tool_labels_returns_default_labels(): + result = ToolLabelsService.list_tool_labels() + assert isinstance(result, list) + assert len(result) > 0 + + +def test_list_tool_labels_items_are_tool_labels(): + from core.tools.entities.tool_entities import ToolLabel + + result = ToolLabelsService.list_tool_labels() + for label in result: + assert isinstance(label, ToolLabel) + + +def test_list_tool_labels_matches_default_values(): + from core.tools.entities.values import default_tool_labels + + assert ToolLabelsService.list_tool_labels() is default_tool_labels diff --git a/api/tests/unit_tests/services/tools/test_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_tools_manage_service.py new file mode 100644 index 0000000000..73ac9a10c6 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_tools_manage_service.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +from services.tools.tools_manage_service import ToolCommonService + + +class TestToolCommonService: + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_transforms_and_returns(self, mock_manager, mock_transform): + mock_provider1 = MagicMock() + mock_provider1.to_dict.return_value = {"name": "provider1"} + mock_provider2 = MagicMock() + mock_provider2.to_dict.return_value = {"name": "provider2"} + mock_manager.list_providers_from_api.return_value = [mock_provider1, mock_provider2] + + result = ToolCommonService.list_tool_providers("user-1", "tenant-1") + + mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", None) + assert mock_transform.repack_provider.call_count == 2 + assert result == [{"name": "provider1"}, {"name": "provider2"}] + + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_with_type_filter(self, mock_manager, mock_transform): + mock_manager.list_providers_from_api.return_value = [] + + result = ToolCommonService.list_tool_providers("user-1", "tenant-1", typ="builtin") + + mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", "builtin") + assert result == [] + + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_empty(self, mock_manager, mock_transform): + mock_manager.list_providers_from_api.return_value = [] + + result = ToolCommonService.list_tool_providers("u", "t") + + assert result == [] + mock_transform.repack_provider.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py b/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py new file mode 100644 index 0000000000..bbfc1cc294 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py @@ -0,0 +1,110 @@ +from unittest.mock import patch + +import pytest + +from services.workflow.queue_dispatcher import ( + BaseQueueDispatcher, + ProfessionalQueueDispatcher, + QueueDispatcherManager, + QueuePriority, + SandboxQueueDispatcher, + TeamQueueDispatcher, +) + + +class TestQueuePriority: + def test_priority_values(self): + assert QueuePriority.PROFESSIONAL == "workflow_professional" + assert QueuePriority.TEAM == "workflow_team" + assert QueuePriority.SANDBOX == "workflow_sandbox" + + +class TestDispatchers: + def test_professional_dispatcher(self): + d = ProfessionalQueueDispatcher() + assert d.get_queue_name() == QueuePriority.PROFESSIONAL + assert d.get_priority() == 100 + + def test_team_dispatcher(self): + d = TeamQueueDispatcher() + assert d.get_queue_name() == QueuePriority.TEAM + assert d.get_priority() == 50 + + def test_sandbox_dispatcher(self): + d = SandboxQueueDispatcher() + assert d.get_queue_name() == QueuePriority.SANDBOX + assert d.get_priority() == 10 + + def test_base_dispatcher_is_abstract(self): + with pytest.raises(TypeError): + BaseQueueDispatcher() + + +class TestQueueDispatcherManager: + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_professional_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "professional"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, ProfessionalQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_team_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "team"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, TeamQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_sandbox_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "sandbox"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_unknown_plan_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "enterprise"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_service_failure_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.side_effect = Exception("billing unavailable") + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_disabled_defaults_to_team(self, mock_config): + mock_config.BILLING_ENABLED = False + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, TeamQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_missing_subscription_key_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) diff --git a/api/tests/unit_tests/services/workflow/test_scheduler.py b/api/tests/unit_tests/services/workflow/test_scheduler.py new file mode 100644 index 0000000000..90b6cb2d8b --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_scheduler.py @@ -0,0 +1,89 @@ +import pytest + +from services.workflow.entities import WorkflowScheduleCFSPlanEntity +from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand + + +class TestSchedulerCommand: + def test_enum_values(self): + assert SchedulerCommand.RESOURCE_LIMIT_REACHED == "resource_limit_reached" + assert SchedulerCommand.NONE == "none" + + def test_enum_is_str(self): + for member in SchedulerCommand: + assert isinstance(member, str) + + +class TestCFSPlanScheduler: + def test_stores_plan(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop, + granularity=-1, + ) + + class ConcretePlanScheduler(CFSPlanScheduler): + def can_schedule(self): + return SchedulerCommand.NONE + + scheduler = ConcretePlanScheduler(plan) + + assert scheduler.plan is plan + assert scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.Nop + assert scheduler.plan.granularity == -1 + + def test_cannot_instantiate_abstract(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=10, + ) + with pytest.raises(TypeError): + CFSPlanScheduler(plan) + + def test_concrete_subclass_can_schedule(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=5, + ) + + class TimedScheduler(CFSPlanScheduler): + def can_schedule(self): + if self.plan.granularity > 0: + return SchedulerCommand.NONE + return SchedulerCommand.RESOURCE_LIMIT_REACHED + + scheduler = TimedScheduler(plan) + assert scheduler.can_schedule() == SchedulerCommand.NONE + + def test_concrete_subclass_resource_limit(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=-1, + ) + + class TimedScheduler(CFSPlanScheduler): + def can_schedule(self): + if self.plan.granularity > 0: + return SchedulerCommand.NONE + return SchedulerCommand.RESOURCE_LIMIT_REACHED + + scheduler = TimedScheduler(plan) + assert scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED + + +class TestWorkflowScheduleCFSPlanEntity: + def test_strategy_values(self): + assert WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice == "time-slice" + assert WorkflowScheduleCFSPlanEntity.Strategy.Nop == "nop" + + def test_default_granularity(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop, + ) + assert plan.granularity == -1 + + def test_explicit_granularity(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=100, + ) + assert plan.granularity == 100 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index e428c603a4..3953248c47 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -41,6 +41,23 @@ class TestWorkflowService: workflows.append(workflow) return workflows + @pytest.fixture + def dummy_session_cls(self): + class DummySession: + def __init__(self, *args, **kwargs): + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return nullcontext() + + return DummySession + def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app): mock_app.workflow_id = None mock_session = MagicMock() @@ -170,7 +187,10 @@ class TestWorkflowService: mock_session.scalars.assert_called_once() def test_submit_human_input_form_preview_uses_rendered_content( - self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch + self, + workflow_service: WorkflowService, + monkeypatch: pytest.MonkeyPatch, + dummy_session_cls, ) -> None: service = workflow_service node_data = HumanInputNodeData( @@ -197,19 +217,6 @@ class TestWorkflowService: saved_outputs: dict[str, object] = {} - class DummySession: - def __init__(self, *args, **kwargs): - self.commit = MagicMock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return nullcontext() - class DummySaver: def __init__(self, *args, **kwargs): pass @@ -217,7 +224,7 @@ class TestWorkflowService: def save(self, outputs, process_data): saved_outputs.update(outputs) - monkeypatch.setattr(workflow_service_module, "Session", DummySession) + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) @@ -278,7 +285,6 @@ class TestWorkflowService: app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") account = SimpleNamespace(id="account-1") - with pytest.raises(ValueError) as exc_info: service.submit_human_input_form_preview( app_model=app_model, @@ -290,3 +296,115 @@ class TestWorkflowService: ) assert "Missing required inputs" in str(exc_info.value) + + def test_run_draft_workflow_node_successful_behavior( + self, workflow_service, mock_app, monkeypatch, dummy_session_cls + ): + """Behavior: When a basic workflow node runs, it correctly sets up context, + executes the node, and saves outputs.""" + service = workflow_service + account = SimpleNamespace(id="account-1") + mock_workflow = MagicMock() + mock_workflow.id = "wf-1" + mock_workflow.tenant_id = "tenant-1" + mock_workflow.environment_variables = [] + mock_workflow.conversation_variables = [] + + # Mock node config + mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}} + mock_workflow.get_enclosing_node_type_and_id.return_value = None + + # Mock class methods + monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) + monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) + + # Mock workflow entry execution + mock_node_exec = MagicMock() + mock_node_exec.id = "exec-1" + mock_node_exec.process_data = {} + mock_run = MagicMock() + monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run) + + # Mock execution handling + service._handle_single_step_result = MagicMock(return_value=mock_node_exec) + + # Mock repository + mock_repo = MagicMock() + mock_repo.get_execution_by_id.return_value = mock_node_exec + mock_repo_factory = MagicMock(return_value=mock_repo) + monkeypatch.setattr( + workflow_service_module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + mock_repo_factory, + ) + service._node_execution_service_repo = mock_repo + + # Set up node execution service repo mock to return our exec node + mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"} + mock_node_exec.node_id = "node-1" + mock_node_exec.node_type = "llm" + + # Mock draft variable saver + mock_saver = MagicMock() + monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver)) + + # Mock DB + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) + + # Act + result = service.run_draft_workflow_node( + app_model=mock_app, + draft_workflow=mock_workflow, + node_id="node-1", + user_inputs={"input_val": "test"}, + account=account, + ) + + # Assert + assert result == mock_node_exec + service._handle_single_step_result.assert_called_once() + mock_repo.save.assert_called_once_with(mock_node_exec) + mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"}) + + def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls): + """Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations.""" + service = workflow_service + account = SimpleNamespace(id="account-1") + mock_workflow = MagicMock() + mock_workflow.tenant_id = "tenant-1" + mock_workflow.environment_variables = [] + mock_workflow.conversation_variables = [] + mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}} + mock_workflow.get_enclosing_node_type_and_id.return_value = None + + monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) + monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) + monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock()) + + mock_node_exec = MagicMock() + mock_node_exec.id = "exec-invalid" + service._handle_single_step_result = MagicMock(return_value=mock_node_exec) + + mock_repo = MagicMock() + mock_repo_factory = MagicMock(return_value=mock_repo) + monkeypatch.setattr( + workflow_service_module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + mock_repo_factory, + ) + service._node_execution_service_repo = mock_repo + + # Simulate failure to retrieve the saved execution + mock_repo.get_execution_by_id.return_value = None + + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) + + # Act & Assert + with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"): + service.run_draft_workflow_node( + app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account + )