diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index bae39dc8c7..4b47777f0b 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -59,8 +59,6 @@ class DatasourcePluginProviderController(ABC): :param credentials: the credentials of the tool """ credentials_schema = dict[str, ProviderConfig]() - if credentials_schema is None: - return for credential in self.entity.credentials_schema: credentials_schema[credential.name] = credential diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py new file mode 100644 index 0000000000..5482b4db52 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +from configs import dify_config +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType + + +class ConcreteDatasourcePlugin(DatasourcePlugin): + """ + Concrete implementation of DatasourcePlugin for testing purposes. + Since DatasourcePlugin is an ABC, we need a concrete class to instantiate it. + """ + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.LOCAL_FILE + + +class TestDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + + # Act + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Act + provider_type = plugin.datasource_provider_type() + # Call the base class method to ensure it's covered + base_provider_type = DatasourcePlugin.datasource_provider_type(plugin) + + # Assert + assert provider_type == DatasourceProviderType.LOCAL_FILE + assert base_provider_type == DatasourceProviderType.LOCAL_FILE + + def test_fork_datasource_runtime(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_entity_copy = MagicMock(spec=DatasourceEntity) + mock_entity.model_copy.return_value = mock_entity_copy + + runtime = MagicMock(spec=DatasourceRuntime) + new_runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + + plugin = ConcreteDatasourcePlugin(entity=mock_entity, runtime=runtime, icon=icon) + + # Act + new_plugin = plugin.fork_datasource_runtime(new_runtime) + + # Assert + assert isinstance(new_plugin, ConcreteDatasourcePlugin) + assert new_plugin.entity == mock_entity_copy + assert new_plugin.runtime == new_runtime + assert new_plugin.icon == icon + mock_entity.model_copy.assert_called_once() + + def test_get_icon_url(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + tenant_id = "test-tenant-id" + + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Mocking dify_config.CONSOLE_API_URL + with patch.object(dify_config, "CONSOLE_API_URL", "https://api.dify.ai"): + # Act + icon_url = plugin.get_icon_url(tenant_id) + + # Assert + expected_url = ( + f"https://api.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={icon}" + ) + assert icon_url == expected_url diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py new file mode 100644 index 0000000000..6a3d21a33d --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py @@ -0,0 +1,265 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.entities.provider_entities import ProviderConfig +from core.tools.errors import ToolProviderCredentialValidationError + + +class ConcreteDatasourcePluginProviderController(DatasourcePluginProviderController): + """ + Concrete implementation of DatasourcePluginProviderController for testing purposes. + """ + + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: + return MagicMock(spec=DatasourcePlugin) + + +class TestDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + tenant_id = "test-tenant-id" + + # Act + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Assert + assert controller.entity == mock_entity + assert controller.tenant_id == tenant_id + + def test_need_credentials(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + tenant_id = "test-tenant-id" + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Case 1: credentials_schema is None + mock_entity.credentials_schema = None + assert controller.need_credentials is False + + # Case 2: credentials_schema is empty + mock_entity.credentials_schema = [] + assert controller.need_credentials is False + + # Case 3: credentials_schema has items + mock_entity.credentials_schema = [MagicMock()] + assert controller.need_credentials is True + + @patch("core.datasource.__base.datasource_provider.PluginToolManager") + def test_validate_credentials(self, mock_manager_class): + # Arrange + mock_manager = mock_manager_class.return_value + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.identity = MagicMock() + mock_entity.identity.name = "test-provider" + tenant_id = "test-tenant-id" + user_id = "test-user-id" + credentials = {"api_key": "secret"} + + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Act: Successful validation + mock_manager.validate_datasource_credentials.return_value = True + controller._validate_credentials(user_id, credentials) + + mock_manager.validate_datasource_credentials.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + provider="test-provider", + credentials=credentials, + ) + + # Act: Failed validation + mock_manager.validate_datasource_credentials.return_value = False + with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"): + controller._validate_credentials(user_id, credentials) + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.LOCAL_FILE + + def test_validate_credentials_format_empty_schema(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + credentials = {} + + # Act & Assert (Should not raise anything) + controller.validate_credentials_format(credentials) + + def test_validate_credentials_format_unknown_credential(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.identity = MagicMock() + mock_entity.identity.name = "test-provider" + mock_entity.credentials_schema = [] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + credentials = {"unknown": "value"} + + # Act & Assert + with pytest.raises( + ToolProviderCredentialValidationError, match="credential unknown not found in provider test-provider" + ): + controller.validate_credentials_format(credentials) + + def test_validate_credentials_format_required_missing(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "api_key" + mock_config.required = True + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + with pytest.raises(ToolProviderCredentialValidationError, match="credential api_key is required"): + controller.validate_credentials_format({}) + + def test_validate_credentials_format_not_required_null(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "optional" + mock_config.required = False + mock_config.default = None + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + credentials = {"optional": None} + controller.validate_credentials_format(credentials) + assert credentials["optional"] is None + + def test_validate_credentials_format_type_mismatch_text(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "text_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.TEXT_INPUT + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + with pytest.raises(ToolProviderCredentialValidationError, match="credential text_field should be string"): + controller.validate_credentials_format({"text_field": 123}) + + def test_validate_credentials_format_select_validation(self): + # Arrange + mock_option = MagicMock() + mock_option.value = "opt1" + + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "select_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.SELECT + mock_config.options = [mock_option] + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Case 1: Value not string + with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be string"): + controller.validate_credentials_format({"select_field": 123}) + + # Case 2: Options not list + mock_config.options = "invalid" + with pytest.raises( + ToolProviderCredentialValidationError, match="credential select_field options should be list" + ): + controller.validate_credentials_format({"select_field": "opt1"}) + + # Case 3: Value not in options + mock_config.options = [mock_option] + with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be one of"): + controller.validate_credentials_format({"select_field": "invalid_opt"}) + + def test_get_datasource_base(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + result = DatasourcePluginProviderController.get_datasource(controller, "test") + + # Assert + assert result is None + + def test_validate_credentials_format_hits_pop(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "valid_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.TEXT_INPUT + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {"valid_field": "valid_value"} + controller.validate_credentials_format(credentials) + + # Assert + assert "valid_field" in credentials + assert credentials["valid_field"] == "valid_value" + + def test_validate_credentials_format_hits_continue(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "optional_field" + mock_config.required = False + mock_config.default = None + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {"optional_field": None} + controller.validate_credentials_format(credentials) + + # Assert + assert credentials["optional_field"] is None + + def test_validate_credentials_format_default_values(self): + # Arrange + mock_config_text = MagicMock(spec=ProviderConfig) + mock_config_text.name = "text_def" + mock_config_text.required = False + mock_config_text.type = ProviderConfig.Type.TEXT_INPUT + mock_config_text.default = 123 # Int default, should be converted to str + + mock_config_other = MagicMock(spec=ProviderConfig) + mock_config_other.name = "other_def" + mock_config_other.required = False + mock_config_other.type = "OTHER" + mock_config_other.default = "fallback" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config_text, mock_config_other] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {} + controller.validate_credentials_format(credentials) + + # Assert + assert credentials["text_def"] == "123" + assert credentials["other_def"] == "fallback" diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py new file mode 100644 index 0000000000..2bca9155e9 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py @@ -0,0 +1,26 @@ +from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.__base.datasource_runtime import DatasourceRuntime, FakeDatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceInvokeFrom + + +class TestDatasourceRuntime: + def test_init(self): + runtime = DatasourceRuntime( + tenant_id="test-tenant", + datasource_id="test-ds", + invoke_from=InvokeFrom.DEBUGGER, + datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, + credentials={"key": "val"}, + runtime_parameters={"p": "v"}, + ) + assert runtime.tenant_id == "test-tenant" + assert runtime.datasource_id == "test-ds" + assert runtime.credentials["key"] == "val" + + def test_fake_datasource_runtime(self): + # This covers the FakeDatasourceRuntime class and its __init__ + runtime = FakeDatasourceRuntime() + assert runtime.tenant_id == "fake_tenant_id" + assert runtime.datasource_id == "fake_datasource_id" + assert runtime.invoke_from == InvokeFrom.DEBUGGER + assert runtime.datasource_invoke_from == DatasourceInvokeFrom.RAG_PIPELINE diff --git a/api/tests/unit_tests/core/datasource/entities/test_api_entities.py b/api/tests/unit_tests/core/datasource/entities/test_api_entities.py new file mode 100644 index 0000000000..9855b4040a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_api_entities.py @@ -0,0 +1,150 @@ +from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity +from core.datasource.entities.datasource_entities import DatasourceParameter +from core.tools.entities.common_entities import I18nObject + + +def test_datasource_api_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + + entity = DatasourceApiEntity( + author="author", name="name", label=label, description=description, labels=["l1", "l2"] + ) + + assert entity.author == "author" + assert entity.name == "name" + assert entity.label == label + assert entity.description == description + assert entity.labels == ["l1", "l2"] + assert entity.parameters is None + assert entity.output_schema is None + + +def test_datasource_provider_api_entity_defaults(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + entity = DatasourceProviderApiEntity( + id="id", author="author", name="name", description=description, icon="icon", label=label, type="type" + ) + + assert entity.id == "id" + assert entity.datasources == [] + assert entity.is_team_authorization is False + assert entity.allow_delete is True + assert entity.plugin_id == "" + assert entity.plugin_unique_identifier == "" + assert entity.labels == [] + + +def test_datasource_provider_api_entity_convert_none_to_empty_list(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + # Implicitly testing the field_validator "convert_none_to_empty_list" + entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=None, # type: ignore + ) + + assert entity.datasources == [] + + +def test_datasource_provider_api_entity_to_dict(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + # Create a parameter that should be converted + param = DatasourceParameter.get_simple_instance( + name="test_param", typ=DatasourceParameter.DatasourceParameterType.SYSTEM_FILES, required=True + ) + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=[param] + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + masked_credentials={"key": "masked"}, + datasources=[ds_entity], + labels=["l1"], + ) + + result = provider_entity.to_dict() + + assert result["id"] == "id" + assert result["author"] == "author" + assert result["name"] == "name" + assert result["description"] == description.to_dict() + assert result["icon"] == "icon" + assert result["label"] == label.to_dict() + assert result["type"] == "type" + assert result["team_credentials"] == {"key": "masked"} + assert result["is_team_authorization"] is False + assert result["allow_delete"] is True + assert result["labels"] == ["l1"] + + # Check if parameter type was converted from SYSTEM_FILES to files + assert result["datasources"][0]["parameters"][0]["type"] == "files" + + +def test_datasource_provider_api_entity_to_dict_no_params(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=None + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=[ds_entity], + ) + + result = provider_entity.to_dict() + assert result["datasources"][0]["parameters"] is None + + +def test_datasource_provider_api_entity_to_dict_other_param_type(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + param = DatasourceParameter.get_simple_instance( + name="test_param", typ=DatasourceParameter.DatasourceParameterType.STRING, required=True + ) + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=[param] + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=[ds_entity], + ) + + result = provider_entity.to_dict() + assert result["datasources"][0]["parameters"][0]["type"] == "string" diff --git a/api/tests/unit_tests/core/datasource/entities/test_common_entities.py b/api/tests/unit_tests/core/datasource/entities/test_common_entities.py new file mode 100644 index 0000000000..0ee4928105 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_common_entities.py @@ -0,0 +1,31 @@ +from core.datasource.entities.common_entities import I18nObject + + +def test_i18n_object_fallback(): + # Only en_US provided + obj = I18nObject(en_US="Hello") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "Hello" + assert obj.pt_BR == "Hello" + assert obj.ja_JP == "Hello" + + # Some fields provided + obj = I18nObject(en_US="Hello", zh_Hans="你好") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "你好" + assert obj.pt_BR == "Hello" + assert obj.ja_JP == "Hello" + + +def test_i18n_object_all_fields(): + obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "你好" + assert obj.pt_BR == "Olá" + assert obj.ja_JP == "こんにちは" + + +def test_i18n_object_to_dict(): + obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは") + expected_dict = {"en_US": "Hello", "zh_Hans": "你好", "pt_BR": "Olá", "ja_JP": "こんにちは"} + assert obj.to_dict() == expected_dict diff --git a/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py b/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py new file mode 100644 index 0000000000..a8c8d31537 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py @@ -0,0 +1,275 @@ +from unittest.mock import patch + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceInvokeMeta, + DatasourceLabel, + DatasourceMessage, + DatasourceParameter, + DatasourceProviderEntity, + DatasourceProviderEntityWithPlugin, + DatasourceProviderIdentity, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetWebsiteCrawlRequest, + OnlineDocumentInfo, + OnlineDocumentPage, + OnlineDocumentPageContent, + OnlineDocumentPagesMessage, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, + OnlineDriveDownloadFileRequest, + OnlineDriveFile, + OnlineDriveFileBucket, + WebsiteCrawlMessage, + WebSiteInfo, + WebSiteInfoDetail, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolLabelEnum + + +def test_datasource_provider_type(): + assert DatasourceProviderType.value_of("online_document") == DatasourceProviderType.ONLINE_DOCUMENT + assert DatasourceProviderType.value_of("local_file") == DatasourceProviderType.LOCAL_FILE + + with pytest.raises(ValueError, match="invalid mode value invalid"): + DatasourceProviderType.value_of("invalid") + + +def test_datasource_parameter_type(): + param_type = DatasourceParameter.DatasourceParameterType.STRING + assert param_type.as_normal_type() == "string" + assert param_type.cast_value("test") == "test" + + param_type = DatasourceParameter.DatasourceParameterType.NUMBER + assert param_type.cast_value("123") == 123 + + +def test_datasource_parameter(): + param = DatasourceParameter.get_simple_instance( + name="test_param", + typ=DatasourceParameter.DatasourceParameterType.STRING, + required=True, + options=["opt1", "opt2"], + ) + assert param.name == "test_param" + assert param.type == DatasourceParameter.DatasourceParameterType.STRING + assert param.required is True + assert len(param.options) == 2 + assert param.options[0].value == "opt1" + + param_no_options = DatasourceParameter.get_simple_instance( + name="test_param_2", typ=DatasourceParameter.DatasourceParameterType.NUMBER, required=False + ) + assert param_no_options.options == [] + + # Test init_frontend_parameter + # For STRING, it should just return the value as is (or cast to str) + frontend_param = param.init_frontend_parameter("val") + assert frontend_param == "val" + + # Test parameter type methods + assert DatasourceParameter.DatasourceParameterType.STRING.as_normal_type() == "string" + assert DatasourceParameter.DatasourceParameterType.NUMBER.as_normal_type() == "number" + assert DatasourceParameter.DatasourceParameterType.SECRET_INPUT.as_normal_type() == "string" + + assert DatasourceParameter.DatasourceParameterType.NUMBER.cast_value("10.5") == 10.5 + assert DatasourceParameter.DatasourceParameterType.BOOLEAN.cast_value("true") is True + assert DatasourceParameter.DatasourceParameterType.FILES.cast_value(["f1", "f2"]) == ["f1", "f2"] + + +def test_datasource_identity(): + label = I18nObject(en_US="label", zh_Hans="标签") + identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider", icon="icon") + assert identity.author == "author" + assert identity.name == "name" + assert identity.label == label + assert identity.provider == "provider" + assert identity.icon == "icon" + + +def test_datasource_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider") + description = I18nObject(en_US="desc", zh_Hans="描述") + + entity = DatasourceEntity( + identity=identity, + description=description, + parameters=None, # Should be handled by validator + ) + assert entity.parameters == [] + + param = DatasourceParameter.get_simple_instance("p1", DatasourceParameter.DatasourceParameterType.STRING, True) + entity_with_params = DatasourceEntity(identity=identity, description=description, parameters=[param]) + assert entity_with_params.parameters == [param] + + +def test_datasource_provider_identity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon.png", label=label, tags=[ToolLabelEnum.SEARCH] + ) + + assert identity.author == "author" + assert identity.name == "name" + assert identity.description == description + assert identity.icon == "icon.png" + assert identity.label == label + assert identity.tags == [ToolLabelEnum.SEARCH] + + # Test generate_datasource_icon_url + with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config: + mock_config.CONSOLE_API_URL = "http://api.example.com" + url = identity.generate_datasource_icon_url("tenant123") + assert "http://api.example.com/console/api/workspaces/current/plugin/icon" in url + assert "tenant_id=tenant123" in url + assert "filename=icon.png" in url + + # Test hardcoded icon + identity.icon = "https://assets.dify.ai/images/File%20Upload.svg" + assert identity.generate_datasource_icon_url("tenant123") == identity.icon + + # Test with empty CONSOLE_API_URL + identity.icon = "test.png" + with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config: + mock_config.CONSOLE_API_URL = None + url = identity.generate_datasource_icon_url("tenant123") + assert url.startswith("/console/api/workspaces/current/plugin/icon") + + +def test_datasource_provider_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon", label=label + ) + + entity = DatasourceProviderEntity( + identity=identity, + provider_type=DatasourceProviderType.ONLINE_DOCUMENT, + credentials_schema=[], + oauth_schema=None, + ) + assert entity.identity == identity + assert entity.provider_type == DatasourceProviderType.ONLINE_DOCUMENT + assert entity.credentials_schema == [] + + +def test_datasource_provider_entity_with_plugin(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon", label=label + ) + + entity = DatasourceProviderEntityWithPlugin( + identity=identity, provider_type=DatasourceProviderType.ONLINE_DOCUMENT, datasources=[] + ) + assert entity.datasources == [] + + +def test_datasource_invoke_meta(): + meta = DatasourceInvokeMeta(time_cost=1.5, error="some error", tool_config={"k": "v"}) + assert meta.time_cost == 1.5 + assert meta.error == "some error" + assert meta.tool_config == {"k": "v"} + + d = meta.to_dict() + assert d == {"time_cost": 1.5, "error": "some error", "tool_config": {"k": "v"}} + + empty_meta = DatasourceInvokeMeta.empty() + assert empty_meta.time_cost == 0.0 + assert empty_meta.error is None + assert empty_meta.tool_config == {} + + error_meta = DatasourceInvokeMeta.error_instance("fatal error") + assert error_meta.time_cost == 0.0 + assert error_meta.error == "fatal error" + assert error_meta.tool_config == {} + + +def test_datasource_label(): + label_obj = I18nObject(en_US="label", zh_Hans="标签") + ds_label = DatasourceLabel(name="name", label=label_obj, icon="icon") + assert ds_label.name == "name" + assert ds_label.label == label_obj + assert ds_label.icon == "icon" + + +def test_online_document_models(): + page = OnlineDocumentPage( + page_id="p1", + page_name="name", + page_icon={"type": "emoji"}, + type="page", + last_edited_time="2023-01-01", + parent_id=None, + ) + assert page.page_id == "p1" + + info = OnlineDocumentInfo(workspace_id="w1", workspace_name="name", workspace_icon="icon", total=1, pages=[page]) + assert info.total == 1 + + msg = OnlineDocumentPagesMessage(result=[info]) + assert msg.result == [info] + + req = GetOnlineDocumentPageContentRequest(workspace_id="w1", page_id="p1", type="page") + assert req.workspace_id == "w1" + + content = OnlineDocumentPageContent(workspace_id="w1", page_id="p1", content="hello") + assert content.content == "hello" + + resp = GetOnlineDocumentPageContentResponse(result=content) + assert resp.result == content + + +def test_website_crawl_models(): + req = GetWebsiteCrawlRequest(crawl_parameters={"url": "http://test.com"}) + assert req.crawl_parameters == {"url": "http://test.com"} + + detail = WebSiteInfoDetail(source_url="http://test.com", content="content", title="title", description="desc") + assert detail.title == "title" + + info = WebSiteInfo(status="completed", web_info_list=[detail], total=1, completed=1) + assert info.status == "completed" + + msg = WebsiteCrawlMessage(result=info) + assert msg.result == info + + # Test default values + msg_default = WebsiteCrawlMessage() + assert msg_default.result.status == "" + assert msg_default.result.web_info_list == [] + + +def test_online_drive_models(): + file = OnlineDriveFile(id="f1", name="file.txt", size=100, type="file") + assert file.name == "file.txt" + + bucket = OnlineDriveFileBucket(bucket="b1", files=[file], is_truncated=False, next_page_parameters=None) + assert bucket.bucket == "b1" + + req = OnlineDriveBrowseFilesRequest(bucket="b1", prefix="folder1", max_keys=10, next_page_parameters=None) + assert req.prefix == "folder1" + + resp = OnlineDriveBrowseFilesResponse(result=[bucket]) + assert resp.result == [bucket] + + dl_req = OnlineDriveDownloadFileRequest(id="f1", bucket="b1") + assert dl_req.id == "f1" + + +def test_datasource_message(): + # Use proper dict for message to avoid Pydantic Union validation ambiguity/crashes + msg = DatasourceMessage(type="text", message={"text": "hello"}) + assert msg.message.text == "hello" + + msg_json = DatasourceMessage(type="json", message={"json_object": {"k": "v"}}) + assert msg_json.message.json_object == {"k": "v"} diff --git a/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py b/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py new file mode 100644 index 0000000000..5bf7362a8a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py @@ -0,0 +1,57 @@ +from unittest.mock import MagicMock + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, +) +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin + + +class TestLocalFileDatasourcePlugin: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test-tenant-id" + icon = "test-icon" + plugin_unique_identifier = "test-plugin-id" + + # Act + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.tenant_id == tenant_id + assert plugin.plugin_unique_identifier == plugin_unique_identifier + assert plugin.entity == mock_entity + assert plugin.runtime == mock_runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.datasource_provider_type() == DatasourceProviderType.LOCAL_FILE + + def test_get_icon_url(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon" + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon=icon, plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.get_icon_url("any-tenant-id") == icon diff --git a/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py b/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py new file mode 100644 index 0000000000..af2369ac4e --- /dev/null +++ b/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py @@ -0,0 +1,96 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin +from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController + + +class TestLocalFileDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + # Act + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.LOCAL_FILE + + def test_validate_credentials(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + # Should not raise any exception + controller._validate_credentials("user_id", {"key": "value"}) + + def test_get_datasource_success(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "test_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + datasource = controller.get_datasource("test_datasource") + + # Assert + assert isinstance(datasource, LocalFileDatasourcePlugin) + assert datasource.entity == mock_datasource_entity + assert datasource.tenant_id == tenant_id + assert datasource.icon == "test_icon" + assert datasource.plugin_unique_identifier == plugin_unique_identifier + + def test_get_datasource_not_found(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name test_datasource not found"): + controller.get_datasource("test_datasource") diff --git a/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py b/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py new file mode 100644 index 0000000000..e3a217725a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py @@ -0,0 +1,151 @@ +from unittest.mock import MagicMock, patch + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, +) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin + + +class TestOnlineDocumentDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + # Act + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.tenant_id == tenant_id + assert plugin.icon == icon + assert plugin.plugin_unique_identifier == plugin_unique_identifier + + def test_get_online_document_pages(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test_key"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + datasource_parameters = {"param": "value"} + provider_type = "test_type" + + mock_generator = MagicMock() + + # Patch PluginDatasourceManager to isolate plugin behavior from external dependencies + with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.get_online_document_pages.return_value = mock_generator + + # Act + result = plugin.get_online_document_pages( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert result == mock_generator + mock_manager_instance.get_online_document_pages.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def test_get_online_document_page_content(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test_key"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + datasource_parameters = MagicMock(spec=GetOnlineDocumentPageContentRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.get_online_document_page_content.return_value = mock_generator + + # Act + result = plugin.get_online_document_page_content( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert result == mock_generator + mock_manager_instance.get_online_document_page_content.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act + result = plugin.datasource_provider_type() + + # Assert + assert result == DatasourceProviderType.ONLINE_DOCUMENT diff --git a/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py b/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py new file mode 100644 index 0000000000..cfdd05e0b2 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py @@ -0,0 +1,100 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController + + +class TestOnlineDocumentDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_uid" + tenant_id = "test_tenant_id" + + # Act + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Assert + assert controller.provider_type == DatasourceProviderType.ONLINE_DOCUMENT + + def test_get_datasource_success(self): + # Arrange + from core.datasource.entities.datasource_entities import DatasourceIdentity + + mock_datasource_entity = MagicMock(spec=DatasourceEntity) + mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity) + mock_datasource_entity.identity.name = "target_datasource" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity = MagicMock() + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_uid" + tenant_id = "test_tenant_id" + + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id="test_plugin_id", + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Act + result = controller.get_datasource("target_datasource") + + # Assert + assert isinstance(result, OnlineDocumentDatasourcePlugin) + assert result.entity == mock_datasource_entity + assert result.tenant_id == tenant_id + assert result.icon == "test_icon" + assert result.plugin_unique_identifier == plugin_unique_identifier + assert result.runtime.tenant_id == tenant_id + + def test_get_datasource_not_found(self): + # Arrange + from core.datasource.entities.datasource_entities import DatasourceIdentity + + mock_datasource_entity = MagicMock(spec=DatasourceEntity) + mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity) + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.datasources = [mock_datasource_entity] + + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id="test_plugin_id", + plugin_unique_identifier="test_plugin_uid", + tenant_id="test_tenant_id", + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name missing_datasource not found"): + controller.get_datasource("missing_datasource") diff --git a/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py new file mode 100644 index 0000000000..6c8b644871 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py @@ -0,0 +1,147 @@ +from unittest.mock import MagicMock, patch + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceProviderType, + OnlineDriveBrowseFilesRequest, + OnlineDriveDownloadFileRequest, +) +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin + + +class TestOnlineDriveDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + # Act + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.tenant_id == tenant_id + assert plugin.icon == icon + assert plugin.plugin_unique_identifier == plugin_unique_identifier + + def test_online_drive_browse_files(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"token": "test_token"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + request = MagicMock(spec=OnlineDriveBrowseFilesRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.online_drive_browse_files.return_value = mock_generator + + # Act + result = plugin.online_drive_browse_files(user_id=user_id, request=request, provider_type=provider_type) + + # Assert + assert result == mock_generator + mock_manager_instance.online_drive_browse_files.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def test_online_drive_download_file(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"token": "test_token"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + request = MagicMock(spec=OnlineDriveDownloadFileRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.online_drive_download_file.return_value = mock_generator + + # Act + result = plugin.online_drive_download_file(user_id=user_id, request=request, provider_type=provider_type) + + # Assert + assert result == mock_generator + mock_manager_instance.online_drive_download_file.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + plugin = OnlineDriveDatasourcePlugin( + entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act + result = plugin.datasource_provider_type() + + # Assert + assert result == DatasourceProviderType.ONLINE_DRIVE diff --git a/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py new file mode 100644 index 0000000000..2824ddd8ed --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin +from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController + + +class TestOnlineDriveDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + # Act + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.ONLINE_DRIVE + + def test_get_datasource_success(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "test_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + datasource = controller.get_datasource("test_datasource") + + # Assert + assert isinstance(datasource, OnlineDriveDatasourcePlugin) + assert datasource.entity == mock_datasource_entity + assert datasource.tenant_id == tenant_id + assert datasource.icon == "test_icon" + assert datasource.plugin_unique_identifier == plugin_unique_identifier + assert datasource.runtime.tenant_id == tenant_id + + def test_get_datasource_not_found(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name test_datasource not found"): + controller.get_datasource("test_datasource") diff --git a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py new file mode 100644 index 0000000000..a7c93242cd --- /dev/null +++ b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py @@ -0,0 +1,409 @@ +import base64 +import hashlib +import hmac +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from core.datasource.datasource_file_manager import DatasourceFileManager +from models.model import MessageFile, UploadFile +from models.tools import ToolFile + + +class TestDatasourceFileManager: + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.os.urandom") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_sign_file(self, mock_config, mock_urandom, mock_time): + # Setup + mock_config.FILES_URL = "http://localhost:5001" + mock_config.SECRET_KEY = "test_secret" + mock_time.return_value = 1700000000 + mock_urandom.return_value = b"1234567890abcdef" # 16 bytes + + datasource_file_id = "file_id_123" + extension = ".png" + + # Execute + signed_url = DatasourceFileManager.sign_file(datasource_file_id, extension) + + # Verify + assert signed_url.startswith("http://localhost:5001/files/datasources/file_id_123.png?") + assert "timestamp=1700000000" in signed_url + assert f"nonce={mock_urandom.return_value.hex()}" in signed_url + assert "sign=" in signed_url + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.os.urandom") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_sign_file_empty_secret(self, mock_config, mock_urandom, mock_time): + # Setup + mock_config.FILES_URL = "http://localhost:5001" + mock_config.SECRET_KEY = None # Empty secret + mock_time.return_value = 1700000000 + mock_urandom.return_value = b"1234567890abcdef" + + # Execute + signed_url = DatasourceFileManager.sign_file("file_id", ".png") + assert "sign=" in signed_url + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_verify_file(self, mock_config, mock_time): + # Setup + mock_config.SECRET_KEY = "test_secret" + mock_config.FILES_ACCESS_TIMEOUT = 300 + mock_time.return_value = 1700000000 + + datasource_file_id = "file_id_123" + timestamp = "1699999800" # 200 seconds ago + nonce = "some_nonce" + + # Manually calculate sign + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = b"test_secret" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + # Execute & Verify Success + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True + + # Verify Failure - Wrong Sign + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False + + # Verify Failure - Timeout + mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout) + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_verify_file_empty_secret(self, mock_config, mock_time): + # Setup + mock_config.SECRET_KEY = "" # Empty string secret + mock_config.FILES_ACCESS_TIMEOUT = 300 + mock_time.return_value = 1700000000 + + datasource_file_id = "file_id_123" + timestamp = "1699999800" + nonce = "some_nonce" + + # Calculate with empty secret + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + sign = hmac.new(b"", data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + user_id = "user_123" + tenant_id = "tenant_456" + file_binary = b"fake binary data" + mimetype = "image/png" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mimetype, + filename="test.png", + ) + + # Verify + assert upload_file.tenant_id == tenant_id + assert upload_file.name == "test.png" + assert upload_file.size == len(file_binary) + assert upload_file.mime_type == mimetype + assert upload_file.key == f"datasources/{tenant_id}/unique_hex.png" + + mock_storage.save.assert_called_once_with(upload_file.key, file_binary) + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw_filename_no_extension(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + user_id = "user_123" + tenant_id = "tenant_456" + file_binary = b"fake binary data" + mimetype = "image/png" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mimetype, + filename="test", # No extension + ) + + # Verify + assert upload_file.name == "test.png" # Should append extension + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + @patch("core.datasource.datasource_file_manager.guess_extension") + def test_create_file_by_raw_unknown_extension(self, mock_guess_ext, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_guess_ext.return_value = None # Cannot guess + mock_uuid.return_value = MagicMock(hex="unique_hex") + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id="user", + tenant_id="tenant", + conversation_id=None, + file_binary=b"data", + mimetype="application/x-unknown", + ) + + # Verify + assert upload_file.extension == ".bin" + assert upload_file.name == "unique_hex.bin" + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw_no_filename(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id="user_123", + tenant_id="tenant_456", + conversation_id=None, + file_binary=b"data", + mimetype="application/pdf", + ) + + # Verify + assert upload_file.name == "unique_hex.pdf" + assert upload_file.extension == ".pdf" + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_mimetype_from_guess(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"bits" + mock_response.headers = {} # No content-type in headers + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.png" + ) + + # Verify + assert tool_file.mimetype == "image/png" # Guessed from .png in URL + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_mimetype_default(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"bits" + mock_response.headers = {} + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", + tenant_id="tenant_456", + file_url="https://example.com/unknown", # No extension, no headers + ) + + # Verify + assert tool_file.mimetype == "application/octet-stream" + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_success(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"downloaded bits" + mock_response.headers = {"Content-Type": "image/jpeg"} + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.jpg" + ) + + # Verify + assert tool_file.mimetype == "image/jpeg" + assert tool_file.size == len(b"downloaded bits") + assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg" + mock_storage.save.assert_called_once() + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + def test_create_file_by_url_timeout(self, mock_ssrf): + # Setup + mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout") + + # Execute & Verify + with pytest.raises(ValueError, match="timeout when downloading file"): + DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/large.file" + ) + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary(self, mock_storage, mock_db): + # Setup + mock_upload_file = MagicMock(spec=UploadFile) + mock_upload_file.key = "some_key" + mock_upload_file.mime_type = "image/png" + + mock_query = mock_db.session.query.return_value + mock_where = mock_query.where.return_value + mock_where.first.return_value = mock_upload_file + + mock_storage.load_once.return_value = b"file content" + + # Execute + result = DatasourceFileManager.get_file_binary("file_id") + + # Verify + assert result == (b"file content", "image/png") + + # Case: Not found + mock_where.first.return_value = None + assert DatasourceFileManager.get_file_binary("unknown") is None + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id(self, mock_storage, mock_db): + # Setup + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = "http://localhost/files/tools/tool_id.png" + + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.file_key = "tool_key" + mock_tool_file.mimetype = "image/png" + + # Mock query sequence + def mock_query(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + elif model == ToolFile: + m.where.return_value.first.return_value = mock_tool_file + return m + + mock_db.session.query.side_effect = mock_query + mock_storage.load_once.return_value = b"tool content" + + # Execute + result = DatasourceFileManager.get_file_binary_by_message_file_id("msg_file_id") + + # Verify + assert result == (b"tool content", "image/png") + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id_with_extension(self, mock_storage, mock_db): + # Test that it correctly parses tool_id even with extension in URL + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = "http://localhost/files/tools/abcdef.png" + + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "abcdef" + mock_tool_file.file_key = "tk" + mock_tool_file.mimetype = "image/png" + + def mock_query(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + else: + m.where.return_value.first.return_value = mock_tool_file + return m + + mock_db.session.query.side_effect = mock_query + mock_storage.load_once.return_value = b"bits" + + result = DatasourceFileManager.get_file_binary_by_message_file_id("m") + assert result == (b"bits", "image/png") + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db): + # Setup common mock + mock_query_obj = MagicMock() + mock_db.session.query.return_value = mock_query_obj + mock_query_obj.where.return_value.first.return_value = None + + # Case 1: Message file not found + assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None + + # Case 2: Message file found but tool file not found + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = None + + def mock_query_v2(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + else: + m.where.return_value.first.return_value = None + return m + + mock_db.session.query.side_effect = mock_query_v2 + assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_generator_by_upload_file_id(self, mock_storage, mock_db): + # Setup + mock_upload_file = MagicMock(spec=UploadFile) + mock_upload_file.key = "upload_key" + mock_upload_file.mime_type = "text/plain" + + mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file + + mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"]) + + # Execute + stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("upload_id") + + # Verify + assert mimetype == "text/plain" + assert list(stream) == [b"chunk1", b"chunk2"] + + # Case: Not found + mock_db.session.query.return_value.where.return_value.first.return_value = None + stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none") + assert stream is None + assert mimetype is None diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index 52c91fb8c9..d5eeae912c 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -1,9 +1,15 @@ import types from collections.abc import Generator +import pytest + +from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceMessage +from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType +from core.datasource.errors import DatasourceProviderNotFoundError from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.file import File +from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent @@ -15,6 +21,22 @@ def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, Non ) +def _drain_generator(gen: Generator[DatasourceMessage, None, object]) -> tuple[list[DatasourceMessage], object | None]: + messages: list[DatasourceMessage] = [] + try: + while True: + messages.append(next(gen)) + except StopIteration as e: + return messages, e.value + + +def _invalidate_recyclable_contextvars() -> None: + """ + Ensure RecyclableContextVar.get() raises LookupError until reset by code under test. + """ + RecyclableContextVar.increment_thread_recycles() + + def test_get_icon_url_calls_runtime(mocker): fake_runtime = mocker.Mock() fake_runtime.get_icon_url.return_value = "https://icon" @@ -30,6 +52,119 @@ def test_get_icon_url_calls_runtime(mocker): DatasourceManager.get_datasource_runtime.assert_called_once() +def test_get_datasource_runtime_delegates_to_provider_controller(mocker): + provider_controller = mocker.Mock() + provider_controller.get_datasource.return_value = object() + mocker.patch.object(DatasourceManager, "get_datasource_plugin_provider", return_value=provider_controller) + + runtime = DatasourceManager.get_datasource_runtime( + provider_id="prov/x", + datasource_name="ds", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + assert runtime is provider_controller.get_datasource.return_value + provider_controller.get_datasource.assert_called_once_with("ds") + + +@pytest.mark.parametrize( + ("datasource_type", "controller_path"), + [ + ( + DatasourceProviderType.ONLINE_DOCUMENT, + "core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.ONLINE_DRIVE, + "core.datasource.datasource_manager.OnlineDriveDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.WEBSITE_CRAWL, + "core.datasource.datasource_manager.WebsiteCrawlDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.LOCAL_FILE, + "core.datasource.datasource_manager.LocalFileDatasourcePluginProviderController", + ), + ], +) +def test_get_datasource_plugin_provider_creates_controller_and_caches(mocker, datasource_type, controller_path): + _invalidate_recyclable_contextvars() + + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + fetch = mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + ctrl_cls = mocker.patch(controller_path) + + first = DatasourceManager.get_datasource_plugin_provider( + provider_id=f"prov/{datasource_type.value}", + tenant_id="t1", + datasource_type=datasource_type, + ) + second = DatasourceManager.get_datasource_plugin_provider( + provider_id=f"prov/{datasource_type.value}", + tenant_id="t1", + datasource_type=datasource_type, + ) + + assert first is second + assert fetch.call_count == 1 + assert ctrl_cls.call_count == 1 + + +def test_get_datasource_plugin_provider_raises_when_provider_entity_missing(mocker): + _invalidate_recyclable_contextvars() + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=None, + ) + + with pytest.raises(DatasourceProviderNotFoundError, match="plugin provider prov/notfound not found"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/notfound", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + + +def test_get_datasource_plugin_provider_raises_for_unsupported_type(mocker): + _invalidate_recyclable_contextvars() + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + + with pytest.raises(ValueError, match="Unsupported datasource type"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/x", + tenant_id="t1", + datasource_type=types.SimpleNamespace(), # not a DatasourceProviderType at runtime + ) + + +def test_get_datasource_plugin_provider_raises_when_controller_none(mocker): + _invalidate_recyclable_contextvars() + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + mocker.patch( + "core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController", + return_value=None, + ) + + with pytest.raises(DatasourceProviderNotFoundError, match="Datasource provider prov/x not found"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/x", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + + def test_stream_online_results_yields_messages_online_document(mocker): # stub runtime to yield a text message def _doc_messages(**_): @@ -60,6 +195,148 @@ def test_stream_online_results_yields_messages_online_document(mocker): assert msgs[0].message.text == "hello" +def test_stream_online_results_sets_credentials_and_returns_empty_dict_online_document(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def get_online_document_page_content(self, **_kwargs): + yield from _gen_messages_text_only("hello") + + runtime = _Runtime() + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ) + + gen = DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="cred", + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + messages, final_value = _drain_generator(gen) + + assert runtime.runtime.credentials == {"token": "t"} + assert [m.message.text for m in messages] == ["hello"] + assert final_value == {} + + +def test_stream_online_results_raises_when_missing_params(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def get_online_document_page_content(self, **_kwargs): + yield from _gen_messages_text_only("never") + + def online_drive_download_file(self, **_kwargs): + yield from _gen_messages_text_only("never") + + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=_Runtime()) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={}, + ) + + with pytest.raises(ValueError, match="datasource_param is required for ONLINE_DOCUMENT streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + with pytest.raises(ValueError, match="online_drive_request is required for ONLINE_DRIVE streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + +def test_stream_online_results_yields_messages_and_returns_empty_dict_online_drive(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def online_drive_download_file(self, **_kwargs): + yield from _gen_messages_text_only("drive") + + runtime = _Runtime() + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ) + + gen = DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="cred", + datasource_param=None, + online_drive_request=types.SimpleNamespace(id="fid", bucket="b"), + ) + messages, final_value = _drain_generator(gen) + + assert runtime.runtime.credentials == {"token": "t"} + assert [m.message.text for m in messages] == ["drive"] + assert final_value == {} + + +def test_stream_online_results_raises_for_unsupported_stream_type(mocker): + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=mocker.Mock()) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={}, + ) + + with pytest.raises(ValueError, match="Unsupported datasource type for streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="website_crawl", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + def test_stream_node_events_emits_events_online_document(mocker): # make manager's low-level stream produce TEXT only mocker.patch.object( @@ -93,6 +370,260 @@ def test_stream_node_events_emits_events_online_document(mocker): assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED +def test_stream_node_events_builds_file_and_variables_from_messages(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"), + meta={}, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, + message=DatasourceMessage.TextMessage(text="hello"), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text="http://example.com"), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="a", stream=True), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="b", stream=True), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="x", variable_value=1, stream=False), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.JSON, + message=DatasourceMessage.JsonMessage(json_object={"k": "v"}), + meta=None, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + fake_tool_file = types.SimpleNamespace(mimetype="image/png") + + class _Session: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def scalar(self, _stmt): + return fake_tool_file + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) + mocker.patch( + "core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE + ) + built = File( + tenant_id="t1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="tool_file_1", + extension=".png", + mime_type="image/png", + storage_key="k", + ) + build_from_mapping = mocker.patch( + "core.datasource.datasource_manager.file_factory.build_from_mapping", + return_value=built, + ) + + variable_pool = mocker.Mock() + + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={"k": "v"}, + datasource_info={"info": "x"}, + variable_pool=variable_pool, + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + ) + + build_from_mapping.assert_called_once() + variable_pool.add.assert_not_called() + + assert any(isinstance(e, StreamChunkEvent) and e.chunk == "hello" for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.chunk.startswith("Link: http") for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "a" for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "b" for e in events) + assert isinstance(events[-2], StreamChunkEvent) + assert events[-2].is_final is True + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.outputs["v"] == "ab" + assert events[-1].node_run_result.outputs["x"] == 1 + + +def test_stream_node_events_raises_when_toolfile_missing(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/missing.png"), + meta={}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + class _Session: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def scalar(self, _stmt): + return None + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) + + with pytest.raises(ValueError, match="ToolFile not found for file_id=missing, tenant_id=t1"): + list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={}, + variable_pool=mocker.Mock(), + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + ) + + +def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + file_in = File( + tenant_id="t1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="tf", + extension=".pdf", + mime_type="application/pdf", + storage_key="k", + ) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.FileMessage(file_marker="file_marker"), + meta={"file": file_in}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + variable_pool = mocker.Mock() + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={"k": "v"}, + variable_pool=variable_pool, + datasource_param=None, + online_drive_request=types.SimpleNamespace(id="id", bucket="b"), + ) + ) + + variable_pool.add.assert_called_once() + assert variable_pool.add.call_args[0][0] == ["nodeA", "file"] + assert variable_pool.add.call_args[0][1] == file_in + + completed = events[-1] + assert isinstance(completed, StreamCompletedEvent) + assert completed.node_run_result.outputs["file"] == file_in + assert completed.node_run_result.outputs["datasource_type"] == DatasourceProviderType.ONLINE_DRIVE + + +def test_stream_node_events_skips_file_build_for_non_online_types(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"), + meta={}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + build_from_mapping = mocker.patch("core.datasource.datasource_manager.file_factory.build_from_mapping") + + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="website_crawl", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={}, + variable_pool=mocker.Mock(), + datasource_param=None, + online_drive_request=None, + ) + ) + + build_from_mapping.assert_not_called() + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.outputs["file"] is None + + def test_get_upload_file_by_id_builds_file(mocker): # fake UploadFile row fake_row = types.SimpleNamespace( @@ -133,3 +664,27 @@ def test_get_upload_file_by_id_builds_file(mocker): f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") assert f.related_id == "fid" assert f.extension == ".txt" + + +def test_get_upload_file_by_id_raises_when_missing(mocker): + class _Q: + def where(self, *_args, **_kwargs): + return self + + def first(self): + return None + + class _S: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def query(self, *_): + return _Q() + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S()) + + with pytest.raises(ValueError, match="UploadFile not found for file_id=fid, tenant_id=t1"): + DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") diff --git a/api/tests/unit_tests/core/datasource/test_errors.py b/api/tests/unit_tests/core/datasource/test_errors.py new file mode 100644 index 0000000000..95986415b1 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/test_errors.py @@ -0,0 +1,64 @@ +from unittest.mock import MagicMock + +from core.datasource.entities.datasource_entities import DatasourceInvokeMeta +from core.datasource.errors import ( + DatasourceApiSchemaError, + DatasourceEngineInvokeError, + DatasourceInvokeError, + DatasourceNotFoundError, + DatasourceNotSupportedError, + DatasourceParameterValidationError, + DatasourceProviderCredentialValidationError, + DatasourceProviderNotFoundError, +) + + +class TestErrors: + def test_datasource_provider_not_found_error(self): + error = DatasourceProviderNotFoundError("Provider not found") + assert str(error) == "Provider not found" + assert isinstance(error, ValueError) + + def test_datasource_not_found_error(self): + error = DatasourceNotFoundError("Datasource not found") + assert str(error) == "Datasource not found" + assert isinstance(error, ValueError) + + def test_datasource_parameter_validation_error(self): + error = DatasourceParameterValidationError("Validation failed") + assert str(error) == "Validation failed" + assert isinstance(error, ValueError) + + def test_datasource_provider_credential_validation_error(self): + error = DatasourceProviderCredentialValidationError("Credential validation failed") + assert str(error) == "Credential validation failed" + assert isinstance(error, ValueError) + + def test_datasource_not_supported_error(self): + error = DatasourceNotSupportedError("Not supported") + assert str(error) == "Not supported" + assert isinstance(error, ValueError) + + def test_datasource_invoke_error(self): + error = DatasourceInvokeError("Invoke error") + assert str(error) == "Invoke error" + assert isinstance(error, ValueError) + + def test_datasource_api_schema_error(self): + error = DatasourceApiSchemaError("API schema error") + assert str(error) == "API schema error" + assert isinstance(error, ValueError) + + def test_datasource_engine_invoke_error(self): + mock_meta = MagicMock(spec=DatasourceInvokeMeta) + error = DatasourceEngineInvokeError(meta=mock_meta) + assert error.meta == mock_meta + assert isinstance(error, Exception) + + def test_datasource_engine_invoke_error_init(self): + # Test initialization with meta + meta = DatasourceInvokeMeta(time_cost=1.5, error="Engine failed") + error = DatasourceEngineInvokeError(meta=meta) + assert error.meta == meta + assert error.meta.time_cost == 1.5 + assert error.meta.error == "Engine failed" diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py new file mode 100644 index 0000000000..43f582feb7 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -0,0 +1,337 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.entities.datasource_entities import DatasourceMessage +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from dify_graph.file import File +from dify_graph.file.enums import FileTransferMethod, FileType +from models.tools import ToolFile + + +class TestDatasourceFileMessageTransformer: + def test_transform_text_and_link_messages(self): + # Setup + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, message=DatasourceMessage.TextMessage(text="hello") + ), + DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text="https://example.com"), + ), + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 2 + assert result[0].type == DatasourceMessage.MessageType.TEXT + assert result[0].message.text == "hello" + assert result[1].type == DatasourceMessage.MessageType.LINK + assert result[1].message.text == "https://example.com" + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + def test_transform_image_message_success(self, mock_guess_ext, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "file_id_123" + mock_tool_file.mimetype = "image/png" + mock_manager.create_file_by_url.return_value = mock_tool_file + mock_guess_ext.return_value = ".png" + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, + message=DatasourceMessage.TextMessage(text="https://example.com/image.png"), + meta={"some": "meta"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1", conversation_id="conv1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/file_id_123.png" + assert result[0].meta == {"some": "meta"} + mock_manager.create_file_by_url.assert_called_once_with( + user_id="user1", tenant_id="tenant1", file_url="https://example.com/image.png", conversation_id="conv1" + ) + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + def test_transform_image_message_failure(self, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_manager.create_file_by_url.side_effect = Exception("Download failed") + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, + message=DatasourceMessage.TextMessage(text="https://example.com/image.png"), + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.TEXT + assert "Failed to download image" in result[0].message.text + assert "Download failed" in result[0].message.text + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + def test_transform_blob_message_image(self, mock_guess_ext, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_456" + mock_tool_file.mimetype = "image/jpeg" + mock_manager.create_file_by_raw.return_value = mock_tool_file + mock_guess_ext.return_value = ".jpg" + + blob_data = b"fake-image-bits" + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=blob_data), + meta={"mime_type": "image/jpeg", "file_name": "test.jpg"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/blob_id_456.jpg" + mock_manager.create_file_by_raw.assert_called_once() + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + @patch("core.datasource.utils.message_transformer.guess_type") + def test_transform_blob_message_binary_guess_mimetype( + self, mock_guess_type, mock_guess_ext, mock_tool_file_manager_cls + ): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_789" + mock_tool_file.mimetype = "application/pdf" + mock_manager.create_file_by_raw.return_value = mock_tool_file + mock_guess_type.return_value = ("application/pdf", None) + mock_guess_ext.return_value = ".pdf" + + blob_data = b"fake-pdf-bits" + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=blob_data), + meta={"file_name": "test.pdf"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK + assert result[0].message.text == "/files/datasources/blob_id_789.pdf" + + def test_transform_blob_message_invalid_type(self): + # Setup + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, message=DatasourceMessage.TextMessage(text="not a blob") + ) + ] + + # Execute & Verify + with pytest.raises(ValueError, match="unexpected message type"): + list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + def test_transform_file_tool_file_image(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.TOOL_FILE + mock_file.related_id = "related_123" + mock_file.extension = ".png" + mock_file.type = FileType.IMAGE + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="ignored"), + meta={"file": mock_file}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/related_123.png" + + def test_transform_file_tool_file_binary(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.TOOL_FILE + mock_file.related_id = "related_456" + mock_file.extension = ".txt" + mock_file.type = FileType.DOCUMENT + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="ignored"), + meta={"file": mock_file}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.LINK + assert result[0].message.text == "/files/datasources/related_456.txt" + + def test_transform_file_other_transfer_method(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.REMOTE_URL + + msg = DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="remote image"), + meta={"file": mock_file}, + ) + messages = [msg] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0] == msg + + def test_transform_other_message_type(self): + # JSON type is yielded by the default 'else' block or the 'yield message' at the end + msg = DatasourceMessage( + type=DatasourceMessage.MessageType.JSON, message=DatasourceMessage.JsonMessage(json_object={"k": "v"}) + ) + messages = [msg] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0] == msg + + def test_get_datasource_file_url(self): + # Test with extension + url = DatasourceFileMessageTransformer.get_datasource_file_url("file1", ".jpg") + assert url == "/files/datasources/file1.jpg" + + # Test without extension + url = DatasourceFileMessageTransformer.get_datasource_file_url("file2", None) + assert url == "/files/datasources/file2.bin" + + def test_transform_blob_message_no_meta_filename(self): + # This tests line 70 where filename might be None + with patch("core.datasource.utils.message_transformer.ToolFileManager") as mock_tool_file_manager_cls: + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_no_name" + mock_tool_file.mimetype = "application/octet-stream" + mock_manager.create_file_by_raw.return_value = mock_tool_file + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=b"data"), + meta={}, # No mime_type, no file_name + ) + ] + + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK + assert result[0].message.text == "/files/datasources/blob_id_no_name.bin" + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + def test_transform_image_message_not_text_message(self, mock_tool_file_manager_cls): + # This tests line 24-26 where it checks if message is instance of TextMessage + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, message=DatasourceMessage.BlobMessage(blob=b"not-text") + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify - should yield unchanged if it's not a TextMessage + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE + assert isinstance(result[0].message, DatasourceMessage.BlobMessage) diff --git a/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py new file mode 100644 index 0000000000..2945eb5523 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py @@ -0,0 +1,101 @@ +from collections.abc import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, + WebsiteCrawlMessage, +) +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin + + +class TestWebsiteCrawlDatasourcePlugin: + @pytest.fixture + def mock_entity(self): + entity = MagicMock(spec=DatasourceEntity) + entity.identity = MagicMock() + entity.identity.provider = "test-provider" + entity.identity.name = "test-name" + return entity + + @pytest.fixture + def mock_runtime(self): + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test-key"} + return runtime + + def test_init(self, mock_entity, mock_runtime): + # Arrange + tenant_id = "test-tenant-id" + icon = "test-icon" + plugin_unique_identifier = "test-plugin-id" + + # Act + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.tenant_id == tenant_id + assert plugin.plugin_unique_identifier == plugin_unique_identifier + assert plugin.entity == mock_entity + assert plugin.runtime == mock_runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self, mock_entity, mock_runtime): + # Arrange + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.datasource_provider_type() == DatasourceProviderType.WEBSITE_CRAWL + + def test_get_website_crawl(self, mock_entity, mock_runtime): + # Arrange + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id="test-tenant-id", + icon="test-icon", + plugin_unique_identifier="test-plugin-id", + ) + + user_id = "test-user-id" + datasource_parameters = {"url": "https://example.com"} + provider_type = "firecrawl" + + mock_message = MagicMock(spec=WebsiteCrawlMessage) + + # Mock PluginDatasourceManager + with patch("core.datasource.website_crawl.website_crawl_plugin.PluginDatasourceManager") as mock_manager_class: + mock_manager = mock_manager_class.return_value + mock_manager.get_website_crawl.return_value = (msg for msg in [mock_message]) + + # Act + result = plugin.get_website_crawl( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert isinstance(result, Generator) + messages = list(result) + assert len(messages) == 1 + assert messages[0] == mock_message + + mock_manager.get_website_crawl.assert_called_once_with( + tenant_id="test-tenant-id", + user_id=user_id, + datasource_provider="test-provider", + datasource_name="test-name", + credentials={"api_key": "test-key"}, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) diff --git a/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py new file mode 100644 index 0000000000..b7822ba800 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py @@ -0,0 +1,95 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController + + +class TestWebsiteCrawlDatasourcePluginProviderController: + @pytest.fixture + def mock_entity(self): + entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + entity.datasources = [] + entity.identity = MagicMock() + entity.identity.icon = "test-icon" + return entity + + def test_init(self, mock_entity): + # Arrange + plugin_id = "test-plugin-id" + plugin_unique_identifier = "test-unique-id" + tenant_id = "test-tenant-id" + + # Act + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self, mock_entity): + # Arrange + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.WEBSITE_CRAWL + + def test_get_datasource_success(self, mock_entity): + # Arrange + datasource_name = "test-datasource" + tenant_id = "test-tenant-id" + plugin_unique_identifier = "test-unique-id" + + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity = MagicMock() + mock_datasource_entity.identity.name = datasource_name + mock_entity.datasources = [mock_datasource_entity] + + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + with patch( + "core.datasource.website_crawl.website_crawl_provider.WebsiteCrawlDatasourcePlugin" + ) as mock_plugin_class: + mock_plugin_instance = mock_plugin_class.return_value + result = controller.get_datasource(datasource_name) + + # Assert + assert result == mock_plugin_instance + mock_plugin_class.assert_called_once() + args, kwargs = mock_plugin_class.call_args + assert kwargs["entity"] == mock_datasource_entity + assert isinstance(kwargs["runtime"], DatasourceRuntime) + assert kwargs["runtime"].tenant_id == tenant_id + assert kwargs["tenant_id"] == tenant_id + assert kwargs["icon"] == "test-icon" + assert kwargs["plugin_unique_identifier"] == plugin_unique_identifier + + def test_get_datasource_not_found(self, mock_entity): + # Arrange + datasource_name = "non-existent" + mock_entity.datasources = [] + + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Act & Assert + with pytest.raises(ValueError, match=f"Datasource with name {datasource_name} not found"): + controller.get_datasource(datasource_name)