mirror of https://github.com/langgenius/dify.git
test: add UTs for api/services recommend_app, tools, workflow (#32645)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
07e19c0748
commit
7d2054d4f4
|
|
@ -0,0 +1,91 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
||||
SAMPLE_BUILTIN_DATA = {
|
||||
"recommended_apps": {
|
||||
"en-US": {"categories": ["writing"], "apps": [{"id": "app-1"}]},
|
||||
"zh-Hans": {"categories": ["search"], "apps": [{"id": "app-2"}]},
|
||||
},
|
||||
"app_details": {
|
||||
"app-1": {"id": "app-1", "name": "Writer", "mode": "chat"},
|
||||
"app-2": {"id": "app-2", "name": "Searcher", "mode": "workflow"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_cache():
|
||||
BuildInRecommendAppRetrieval.builtin_data = None
|
||||
yield
|
||||
BuildInRecommendAppRetrieval.builtin_data = None
|
||||
|
||||
|
||||
class TestBuildInRecommendAppRetrieval:
|
||||
def test_get_type(self):
|
||||
retrieval = BuildInRecommendAppRetrieval()
|
||||
assert retrieval.get_type() == RecommendAppType.BUILDIN
|
||||
|
||||
def test_get_recommended_apps_and_categories_delegates(self):
|
||||
with patch.object(
|
||||
BuildInRecommendAppRetrieval,
|
||||
"fetch_recommended_apps_from_builtin",
|
||||
return_value={"apps": []},
|
||||
) as mock_fetch:
|
||||
retrieval = BuildInRecommendAppRetrieval()
|
||||
result = retrieval.get_recommended_apps_and_categories("en-US")
|
||||
mock_fetch.assert_called_once_with("en-US")
|
||||
assert result == {"apps": []}
|
||||
|
||||
def test_get_recommend_app_detail_delegates(self):
|
||||
with patch.object(
|
||||
BuildInRecommendAppRetrieval,
|
||||
"fetch_recommended_app_detail_from_builtin",
|
||||
return_value={"id": "app-1"},
|
||||
) as mock_fetch:
|
||||
retrieval = BuildInRecommendAppRetrieval()
|
||||
result = retrieval.get_recommend_app_detail("app-1")
|
||||
mock_fetch.assert_called_once_with("app-1")
|
||||
assert result == {"id": "app-1"}
|
||||
|
||||
def test_get_builtin_data_reads_json_and_caches(self, tmp_path):
|
||||
json_file = tmp_path / "constants" / "recommended_apps.json"
|
||||
json_file.parent.mkdir(parents=True)
|
||||
json_file.write_text(json.dumps(SAMPLE_BUILTIN_DATA))
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.root_path = str(tmp_path)
|
||||
|
||||
with patch(
|
||||
"services.recommend_app.buildin.buildin_retrieval.current_app",
|
||||
mock_app,
|
||||
):
|
||||
first = BuildInRecommendAppRetrieval._get_builtin_data()
|
||||
second = BuildInRecommendAppRetrieval._get_builtin_data()
|
||||
|
||||
assert first == SAMPLE_BUILTIN_DATA
|
||||
assert first is second
|
||||
|
||||
def test_fetch_recommended_apps_from_builtin(self):
|
||||
BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("en-US")
|
||||
assert result == SAMPLE_BUILTIN_DATA["recommended_apps"]["en-US"]
|
||||
|
||||
def test_fetch_recommended_apps_from_builtin_missing_language(self):
|
||||
BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("fr-FR")
|
||||
assert result == {}
|
||||
|
||||
def test_fetch_recommended_app_detail_from_builtin(self):
|
||||
BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("app-1")
|
||||
assert result == {"id": "app-1", "name": "Writer", "mode": "chat"}
|
||||
|
||||
def test_fetch_recommended_app_detail_from_builtin_missing(self):
|
||||
BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("nonexistent")
|
||||
assert result is None
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
||||
|
||||
class TestDatabaseRecommendAppRetrieval:
|
||||
def test_get_type(self):
|
||||
assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE
|
||||
|
||||
def test_get_recommended_apps_delegates(self):
|
||||
with patch.object(
|
||||
DatabaseRecommendAppRetrieval,
|
||||
"fetch_recommended_apps_from_db",
|
||||
return_value={"recommended_apps": [], "categories": []},
|
||||
) as mock_fetch:
|
||||
result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
|
||||
mock_fetch.assert_called_once_with("en-US")
|
||||
assert result == {"recommended_apps": [], "categories": []}
|
||||
|
||||
def test_get_recommend_app_detail_delegates(self):
|
||||
with patch.object(
|
||||
DatabaseRecommendAppRetrieval,
|
||||
"fetch_recommended_app_detail_from_db",
|
||||
return_value={"id": "app-1"},
|
||||
) as mock_fetch:
|
||||
result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1")
|
||||
mock_fetch.assert_called_once_with("app-1")
|
||||
assert result == {"id": "app-1"}
|
||||
|
||||
|
||||
class TestFetchRecommendedAppsFromDb:
|
||||
def _make_recommended_app(self, app_id, category, is_public=True, has_site=True):
|
||||
site = (
|
||||
SimpleNamespace(
|
||||
description="desc",
|
||||
copyright="copy",
|
||||
privacy_policy="pp",
|
||||
custom_disclaimer="cd",
|
||||
)
|
||||
if has_site
|
||||
else None
|
||||
)
|
||||
app = (
|
||||
SimpleNamespace(is_public=is_public, site=site)
|
||||
if is_public
|
||||
else SimpleNamespace(is_public=False, site=site)
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id=f"rec-{app_id}",
|
||||
app=app,
|
||||
app_id=app_id,
|
||||
category=category,
|
||||
position=1,
|
||||
is_listed=True,
|
||||
)
|
||||
|
||||
@patch("services.recommend_app.database.database_retrieval.db")
|
||||
def test_returns_apps_and_sorted_categories(self, mock_db):
|
||||
rec1 = self._make_recommended_app("a1", "writing")
|
||||
rec2 = self._make_recommended_app("a2", "assistant")
|
||||
mock_db.session.scalars.return_value.all.return_value = [rec1, rec2]
|
||||
|
||||
result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
|
||||
|
||||
assert len(result["recommended_apps"]) == 2
|
||||
assert result["categories"] == ["assistant", "writing"]
|
||||
|
||||
@patch("services.recommend_app.database.database_retrieval.db")
|
||||
def test_falls_back_to_default_language_when_empty(self, mock_db):
|
||||
mock_db.session.scalars.return_value.all.side_effect = [
|
||||
[],
|
||||
[self._make_recommended_app("a1", "chat")],
|
||||
]
|
||||
|
||||
result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR")
|
||||
|
||||
assert len(result["recommended_apps"]) == 1
|
||||
assert mock_db.session.scalars.call_count == 2
|
||||
|
||||
@patch("services.recommend_app.database.database_retrieval.db")
|
||||
def test_skips_non_public_apps(self, mock_db):
|
||||
rec = self._make_recommended_app("a1", "chat", is_public=False)
|
||||
mock_db.session.scalars.return_value.all.return_value = [rec]
|
||||
|
||||
result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
|
||||
|
||||
assert result["recommended_apps"] == []
|
||||
|
||||
@patch("services.recommend_app.database.database_retrieval.db")
|
||||
def test_skips_apps_without_site(self, mock_db):
|
||||
rec = self._make_recommended_app("a1", "chat", has_site=False)
|
||||
mock_db.session.scalars.return_value.all.return_value = [rec]
|
||||
|
||||
result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US")
|
||||
|
||||
assert result["recommended_apps"] == []
|
||||
|
||||
|
||||
class TestFetchRecommendedAppDetailFromDb:
|
||||
@patch("services.recommend_app.database.database_retrieval.db")
|
||||
def test_returns_none_when_not_listed(self, mock_db):
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("services.recommend_app.database.database_retrieval.AppDslService")
|
||||
@patch("services.recommend_app.database.database_retrieval.db")
|
||||
def test_returns_none_when_app_not_public(self, mock_db, mock_dsl):
|
||||
rec_chain = MagicMock()
|
||||
rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1")
|
||||
app_chain = MagicMock()
|
||||
app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False)
|
||||
mock_db.session.query.side_effect = [rec_chain, app_chain]
|
||||
|
||||
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("services.recommend_app.database.database_retrieval.AppDslService")
|
||||
@patch("services.recommend_app.database.database_retrieval.db")
|
||||
def test_returns_detail_on_success(self, mock_db, mock_dsl):
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
name="My App",
|
||||
icon="icon.png",
|
||||
icon_background="#fff",
|
||||
mode="chat",
|
||||
is_public=True,
|
||||
)
|
||||
rec_chain = MagicMock()
|
||||
rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1")
|
||||
app_chain = MagicMock()
|
||||
app_chain.where.return_value.first.return_value = app_model
|
||||
mock_db.session.query.side_effect = [rec_chain, app_chain]
|
||||
mock_dsl.export_dsl.return_value = "exported_yaml"
|
||||
|
||||
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1")
|
||||
|
||||
assert result["id"] == "app-1"
|
||||
assert result["name"] == "My App"
|
||||
assert result["export_data"] == "exported_yaml"
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import pytest
|
||||
|
||||
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
|
||||
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
|
||||
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
||||
from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval
|
||||
|
||||
|
||||
class TestRecommendAppRetrievalFactory:
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "expected_class"),
|
||||
[
|
||||
("remote", RemoteRecommendAppRetrieval),
|
||||
("builtin", BuildInRecommendAppRetrieval),
|
||||
("db", DatabaseRecommendAppRetrieval),
|
||||
],
|
||||
)
|
||||
def test_factory_returns_correct_class(self, mode, expected_class):
|
||||
result = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)
|
||||
assert result is expected_class
|
||||
|
||||
def test_factory_raises_for_unknown_mode(self):
|
||||
with pytest.raises(ValueError, match="invalid fetch recommended apps mode"):
|
||||
RecommendAppRetrievalFactory.get_recommend_app_factory("invalid_mode")
|
||||
|
||||
def test_get_buildin_recommend_app_retrieval(self):
|
||||
result = RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval()
|
||||
assert result is BuildInRecommendAppRetrieval
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
||||
|
||||
def test_enum_values():
|
||||
assert RecommendAppType.REMOTE == "remote"
|
||||
assert RecommendAppType.BUILDIN == "builtin"
|
||||
assert RecommendAppType.DATABASE == "db"
|
||||
|
||||
|
||||
def test_enum_membership():
|
||||
assert "remote" in RecommendAppType.__members__.values()
|
||||
assert "builtin" in RecommendAppType.__members__.values()
|
||||
assert "db" in RecommendAppType.__members__.values()
|
||||
|
||||
|
||||
def test_enum_is_str():
|
||||
for member in RecommendAppType:
|
||||
assert isinstance(member, str)
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval
|
||||
|
||||
|
||||
class TestRemoteRecommendAppRetrieval:
|
||||
def test_get_type(self):
|
||||
assert RemoteRecommendAppRetrieval().get_type() == RecommendAppType.REMOTE
|
||||
|
||||
@patch.object(
|
||||
RemoteRecommendAppRetrieval,
|
||||
"fetch_recommended_app_detail_from_dify_official",
|
||||
return_value={"id": "app-1"},
|
||||
)
|
||||
def test_get_recommend_app_detail_success(self, mock_fetch):
|
||||
result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1")
|
||||
assert result == {"id": "app-1"}
|
||||
mock_fetch.assert_called_once_with("app-1")
|
||||
|
||||
@patch(
|
||||
"services.recommend_app.remote.remote_retrieval"
|
||||
".BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin",
|
||||
return_value={"id": "fallback"},
|
||||
)
|
||||
@patch.object(
|
||||
RemoteRecommendAppRetrieval,
|
||||
"fetch_recommended_app_detail_from_dify_official",
|
||||
side_effect=ConnectionError("timeout"),
|
||||
)
|
||||
def test_get_recommend_app_detail_falls_back_on_error(self, mock_fetch, mock_builtin):
|
||||
result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1")
|
||||
assert result == {"id": "fallback"}
|
||||
mock_builtin.assert_called_once_with("app-1")
|
||||
|
||||
@patch.object(
|
||||
RemoteRecommendAppRetrieval,
|
||||
"fetch_recommended_apps_from_dify_official",
|
||||
return_value={"recommended_apps": [], "categories": []},
|
||||
)
|
||||
def test_get_recommended_apps_success(self, mock_fetch):
|
||||
result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
|
||||
assert result == {"recommended_apps": [], "categories": []}
|
||||
|
||||
@patch(
|
||||
"services.recommend_app.remote.remote_retrieval"
|
||||
".BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin",
|
||||
return_value={"recommended_apps": [{"id": "builtin"}]},
|
||||
)
|
||||
@patch.object(
|
||||
RemoteRecommendAppRetrieval,
|
||||
"fetch_recommended_apps_from_dify_official",
|
||||
side_effect=ValueError("server error"),
|
||||
)
|
||||
def test_get_recommended_apps_falls_back_on_error(self, mock_fetch, mock_builtin):
|
||||
result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
|
||||
assert result == {"recommended_apps": [{"id": "builtin"}]}
|
||||
|
||||
|
||||
class TestFetchFromDifyOfficial:
|
||||
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
|
||||
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
|
||||
def test_detail_returns_json_on_200(self, mock_get, mock_config):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
|
||||
mock_response = MagicMock(status_code=200)
|
||||
mock_response.json.return_value = {"id": "app-1", "name": "Test"}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1")
|
||||
|
||||
assert result == {"id": "app-1", "name": "Test"}
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
|
||||
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
|
||||
def test_detail_returns_none_on_non_200(self, mock_get, mock_config):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
|
||||
mock_get.return_value = MagicMock(status_code=404)
|
||||
|
||||
result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
|
||||
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
|
||||
def test_apps_returns_sorted_categories_on_200(self, mock_get, mock_config):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
|
||||
mock_response = MagicMock(status_code=200)
|
||||
mock_response.json.return_value = {
|
||||
"recommended_apps": [],
|
||||
"categories": ["writing", "agent", "chat"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
|
||||
|
||||
assert result["categories"] == ["agent", "chat", "writing"]
|
||||
|
||||
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
|
||||
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
|
||||
def test_apps_raises_on_non_200(self, mock_get, mock_config):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
|
||||
mock_get.return_value = MagicMock(status_code=500)
|
||||
|
||||
with pytest.raises(ValueError, match="fetch recommended apps failed"):
|
||||
RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
|
||||
|
||||
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
|
||||
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
|
||||
def test_apps_without_categories_key(self, mock_get, mock_config):
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
|
||||
mock_response = MagicMock(status_code=200)
|
||||
mock_response.json.return_value = {"recommended_apps": []}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
|
||||
|
||||
assert "categories" not in result
|
||||
|
|
@ -0,0 +1,455 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
MODULE = "services.tools.builtin_tools_manage_service"
|
||||
|
||||
|
||||
def _mock_session(mock_session_cls):
|
||||
"""Helper: set up a Session context manager mock and return the inner session."""
|
||||
session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return session
|
||||
|
||||
|
||||
class TestDeleteCustomOauthClientParams:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_deletes_and_returns_success(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
|
||||
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestListBuiltinToolProviderTools:
|
||||
@patch(f"{MODULE}.ToolLabelManager")
|
||||
@patch(f"{MODULE}.ToolTransformService")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_transforms_each_tool(self, mock_manager, mock_transform, mock_labels):
|
||||
mock_controller = MagicMock()
|
||||
mock_controller.get_tools.return_value = [MagicMock(), MagicMock()]
|
||||
mock_manager.get_builtin_provider.return_value = mock_controller
|
||||
mock_transform.convert_tool_entity_to_api_entity.return_value = MagicMock()
|
||||
|
||||
result = BuiltinToolManageService.list_builtin_tool_provider_tools("tenant-1", "google")
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
@patch(f"{MODULE}.ToolLabelManager")
|
||||
@patch(f"{MODULE}.ToolTransformService")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_empty_tools(self, mock_manager, mock_transform, mock_labels):
|
||||
mock_controller = MagicMock()
|
||||
mock_controller.get_tools.return_value = []
|
||||
mock_manager.get_builtin_provider.return_value = mock_controller
|
||||
|
||||
assert BuiltinToolManageService.list_builtin_tool_provider_tools("t", "p") == []
|
||||
|
||||
|
||||
class TestGetBuiltinToolProviderInfo:
|
||||
@patch(f"{MODULE}.ToolTransformService")
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_raises_when_not_found(self, mock_manager, mock_get, mock_transform):
|
||||
mock_get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
BuiltinToolManageService.get_builtin_tool_provider_info("t", "no")
|
||||
|
||||
@patch(f"{MODULE}.ToolTransformService")
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_clears_original_credentials(self, mock_manager, mock_get, mock_transform):
|
||||
mock_get.return_value = MagicMock()
|
||||
entity = MagicMock()
|
||||
mock_transform.builtin_provider_to_user_provider.return_value = entity
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_tool_provider_info("t", "google")
|
||||
|
||||
assert result.original_credentials == {}
|
||||
|
||||
|
||||
class TestListBuiltinProviderCredentialsSchema:
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_returns_schema(self, mock_manager):
|
||||
mock_manager.get_builtin_provider.return_value.get_credentials_schema_by_type.return_value = [{"f": "k"}]
|
||||
|
||||
result = BuiltinToolManageService.list_builtin_provider_credentials_schema("g", "api_key", "t")
|
||||
|
||||
assert result == [{"f": "k"}]
|
||||
|
||||
|
||||
class TestGetBuiltinToolProviderIcon:
|
||||
@patch(f"{MODULE}.Path")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_returns_bytes_and_mime(self, mock_manager, mock_path):
|
||||
mock_manager.get_hardcoded_provider_icon.return_value = ("/icon.svg", "image/svg+xml")
|
||||
mock_path.return_value.read_bytes.return_value = b"<svg/>"
|
||||
|
||||
icon, mime = BuiltinToolManageService.get_builtin_tool_provider_icon("google")
|
||||
|
||||
assert icon == b"<svg/>"
|
||||
assert mime == "image/svg+xml"
|
||||
|
||||
|
||||
class TestIsOauthSystemClientExists:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_true_when_exists(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = MagicMock()
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
|
||||
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_false_when_missing(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
|
||||
|
||||
|
||||
class TestIsOauthCustomClientEnabled:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_true_when_enabled(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True)
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
|
||||
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_false_when_none(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
|
||||
|
||||
|
||||
class TestDeleteBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
|
||||
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
db_provider = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mock_cache = MagicMock()
|
||||
mock_enc.return_value = (MagicMock(), mock_cache)
|
||||
|
||||
result = BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "c")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.delete.assert_called_once_with(db_provider)
|
||||
session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestSetDefaultProvider:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="provider not found"):
|
||||
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
target = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = target
|
||||
|
||||
result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
assert target.is_default is True
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
|
||||
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.CredentialType")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
|
||||
mock_cred_instance = MagicMock()
|
||||
mock_cred_instance.is_editable.return_value = True
|
||||
mock_cred_instance.is_validate_allowed.return_value = False
|
||||
mock_cred_type.of.return_value = mock_cred_instance
|
||||
|
||||
mock_controller = MagicMock(need_credentials=True)
|
||||
mock_tm.get_builtin_provider.return_value = mock_controller
|
||||
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"key": "old"}
|
||||
mock_encrypter.encrypt.return_value = {"key": "new"}
|
||||
mock_cache = MagicMock()
|
||||
mock_enc.return_value = (mock_encrypter, mock_cache)
|
||||
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestGetOauthClientSchema:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.get_custom_oauth_client_params", return_value={})
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.is_oauth_system_client_exists", return_value=False)
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=True)
|
||||
@patch(f"{MODULE}.dify_config")
|
||||
@patch(f"{MODULE}.PluginService")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_returns_schema_dict(self, mock_tm, mock_plugin, mock_config, mock_enabled, mock_sys, mock_params):
|
||||
mock_config.CONSOLE_API_URL = "https://api.example.com"
|
||||
mock_controller = MagicMock()
|
||||
mock_controller.get_oauth_client_schema.return_value = []
|
||||
mock_tm.get_builtin_provider.return_value = mock_controller
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema("t", "google")
|
||||
|
||||
assert "schema" in result
|
||||
assert result["is_oauth_custom_client_enabled"] is True
|
||||
assert "redirect_uri" in result
|
||||
|
||||
|
||||
class TestGetOauthClient:
|
||||
@patch(f"{MODULE}.PluginService")
|
||||
@patch(f"{MODULE}.create_provider_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_user_client_params_when_exists(
|
||||
self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin
|
||||
):
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_controller = MagicMock()
|
||||
mock_controller.get_oauth_client_schema.return_value = []
|
||||
mock_tm.get_builtin_provider.return_value = mock_controller
|
||||
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"client_id": "id", "client_secret": "secret"}
|
||||
mock_create_enc.return_value = (mock_encrypter, MagicMock())
|
||||
|
||||
user_client = MagicMock(oauth_params='{"encrypted": "data"}')
|
||||
session.query.return_value.filter_by.return_value.first.return_value = user_client
|
||||
|
||||
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
||||
|
||||
assert result == {"client_id": "id", "client_secret": "secret"}
|
||||
|
||||
@patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"})
|
||||
@patch(f"{MODULE}.PluginService")
|
||||
@patch(f"{MODULE}.create_provider_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_falls_back_to_system_client(
|
||||
self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin, mock_decrypt
|
||||
):
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_controller = MagicMock()
|
||||
mock_controller.get_oauth_client_schema.return_value = []
|
||||
mock_tm.get_builtin_provider.return_value = mock_controller
|
||||
|
||||
mock_create_enc.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
system_client = MagicMock(encrypted_oauth_params="enc")
|
||||
session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
None, # user client
|
||||
system_client, # system client
|
||||
]
|
||||
|
||||
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
||||
|
||||
assert result == {"sys_key": "sys_val"}
|
||||
|
||||
|
||||
class TestSaveCustomOauthClientParams:
|
||||
def test_returns_early_when_no_params(self):
|
||||
result = BuiltinToolManageService.save_custom_oauth_client_params("t", "p")
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_raises_when_provider_not_found(self, mock_tm):
|
||||
mock_tm.get_builtin_provider.return_value = None
|
||||
|
||||
with pytest.raises((ValueError, Exception), match="not found|Provider"):
|
||||
BuiltinToolManageService.save_custom_oauth_client_params("t", "p", enable_oauth_custom_client=True)
|
||||
|
||||
|
||||
class TestGetCustomOauthClientParams:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_empty_when_none(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestGetBuiltinToolProviderCredentialInfo:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=False)
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.get_builtin_tool_provider_credentials", return_value=[])
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
def test_returns_credential_info(self, mock_tm, mock_creds, mock_oauth):
|
||||
mock_tm.get_builtin_provider.return_value.get_supported_credential_types.return_value = ["api-key"]
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_tool_provider_credential_info("t", "google")
|
||||
|
||||
assert result.credentials == []
|
||||
assert result.supported_credential_types == ["api-key"]
|
||||
assert result.is_oauth_custom_client_enabled is False
|
||||
|
||||
|
||||
class TestGetBuiltinToolProviderCredentials:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_empty_when_no_providers(self, mock_db):
|
||||
mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None)
|
||||
mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = []
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google")
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch(f"{MODULE}.ToolTransformService")
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_credential_entities(self, mock_db, mock_tm, mock_enc, mock_transform):
|
||||
mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None)
|
||||
mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
provider = MagicMock(provider="google", is_default=False)
|
||||
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider]
|
||||
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"key": "decrypted"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {"key": "***"}
|
||||
mock_enc.return_value = (mock_encrypter, MagicMock())
|
||||
|
||||
credential_entity = MagicMock()
|
||||
mock_transform.convert_builtin_provider_to_credential_entity.return_value = credential_entity
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] is credential_entity
|
||||
assert provider.is_default is True
|
||||
|
||||
|
||||
class TestGetBuiltinProvider:
|
||||
@patch(f"{MODULE}.ToolProviderID")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_none_when_not_found(self, mock_db, mock_session_cls, mock_prov_id):
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_prov_id.return_value.provider_name = "google"
|
||||
mock_prov_id.return_value.organization = "langgenius"
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch(f"{MODULE}.ToolProviderID")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_provider_for_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id):
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_prov_id.return_value.provider_name = "google"
|
||||
mock_prov_id.return_value.organization = "langgenius"
|
||||
db_provider = MagicMock(provider="google")
|
||||
mock_prov_id_result = MagicMock()
|
||||
mock_prov_id_result.to_string.return_value = "langgenius/google/google"
|
||||
|
||||
def prov_id_side_effect(name):
|
||||
m = MagicMock()
|
||||
m.provider_name = "google"
|
||||
m.organization = "langgenius"
|
||||
m.to_string.return_value = "langgenius/google/google"
|
||||
m.plugin_id = "langgenius/google"
|
||||
return m
|
||||
|
||||
mock_prov_id.side_effect = prov_id_side_effect
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
||||
|
||||
assert result is db_provider
|
||||
|
||||
@patch(f"{MODULE}.ToolProviderID")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_provider_for_non_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id):
|
||||
session = _mock_session(mock_session_cls)
|
||||
|
||||
def prov_id_side_effect(name):
|
||||
m = MagicMock()
|
||||
m.provider_name = "custom-tool"
|
||||
m.organization = "third-party"
|
||||
m.to_string.return_value = "third-party/custom/custom-tool"
|
||||
m.plugin_id = "third-party/custom"
|
||||
return m
|
||||
|
||||
mock_prov_id.side_effect = prov_id_side_effect
|
||||
db_provider = MagicMock(provider="third-party/custom/custom-tool")
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
|
||||
|
||||
assert result is db_provider
|
||||
|
||||
@patch(f"{MODULE}.ToolProviderID")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_falls_back_on_exception(self, mock_db, mock_session_cls, mock_prov_id):
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_prov_id.side_effect = Exception("parse error")
|
||||
fallback = MagicMock()
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")
|
||||
|
||||
assert result is fallback
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
|
||||
|
||||
def test_list_tool_labels_returns_default_labels():
|
||||
result = ToolLabelsService.list_tool_labels()
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_list_tool_labels_items_are_tool_labels():
|
||||
from core.tools.entities.tool_entities import ToolLabel
|
||||
|
||||
result = ToolLabelsService.list_tool_labels()
|
||||
for label in result:
|
||||
assert isinstance(label, ToolLabel)
|
||||
|
||||
|
||||
def test_list_tool_labels_matches_default_values():
|
||||
from core.tools.entities.values import default_tool_labels
|
||||
|
||||
assert ToolLabelsService.list_tool_labels() is default_tool_labels
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
|
||||
|
||||
class TestToolCommonService:
|
||||
@patch("services.tools.tools_manage_service.ToolTransformService")
|
||||
@patch("services.tools.tools_manage_service.ToolManager")
|
||||
def test_list_tool_providers_transforms_and_returns(self, mock_manager, mock_transform):
|
||||
mock_provider1 = MagicMock()
|
||||
mock_provider1.to_dict.return_value = {"name": "provider1"}
|
||||
mock_provider2 = MagicMock()
|
||||
mock_provider2.to_dict.return_value = {"name": "provider2"}
|
||||
mock_manager.list_providers_from_api.return_value = [mock_provider1, mock_provider2]
|
||||
|
||||
result = ToolCommonService.list_tool_providers("user-1", "tenant-1")
|
||||
|
||||
mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", None)
|
||||
assert mock_transform.repack_provider.call_count == 2
|
||||
assert result == [{"name": "provider1"}, {"name": "provider2"}]
|
||||
|
||||
@patch("services.tools.tools_manage_service.ToolTransformService")
|
||||
@patch("services.tools.tools_manage_service.ToolManager")
|
||||
def test_list_tool_providers_with_type_filter(self, mock_manager, mock_transform):
|
||||
mock_manager.list_providers_from_api.return_value = []
|
||||
|
||||
result = ToolCommonService.list_tool_providers("user-1", "tenant-1", typ="builtin")
|
||||
|
||||
mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", "builtin")
|
||||
assert result == []
|
||||
|
||||
@patch("services.tools.tools_manage_service.ToolTransformService")
|
||||
@patch("services.tools.tools_manage_service.ToolManager")
|
||||
def test_list_tool_providers_empty(self, mock_manager, mock_transform):
|
||||
mock_manager.list_providers_from_api.return_value = []
|
||||
|
||||
result = ToolCommonService.list_tool_providers("u", "t")
|
||||
|
||||
assert result == []
|
||||
mock_transform.repack_provider.assert_not_called()
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.workflow.queue_dispatcher import (
|
||||
BaseQueueDispatcher,
|
||||
ProfessionalQueueDispatcher,
|
||||
QueueDispatcherManager,
|
||||
QueuePriority,
|
||||
SandboxQueueDispatcher,
|
||||
TeamQueueDispatcher,
|
||||
)
|
||||
|
||||
|
||||
class TestQueuePriority:
|
||||
def test_priority_values(self):
|
||||
assert QueuePriority.PROFESSIONAL == "workflow_professional"
|
||||
assert QueuePriority.TEAM == "workflow_team"
|
||||
assert QueuePriority.SANDBOX == "workflow_sandbox"
|
||||
|
||||
|
||||
class TestDispatchers:
|
||||
def test_professional_dispatcher(self):
|
||||
d = ProfessionalQueueDispatcher()
|
||||
assert d.get_queue_name() == QueuePriority.PROFESSIONAL
|
||||
assert d.get_priority() == 100
|
||||
|
||||
def test_team_dispatcher(self):
|
||||
d = TeamQueueDispatcher()
|
||||
assert d.get_queue_name() == QueuePriority.TEAM
|
||||
assert d.get_priority() == 50
|
||||
|
||||
def test_sandbox_dispatcher(self):
|
||||
d = SandboxQueueDispatcher()
|
||||
assert d.get_queue_name() == QueuePriority.SANDBOX
|
||||
assert d.get_priority() == 10
|
||||
|
||||
def test_base_dispatcher_is_abstract(self):
|
||||
with pytest.raises(TypeError):
|
||||
BaseQueueDispatcher()
|
||||
|
||||
|
||||
class TestQueueDispatcherManager:
|
||||
@patch("services.workflow.queue_dispatcher.BillingService")
|
||||
@patch("services.workflow.queue_dispatcher.dify_config")
|
||||
def test_billing_enabled_professional_plan(self, mock_config, mock_billing):
|
||||
mock_config.BILLING_ENABLED = True
|
||||
mock_billing.get_info.return_value = {"subscription": {"plan": "professional"}}
|
||||
|
||||
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
|
||||
|
||||
assert isinstance(dispatcher, ProfessionalQueueDispatcher)
|
||||
|
||||
@patch("services.workflow.queue_dispatcher.BillingService")
|
||||
@patch("services.workflow.queue_dispatcher.dify_config")
|
||||
def test_billing_enabled_team_plan(self, mock_config, mock_billing):
|
||||
mock_config.BILLING_ENABLED = True
|
||||
mock_billing.get_info.return_value = {"subscription": {"plan": "team"}}
|
||||
|
||||
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
|
||||
|
||||
assert isinstance(dispatcher, TeamQueueDispatcher)
|
||||
|
||||
@patch("services.workflow.queue_dispatcher.BillingService")
|
||||
@patch("services.workflow.queue_dispatcher.dify_config")
|
||||
def test_billing_enabled_sandbox_plan(self, mock_config, mock_billing):
|
||||
mock_config.BILLING_ENABLED = True
|
||||
mock_billing.get_info.return_value = {"subscription": {"plan": "sandbox"}}
|
||||
|
||||
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
|
||||
|
||||
assert isinstance(dispatcher, SandboxQueueDispatcher)
|
||||
|
||||
@patch("services.workflow.queue_dispatcher.BillingService")
|
||||
@patch("services.workflow.queue_dispatcher.dify_config")
|
||||
def test_billing_enabled_unknown_plan_defaults_to_sandbox(self, mock_config, mock_billing):
|
||||
mock_config.BILLING_ENABLED = True
|
||||
mock_billing.get_info.return_value = {"subscription": {"plan": "enterprise"}}
|
||||
|
||||
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
|
||||
|
||||
assert isinstance(dispatcher, SandboxQueueDispatcher)
|
||||
|
||||
@patch("services.workflow.queue_dispatcher.BillingService")
|
||||
@patch("services.workflow.queue_dispatcher.dify_config")
|
||||
def test_billing_enabled_service_failure_defaults_to_sandbox(self, mock_config, mock_billing):
|
||||
mock_config.BILLING_ENABLED = True
|
||||
mock_billing.get_info.side_effect = Exception("billing unavailable")
|
||||
|
||||
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
|
||||
|
||||
assert isinstance(dispatcher, SandboxQueueDispatcher)
|
||||
|
||||
@patch("services.workflow.queue_dispatcher.dify_config")
|
||||
def test_billing_disabled_defaults_to_team(self, mock_config):
|
||||
mock_config.BILLING_ENABLED = False
|
||||
|
||||
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
|
||||
|
||||
assert isinstance(dispatcher, TeamQueueDispatcher)
|
||||
|
||||
@patch("services.workflow.queue_dispatcher.BillingService")
|
||||
@patch("services.workflow.queue_dispatcher.dify_config")
|
||||
def test_missing_subscription_key_defaults_to_sandbox(self, mock_config, mock_billing):
|
||||
mock_config.BILLING_ENABLED = True
|
||||
mock_billing.get_info.return_value = {}
|
||||
|
||||
dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1")
|
||||
|
||||
assert isinstance(dispatcher, SandboxQueueDispatcher)
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
|
||||
|
||||
|
||||
class TestSchedulerCommand:
|
||||
def test_enum_values(self):
|
||||
assert SchedulerCommand.RESOURCE_LIMIT_REACHED == "resource_limit_reached"
|
||||
assert SchedulerCommand.NONE == "none"
|
||||
|
||||
def test_enum_is_str(self):
|
||||
for member in SchedulerCommand:
|
||||
assert isinstance(member, str)
|
||||
|
||||
|
||||
class TestCFSPlanScheduler:
|
||||
def test_stores_plan(self):
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop,
|
||||
granularity=-1,
|
||||
)
|
||||
|
||||
class ConcretePlanScheduler(CFSPlanScheduler):
|
||||
def can_schedule(self):
|
||||
return SchedulerCommand.NONE
|
||||
|
||||
scheduler = ConcretePlanScheduler(plan)
|
||||
|
||||
assert scheduler.plan is plan
|
||||
assert scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.Nop
|
||||
assert scheduler.plan.granularity == -1
|
||||
|
||||
def test_cannot_instantiate_abstract(self):
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
|
||||
granularity=10,
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
CFSPlanScheduler(plan)
|
||||
|
||||
def test_concrete_subclass_can_schedule(self):
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
|
||||
granularity=5,
|
||||
)
|
||||
|
||||
class TimedScheduler(CFSPlanScheduler):
|
||||
def can_schedule(self):
|
||||
if self.plan.granularity > 0:
|
||||
return SchedulerCommand.NONE
|
||||
return SchedulerCommand.RESOURCE_LIMIT_REACHED
|
||||
|
||||
scheduler = TimedScheduler(plan)
|
||||
assert scheduler.can_schedule() == SchedulerCommand.NONE
|
||||
|
||||
def test_concrete_subclass_resource_limit(self):
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
|
||||
granularity=-1,
|
||||
)
|
||||
|
||||
class TimedScheduler(CFSPlanScheduler):
|
||||
def can_schedule(self):
|
||||
if self.plan.granularity > 0:
|
||||
return SchedulerCommand.NONE
|
||||
return SchedulerCommand.RESOURCE_LIMIT_REACHED
|
||||
|
||||
scheduler = TimedScheduler(plan)
|
||||
assert scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED
|
||||
|
||||
|
||||
class TestWorkflowScheduleCFSPlanEntity:
|
||||
def test_strategy_values(self):
|
||||
assert WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice == "time-slice"
|
||||
assert WorkflowScheduleCFSPlanEntity.Strategy.Nop == "nop"
|
||||
|
||||
def test_default_granularity(self):
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop,
|
||||
)
|
||||
assert plan.granularity == -1
|
||||
|
||||
def test_explicit_granularity(self):
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice,
|
||||
granularity=100,
|
||||
)
|
||||
assert plan.granularity == 100
|
||||
|
|
@ -41,6 +41,23 @@ class TestWorkflowService:
|
|||
workflows.append(workflow)
|
||||
return workflows
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_session_cls(self):
|
||||
class DummySession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.commit = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def begin(self):
|
||||
return nullcontext()
|
||||
|
||||
return DummySession
|
||||
|
||||
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
|
||||
mock_app.workflow_id = None
|
||||
mock_session = MagicMock()
|
||||
|
|
@ -170,7 +187,10 @@ class TestWorkflowService:
|
|||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_submit_human_input_form_preview_uses_rendered_content(
|
||||
self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
|
||||
self,
|
||||
workflow_service: WorkflowService,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
dummy_session_cls,
|
||||
) -> None:
|
||||
service = workflow_service
|
||||
node_data = HumanInputNodeData(
|
||||
|
|
@ -197,19 +217,6 @@ class TestWorkflowService:
|
|||
|
||||
saved_outputs: dict[str, object] = {}
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.commit = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def begin(self):
|
||||
return nullcontext()
|
||||
|
||||
class DummySaver:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
@ -217,7 +224,7 @@ class TestWorkflowService:
|
|||
def save(self, outputs, process_data):
|
||||
saved_outputs.update(outputs)
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "Session", DummySession)
|
||||
monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
|
||||
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
|
||||
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
|
|
@ -278,7 +285,6 @@ class TestWorkflowService:
|
|||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
|
|
@ -290,3 +296,115 @@ class TestWorkflowService:
|
|||
)
|
||||
|
||||
assert "Missing required inputs" in str(exc_info.value)
|
||||
|
||||
def test_run_draft_workflow_node_successful_behavior(
|
||||
self, workflow_service, mock_app, monkeypatch, dummy_session_cls
|
||||
):
|
||||
"""Behavior: When a basic workflow node runs, it correctly sets up context,
|
||||
executes the node, and saves outputs."""
|
||||
service = workflow_service
|
||||
account = SimpleNamespace(id="account-1")
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.id = "wf-1"
|
||||
mock_workflow.tenant_id = "tenant-1"
|
||||
mock_workflow.environment_variables = []
|
||||
mock_workflow.conversation_variables = []
|
||||
|
||||
# Mock node config
|
||||
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
|
||||
mock_workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
|
||||
# Mock class methods
|
||||
monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
|
||||
monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock())
|
||||
|
||||
# Mock workflow entry execution
|
||||
mock_node_exec = MagicMock()
|
||||
mock_node_exec.id = "exec-1"
|
||||
mock_node_exec.process_data = {}
|
||||
mock_run = MagicMock()
|
||||
monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run)
|
||||
|
||||
# Mock execution handling
|
||||
service._handle_single_step_result = MagicMock(return_value=mock_node_exec)
|
||||
|
||||
# Mock repository
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_execution_by_id.return_value = mock_node_exec
|
||||
mock_repo_factory = MagicMock(return_value=mock_repo)
|
||||
monkeypatch.setattr(
|
||||
workflow_service_module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
mock_repo_factory,
|
||||
)
|
||||
service._node_execution_service_repo = mock_repo
|
||||
|
||||
# Set up node execution service repo mock to return our exec node
|
||||
mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"}
|
||||
mock_node_exec.node_id = "node-1"
|
||||
mock_node_exec.node_type = "llm"
|
||||
|
||||
# Mock draft variable saver
|
||||
mock_saver = MagicMock()
|
||||
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver))
|
||||
|
||||
# Mock DB
|
||||
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
|
||||
|
||||
# Act
|
||||
result = service.run_draft_workflow_node(
|
||||
app_model=mock_app,
|
||||
draft_workflow=mock_workflow,
|
||||
node_id="node-1",
|
||||
user_inputs={"input_val": "test"},
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_node_exec
|
||||
service._handle_single_step_result.assert_called_once()
|
||||
mock_repo.save.assert_called_once_with(mock_node_exec)
|
||||
mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"})
|
||||
|
||||
def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls):
|
||||
"""Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations."""
|
||||
service = workflow_service
|
||||
account = SimpleNamespace(id="account-1")
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.tenant_id = "tenant-1"
|
||||
mock_workflow.environment_variables = []
|
||||
mock_workflow.conversation_variables = []
|
||||
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
|
||||
mock_workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
|
||||
monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock())
|
||||
monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock())
|
||||
|
||||
mock_node_exec = MagicMock()
|
||||
mock_node_exec.id = "exec-invalid"
|
||||
service._handle_single_step_result = MagicMock(return_value=mock_node_exec)
|
||||
|
||||
mock_repo = MagicMock()
|
||||
mock_repo_factory = MagicMock(return_value=mock_repo)
|
||||
monkeypatch.setattr(
|
||||
workflow_service_module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
mock_repo_factory,
|
||||
)
|
||||
service._node_execution_service_repo = mock_repo
|
||||
|
||||
# Simulate failure to retrieve the saved execution
|
||||
mock_repo.get_execution_by_id.return_value = None
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"):
|
||||
service.run_draft_workflow_node(
|
||||
app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue