test: add test for api core datasource (#32414)

Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
mahammadasim 2026-03-11 00:42:46 +05:30 committed by GitHub
parent 86b6868772
commit 1083f5c46a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 3033 additions and 3 deletions

View File

@ -59,8 +59,6 @@ class DatasourcePluginProviderController(ABC):
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential

View File

@ -0,0 +1,90 @@
from unittest.mock import MagicMock, patch
from configs import dify_config
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType
class ConcreteDatasourcePlugin(DatasourcePlugin):
"""
Concrete implementation of DatasourcePlugin for testing purposes.
Since DatasourcePlugin is an ABC, we need a concrete class to instantiate it.
"""
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE
class TestDatasourcePlugin:
def test_init(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
runtime = MagicMock(spec=DatasourceRuntime)
icon = "test-icon.png"
# Act
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
# Assert
assert plugin.entity == entity
assert plugin.runtime == runtime
assert plugin.icon == icon
def test_datasource_provider_type(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
runtime = MagicMock(spec=DatasourceRuntime)
icon = "test-icon.png"
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
# Act
provider_type = plugin.datasource_provider_type()
# Call the base class method to ensure it's covered
base_provider_type = DatasourcePlugin.datasource_provider_type(plugin)
# Assert
assert provider_type == DatasourceProviderType.LOCAL_FILE
assert base_provider_type == DatasourceProviderType.LOCAL_FILE
def test_fork_datasource_runtime(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceEntity)
mock_entity_copy = MagicMock(spec=DatasourceEntity)
mock_entity.model_copy.return_value = mock_entity_copy
runtime = MagicMock(spec=DatasourceRuntime)
new_runtime = MagicMock(spec=DatasourceRuntime)
icon = "test-icon.png"
plugin = ConcreteDatasourcePlugin(entity=mock_entity, runtime=runtime, icon=icon)
# Act
new_plugin = plugin.fork_datasource_runtime(new_runtime)
# Assert
assert isinstance(new_plugin, ConcreteDatasourcePlugin)
assert new_plugin.entity == mock_entity_copy
assert new_plugin.runtime == new_runtime
assert new_plugin.icon == icon
mock_entity.model_copy.assert_called_once()
def test_get_icon_url(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
runtime = MagicMock(spec=DatasourceRuntime)
icon = "test-icon.png"
tenant_id = "test-tenant-id"
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
# Mocking dify_config.CONSOLE_API_URL
with patch.object(dify_config, "CONSOLE_API_URL", "https://api.dify.ai"):
# Act
icon_url = plugin.get_icon_url(tenant_id)
# Assert
expected_url = (
f"https://api.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={icon}"
)
assert icon_url == expected_url

View File

@ -0,0 +1,265 @@
from unittest.mock import MagicMock, patch
import pytest
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.datasource_entities import (
DatasourceProviderEntityWithPlugin,
DatasourceProviderType,
)
from core.entities.provider_entities import ProviderConfig
from core.tools.errors import ToolProviderCredentialValidationError
class ConcreteDatasourcePluginProviderController(DatasourcePluginProviderController):
"""
Concrete implementation of DatasourcePluginProviderController for testing purposes.
"""
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
return MagicMock(spec=DatasourcePlugin)
class TestDatasourcePluginProviderController:
def test_init(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
tenant_id = "test-tenant-id"
# Act
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
# Assert
assert controller.entity == mock_entity
assert controller.tenant_id == tenant_id
def test_need_credentials(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
tenant_id = "test-tenant-id"
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
# Case 1: credentials_schema is None
mock_entity.credentials_schema = None
assert controller.need_credentials is False
# Case 2: credentials_schema is empty
mock_entity.credentials_schema = []
assert controller.need_credentials is False
# Case 3: credentials_schema has items
mock_entity.credentials_schema = [MagicMock()]
assert controller.need_credentials is True
@patch("core.datasource.__base.datasource_provider.PluginToolManager")
def test_validate_credentials(self, mock_manager_class):
# Arrange
mock_manager = mock_manager_class.return_value
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.identity = MagicMock()
mock_entity.identity.name = "test-provider"
tenant_id = "test-tenant-id"
user_id = "test-user-id"
credentials = {"api_key": "secret"}
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
# Act: Successful validation
mock_manager.validate_datasource_credentials.return_value = True
controller._validate_credentials(user_id, credentials)
mock_manager.validate_datasource_credentials.assert_called_once_with(
tenant_id=tenant_id,
user_id=user_id,
provider="test-provider",
credentials=credentials,
)
# Act: Failed validation
mock_manager.validate_datasource_credentials.return_value = False
with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"):
controller._validate_credentials(user_id, credentials)
def test_provider_type(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Act & Assert
assert controller.provider_type == DatasourceProviderType.LOCAL_FILE
def test_validate_credentials_format_empty_schema(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.credentials_schema = []
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
credentials = {}
# Act & Assert (Should not raise anything)
controller.validate_credentials_format(credentials)
def test_validate_credentials_format_unknown_credential(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.identity = MagicMock()
mock_entity.identity.name = "test-provider"
mock_entity.credentials_schema = []
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
credentials = {"unknown": "value"}
# Act & Assert
with pytest.raises(
ToolProviderCredentialValidationError, match="credential unknown not found in provider test-provider"
):
controller.validate_credentials_format(credentials)
def test_validate_credentials_format_required_missing(self):
# Arrange
mock_config = MagicMock(spec=ProviderConfig)
mock_config.name = "api_key"
mock_config.required = True
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.credentials_schema = [mock_config]
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Act & Assert
with pytest.raises(ToolProviderCredentialValidationError, match="credential api_key is required"):
controller.validate_credentials_format({})
def test_validate_credentials_format_not_required_null(self):
# Arrange
mock_config = MagicMock(spec=ProviderConfig)
mock_config.name = "optional"
mock_config.required = False
mock_config.default = None
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.credentials_schema = [mock_config]
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Act & Assert
credentials = {"optional": None}
controller.validate_credentials_format(credentials)
assert credentials["optional"] is None
def test_validate_credentials_format_type_mismatch_text(self):
# Arrange
mock_config = MagicMock(spec=ProviderConfig)
mock_config.name = "text_field"
mock_config.required = True
mock_config.type = ProviderConfig.Type.TEXT_INPUT
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.credentials_schema = [mock_config]
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Act & Assert
with pytest.raises(ToolProviderCredentialValidationError, match="credential text_field should be string"):
controller.validate_credentials_format({"text_field": 123})
def test_validate_credentials_format_select_validation(self):
# Arrange
mock_option = MagicMock()
mock_option.value = "opt1"
mock_config = MagicMock(spec=ProviderConfig)
mock_config.name = "select_field"
mock_config.required = True
mock_config.type = ProviderConfig.Type.SELECT
mock_config.options = [mock_option]
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.credentials_schema = [mock_config]
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Case 1: Value not string
with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be string"):
controller.validate_credentials_format({"select_field": 123})
# Case 2: Options not list
mock_config.options = "invalid"
with pytest.raises(
ToolProviderCredentialValidationError, match="credential select_field options should be list"
):
controller.validate_credentials_format({"select_field": "opt1"})
# Case 3: Value not in options
mock_config.options = [mock_option]
with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be one of"):
controller.validate_credentials_format({"select_field": "invalid_opt"})
def test_get_datasource_base(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Act
result = DatasourcePluginProviderController.get_datasource(controller, "test")
# Assert
assert result is None
def test_validate_credentials_format_hits_pop(self):
# Arrange
mock_config = MagicMock(spec=ProviderConfig)
mock_config.name = "valid_field"
mock_config.required = True
mock_config.type = ProviderConfig.Type.TEXT_INPUT
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.credentials_schema = [mock_config]
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Act
credentials = {"valid_field": "valid_value"}
controller.validate_credentials_format(credentials)
# Assert
assert "valid_field" in credentials
assert credentials["valid_field"] == "valid_value"
def test_validate_credentials_format_hits_continue(self):
# Arrange
mock_config = MagicMock(spec=ProviderConfig)
mock_config.name = "optional_field"
mock_config.required = False
mock_config.default = None
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.credentials_schema = [mock_config]
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Act
credentials = {"optional_field": None}
controller.validate_credentials_format(credentials)
# Assert
assert credentials["optional_field"] is None
def test_validate_credentials_format_default_values(self):
# Arrange
mock_config_text = MagicMock(spec=ProviderConfig)
mock_config_text.name = "text_def"
mock_config_text.required = False
mock_config_text.type = ProviderConfig.Type.TEXT_INPUT
mock_config_text.default = 123 # Int default, should be converted to str
mock_config_other = MagicMock(spec=ProviderConfig)
mock_config_other.name = "other_def"
mock_config_other.required = False
mock_config_other.type = "OTHER"
mock_config_other.default = "fallback"
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.credentials_schema = [mock_config_text, mock_config_other]
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
# Act
credentials = {}
controller.validate_credentials_format(credentials)
# Assert
assert credentials["text_def"] == "123"
assert credentials["other_def"] == "fallback"

View File

@ -0,0 +1,26 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.__base.datasource_runtime import DatasourceRuntime, FakeDatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
class TestDatasourceRuntime:
def test_init(self):
runtime = DatasourceRuntime(
tenant_id="test-tenant",
datasource_id="test-ds",
invoke_from=InvokeFrom.DEBUGGER,
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
credentials={"key": "val"},
runtime_parameters={"p": "v"},
)
assert runtime.tenant_id == "test-tenant"
assert runtime.datasource_id == "test-ds"
assert runtime.credentials["key"] == "val"
def test_fake_datasource_runtime(self):
# This covers the FakeDatasourceRuntime class and its __init__
runtime = FakeDatasourceRuntime()
assert runtime.tenant_id == "fake_tenant_id"
assert runtime.datasource_id == "fake_datasource_id"
assert runtime.invoke_from == InvokeFrom.DEBUGGER
assert runtime.datasource_invoke_from == DatasourceInvokeFrom.RAG_PIPELINE

View File

@ -0,0 +1,150 @@
from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.tools.entities.common_entities import I18nObject
def test_datasource_api_entity():
label = I18nObject(en_US="label", zh_Hans="标签")
description = I18nObject(en_US="desc", zh_Hans="描述")
entity = DatasourceApiEntity(
author="author", name="name", label=label, description=description, labels=["l1", "l2"]
)
assert entity.author == "author"
assert entity.name == "name"
assert entity.label == label
assert entity.description == description
assert entity.labels == ["l1", "l2"]
assert entity.parameters is None
assert entity.output_schema is None
def test_datasource_provider_api_entity_defaults():
description = I18nObject(en_US="desc", zh_Hans="描述")
label = I18nObject(en_US="label", zh_Hans="标签")
entity = DatasourceProviderApiEntity(
id="id", author="author", name="name", description=description, icon="icon", label=label, type="type"
)
assert entity.id == "id"
assert entity.datasources == []
assert entity.is_team_authorization is False
assert entity.allow_delete is True
assert entity.plugin_id == ""
assert entity.plugin_unique_identifier == ""
assert entity.labels == []
def test_datasource_provider_api_entity_convert_none_to_empty_list():
description = I18nObject(en_US="desc", zh_Hans="描述")
label = I18nObject(en_US="label", zh_Hans="标签")
# Implicitly testing the field_validator "convert_none_to_empty_list"
entity = DatasourceProviderApiEntity(
id="id",
author="author",
name="name",
description=description,
icon="icon",
label=label,
type="type",
datasources=None, # type: ignore
)
assert entity.datasources == []
def test_datasource_provider_api_entity_to_dict():
description = I18nObject(en_US="desc", zh_Hans="描述")
label = I18nObject(en_US="label", zh_Hans="标签")
# Create a parameter that should be converted
param = DatasourceParameter.get_simple_instance(
name="test_param", typ=DatasourceParameter.DatasourceParameterType.SYSTEM_FILES, required=True
)
ds_entity = DatasourceApiEntity(
author="author", name="ds_name", label=label, description=description, parameters=[param]
)
provider_entity = DatasourceProviderApiEntity(
id="id",
author="author",
name="name",
description=description,
icon="icon",
label=label,
type="type",
masked_credentials={"key": "masked"},
datasources=[ds_entity],
labels=["l1"],
)
result = provider_entity.to_dict()
assert result["id"] == "id"
assert result["author"] == "author"
assert result["name"] == "name"
assert result["description"] == description.to_dict()
assert result["icon"] == "icon"
assert result["label"] == label.to_dict()
assert result["type"] == "type"
assert result["team_credentials"] == {"key": "masked"}
assert result["is_team_authorization"] is False
assert result["allow_delete"] is True
assert result["labels"] == ["l1"]
# Check if parameter type was converted from SYSTEM_FILES to files
assert result["datasources"][0]["parameters"][0]["type"] == "files"
def test_datasource_provider_api_entity_to_dict_no_params():
description = I18nObject(en_US="desc", zh_Hans="描述")
label = I18nObject(en_US="label", zh_Hans="标签")
ds_entity = DatasourceApiEntity(
author="author", name="ds_name", label=label, description=description, parameters=None
)
provider_entity = DatasourceProviderApiEntity(
id="id",
author="author",
name="name",
description=description,
icon="icon",
label=label,
type="type",
datasources=[ds_entity],
)
result = provider_entity.to_dict()
assert result["datasources"][0]["parameters"] is None
def test_datasource_provider_api_entity_to_dict_other_param_type():
description = I18nObject(en_US="desc", zh_Hans="描述")
label = I18nObject(en_US="label", zh_Hans="标签")
param = DatasourceParameter.get_simple_instance(
name="test_param", typ=DatasourceParameter.DatasourceParameterType.STRING, required=True
)
ds_entity = DatasourceApiEntity(
author="author", name="ds_name", label=label, description=description, parameters=[param]
)
provider_entity = DatasourceProviderApiEntity(
id="id",
author="author",
name="name",
description=description,
icon="icon",
label=label,
type="type",
datasources=[ds_entity],
)
result = provider_entity.to_dict()
assert result["datasources"][0]["parameters"][0]["type"] == "string"

View File

@ -0,0 +1,31 @@
from core.datasource.entities.common_entities import I18nObject
def test_i18n_object_fallback():
# Only en_US provided
obj = I18nObject(en_US="Hello")
assert obj.en_US == "Hello"
assert obj.zh_Hans == "Hello"
assert obj.pt_BR == "Hello"
assert obj.ja_JP == "Hello"
# Some fields provided
obj = I18nObject(en_US="Hello", zh_Hans="你好")
assert obj.en_US == "Hello"
assert obj.zh_Hans == "你好"
assert obj.pt_BR == "Hello"
assert obj.ja_JP == "Hello"
def test_i18n_object_all_fields():
obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは")
assert obj.en_US == "Hello"
assert obj.zh_Hans == "你好"
assert obj.pt_BR == "Olá"
assert obj.ja_JP == "こんにちは"
def test_i18n_object_to_dict():
obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは")
expected_dict = {"en_US": "Hello", "zh_Hans": "你好", "pt_BR": "Olá", "ja_JP": "こんにちは"}
assert obj.to_dict() == expected_dict

View File

@ -0,0 +1,275 @@
from unittest.mock import patch
import pytest
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceIdentity,
DatasourceInvokeMeta,
DatasourceLabel,
DatasourceMessage,
DatasourceParameter,
DatasourceProviderEntity,
DatasourceProviderEntityWithPlugin,
DatasourceProviderIdentity,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
GetWebsiteCrawlRequest,
OnlineDocumentInfo,
OnlineDocumentPage,
OnlineDocumentPageContent,
OnlineDocumentPagesMessage,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
OnlineDriveDownloadFileRequest,
OnlineDriveFile,
OnlineDriveFileBucket,
WebsiteCrawlMessage,
WebSiteInfo,
WebSiteInfoDetail,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolLabelEnum
def test_datasource_provider_type():
assert DatasourceProviderType.value_of("online_document") == DatasourceProviderType.ONLINE_DOCUMENT
assert DatasourceProviderType.value_of("local_file") == DatasourceProviderType.LOCAL_FILE
with pytest.raises(ValueError, match="invalid mode value invalid"):
DatasourceProviderType.value_of("invalid")
def test_datasource_parameter_type():
param_type = DatasourceParameter.DatasourceParameterType.STRING
assert param_type.as_normal_type() == "string"
assert param_type.cast_value("test") == "test"
param_type = DatasourceParameter.DatasourceParameterType.NUMBER
assert param_type.cast_value("123") == 123
def test_datasource_parameter():
param = DatasourceParameter.get_simple_instance(
name="test_param",
typ=DatasourceParameter.DatasourceParameterType.STRING,
required=True,
options=["opt1", "opt2"],
)
assert param.name == "test_param"
assert param.type == DatasourceParameter.DatasourceParameterType.STRING
assert param.required is True
assert len(param.options) == 2
assert param.options[0].value == "opt1"
param_no_options = DatasourceParameter.get_simple_instance(
name="test_param_2", typ=DatasourceParameter.DatasourceParameterType.NUMBER, required=False
)
assert param_no_options.options == []
# Test init_frontend_parameter
# For STRING, it should just return the value as is (or cast to str)
frontend_param = param.init_frontend_parameter("val")
assert frontend_param == "val"
# Test parameter type methods
assert DatasourceParameter.DatasourceParameterType.STRING.as_normal_type() == "string"
assert DatasourceParameter.DatasourceParameterType.NUMBER.as_normal_type() == "number"
assert DatasourceParameter.DatasourceParameterType.SECRET_INPUT.as_normal_type() == "string"
assert DatasourceParameter.DatasourceParameterType.NUMBER.cast_value("10.5") == 10.5
assert DatasourceParameter.DatasourceParameterType.BOOLEAN.cast_value("true") is True
assert DatasourceParameter.DatasourceParameterType.FILES.cast_value(["f1", "f2"]) == ["f1", "f2"]
def test_datasource_identity():
label = I18nObject(en_US="label", zh_Hans="标签")
identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider", icon="icon")
assert identity.author == "author"
assert identity.name == "name"
assert identity.label == label
assert identity.provider == "provider"
assert identity.icon == "icon"
def test_datasource_entity():
label = I18nObject(en_US="label", zh_Hans="标签")
identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider")
description = I18nObject(en_US="desc", zh_Hans="描述")
entity = DatasourceEntity(
identity=identity,
description=description,
parameters=None, # Should be handled by validator
)
assert entity.parameters == []
param = DatasourceParameter.get_simple_instance("p1", DatasourceParameter.DatasourceParameterType.STRING, True)
entity_with_params = DatasourceEntity(identity=identity, description=description, parameters=[param])
assert entity_with_params.parameters == [param]
def test_datasource_provider_identity():
label = I18nObject(en_US="label", zh_Hans="标签")
description = I18nObject(en_US="desc", zh_Hans="描述")
identity = DatasourceProviderIdentity(
author="author", name="name", description=description, icon="icon.png", label=label, tags=[ToolLabelEnum.SEARCH]
)
assert identity.author == "author"
assert identity.name == "name"
assert identity.description == description
assert identity.icon == "icon.png"
assert identity.label == label
assert identity.tags == [ToolLabelEnum.SEARCH]
# Test generate_datasource_icon_url
with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config:
mock_config.CONSOLE_API_URL = "http://api.example.com"
url = identity.generate_datasource_icon_url("tenant123")
assert "http://api.example.com/console/api/workspaces/current/plugin/icon" in url
assert "tenant_id=tenant123" in url
assert "filename=icon.png" in url
# Test hardcoded icon
identity.icon = "https://assets.dify.ai/images/File%20Upload.svg"
assert identity.generate_datasource_icon_url("tenant123") == identity.icon
# Test with empty CONSOLE_API_URL
identity.icon = "test.png"
with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config:
mock_config.CONSOLE_API_URL = None
url = identity.generate_datasource_icon_url("tenant123")
assert url.startswith("/console/api/workspaces/current/plugin/icon")
def test_datasource_provider_entity():
label = I18nObject(en_US="label", zh_Hans="标签")
description = I18nObject(en_US="desc", zh_Hans="描述")
identity = DatasourceProviderIdentity(
author="author", name="name", description=description, icon="icon", label=label
)
entity = DatasourceProviderEntity(
identity=identity,
provider_type=DatasourceProviderType.ONLINE_DOCUMENT,
credentials_schema=[],
oauth_schema=None,
)
assert entity.identity == identity
assert entity.provider_type == DatasourceProviderType.ONLINE_DOCUMENT
assert entity.credentials_schema == []
def test_datasource_provider_entity_with_plugin():
label = I18nObject(en_US="label", zh_Hans="标签")
description = I18nObject(en_US="desc", zh_Hans="描述")
identity = DatasourceProviderIdentity(
author="author", name="name", description=description, icon="icon", label=label
)
entity = DatasourceProviderEntityWithPlugin(
identity=identity, provider_type=DatasourceProviderType.ONLINE_DOCUMENT, datasources=[]
)
assert entity.datasources == []
def test_datasource_invoke_meta():
meta = DatasourceInvokeMeta(time_cost=1.5, error="some error", tool_config={"k": "v"})
assert meta.time_cost == 1.5
assert meta.error == "some error"
assert meta.tool_config == {"k": "v"}
d = meta.to_dict()
assert d == {"time_cost": 1.5, "error": "some error", "tool_config": {"k": "v"}}
empty_meta = DatasourceInvokeMeta.empty()
assert empty_meta.time_cost == 0.0
assert empty_meta.error is None
assert empty_meta.tool_config == {}
error_meta = DatasourceInvokeMeta.error_instance("fatal error")
assert error_meta.time_cost == 0.0
assert error_meta.error == "fatal error"
assert error_meta.tool_config == {}
def test_datasource_label():
label_obj = I18nObject(en_US="label", zh_Hans="标签")
ds_label = DatasourceLabel(name="name", label=label_obj, icon="icon")
assert ds_label.name == "name"
assert ds_label.label == label_obj
assert ds_label.icon == "icon"
def test_online_document_models():
page = OnlineDocumentPage(
page_id="p1",
page_name="name",
page_icon={"type": "emoji"},
type="page",
last_edited_time="2023-01-01",
parent_id=None,
)
assert page.page_id == "p1"
info = OnlineDocumentInfo(workspace_id="w1", workspace_name="name", workspace_icon="icon", total=1, pages=[page])
assert info.total == 1
msg = OnlineDocumentPagesMessage(result=[info])
assert msg.result == [info]
req = GetOnlineDocumentPageContentRequest(workspace_id="w1", page_id="p1", type="page")
assert req.workspace_id == "w1"
content = OnlineDocumentPageContent(workspace_id="w1", page_id="p1", content="hello")
assert content.content == "hello"
resp = GetOnlineDocumentPageContentResponse(result=content)
assert resp.result == content
def test_website_crawl_models():
req = GetWebsiteCrawlRequest(crawl_parameters={"url": "http://test.com"})
assert req.crawl_parameters == {"url": "http://test.com"}
detail = WebSiteInfoDetail(source_url="http://test.com", content="content", title="title", description="desc")
assert detail.title == "title"
info = WebSiteInfo(status="completed", web_info_list=[detail], total=1, completed=1)
assert info.status == "completed"
msg = WebsiteCrawlMessage(result=info)
assert msg.result == info
# Test default values
msg_default = WebsiteCrawlMessage()
assert msg_default.result.status == ""
assert msg_default.result.web_info_list == []
def test_online_drive_models():
file = OnlineDriveFile(id="f1", name="file.txt", size=100, type="file")
assert file.name == "file.txt"
bucket = OnlineDriveFileBucket(bucket="b1", files=[file], is_truncated=False, next_page_parameters=None)
assert bucket.bucket == "b1"
req = OnlineDriveBrowseFilesRequest(bucket="b1", prefix="folder1", max_keys=10, next_page_parameters=None)
assert req.prefix == "folder1"
resp = OnlineDriveBrowseFilesResponse(result=[bucket])
assert resp.result == [bucket]
dl_req = OnlineDriveDownloadFileRequest(id="f1", bucket="b1")
assert dl_req.id == "f1"
def test_datasource_message():
# Use proper dict for message to avoid Pydantic Union validation ambiguity/crashes
msg = DatasourceMessage(type="text", message={"text": "hello"})
assert msg.message.text == "hello"
msg_json = DatasourceMessage(type="json", message={"json_object": {"k": "v"}})
assert msg_json.message.json_object == {"k": "v"}

View File

@ -0,0 +1,57 @@
from unittest.mock import MagicMock
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
class TestLocalFileDatasourcePlugin:
def test_init(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceEntity)
mock_runtime = MagicMock(spec=DatasourceRuntime)
tenant_id = "test-tenant-id"
icon = "test-icon"
plugin_unique_identifier = "test-plugin-id"
# Act
plugin = LocalFileDatasourcePlugin(
entity=mock_entity,
runtime=mock_runtime,
tenant_id=tenant_id,
icon=icon,
plugin_unique_identifier=plugin_unique_identifier,
)
# Assert
assert plugin.tenant_id == tenant_id
assert plugin.plugin_unique_identifier == plugin_unique_identifier
assert plugin.entity == mock_entity
assert plugin.runtime == mock_runtime
assert plugin.icon == icon
def test_datasource_provider_type(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceEntity)
mock_runtime = MagicMock(spec=DatasourceRuntime)
plugin = LocalFileDatasourcePlugin(
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
)
# Act & Assert
assert plugin.datasource_provider_type() == DatasourceProviderType.LOCAL_FILE
def test_get_icon_url(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceEntity)
mock_runtime = MagicMock(spec=DatasourceRuntime)
icon = "test-icon"
plugin = LocalFileDatasourcePlugin(
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon=icon, plugin_unique_identifier="test"
)
# Act & Assert
assert plugin.get_icon_url("any-tenant-id") == icon

View File

@ -0,0 +1,96 @@
from unittest.mock import MagicMock
import pytest
from core.datasource.entities.datasource_entities import (
DatasourceProviderEntityWithPlugin,
DatasourceProviderType,
)
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
class TestLocalFileDatasourcePluginProviderController:
def test_init(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
plugin_id = "test_plugin_id"
plugin_unique_identifier = "test_plugin_unique_identifier"
tenant_id = "test_tenant_id"
# Act
controller = LocalFileDatasourcePluginProviderController(
entity=mock_entity,
plugin_id=plugin_id,
plugin_unique_identifier=plugin_unique_identifier,
tenant_id=tenant_id,
)
# Assert
assert controller.entity == mock_entity
assert controller.plugin_id == plugin_id
assert controller.plugin_unique_identifier == plugin_unique_identifier
assert controller.tenant_id == tenant_id
def test_provider_type(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
controller = LocalFileDatasourcePluginProviderController(
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
)
# Act & Assert
assert controller.provider_type == DatasourceProviderType.LOCAL_FILE
def test_validate_credentials(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
controller = LocalFileDatasourcePluginProviderController(
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
)
# Act & Assert
# Should not raise any exception
controller._validate_credentials("user_id", {"key": "value"})
def test_get_datasource_success(self):
# Arrange
mock_datasource_entity = MagicMock()
mock_datasource_entity.identity.name = "test_datasource"
mock_entity = MagicMock()
mock_entity.datasources = [mock_datasource_entity]
mock_entity.identity.icon = "test_icon"
plugin_unique_identifier = "test_plugin_unique_identifier"
tenant_id = "test_tenant_id"
controller = LocalFileDatasourcePluginProviderController(
entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
)
# Act
datasource = controller.get_datasource("test_datasource")
# Assert
assert isinstance(datasource, LocalFileDatasourcePlugin)
assert datasource.entity == mock_datasource_entity
assert datasource.tenant_id == tenant_id
assert datasource.icon == "test_icon"
assert datasource.plugin_unique_identifier == plugin_unique_identifier
def test_get_datasource_not_found(self):
# Arrange
mock_datasource_entity = MagicMock()
mock_datasource_entity.identity.name = "other_datasource"
mock_entity = MagicMock()
mock_entity.datasources = [mock_datasource_entity]
controller = LocalFileDatasourcePluginProviderController(
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
)
# Act & Assert
with pytest.raises(ValueError, match="Datasource with name test_datasource not found"):
controller.get_datasource("test_datasource")

View File

@ -0,0 +1,151 @@
from unittest.mock import MagicMock, patch
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceIdentity,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
class TestOnlineDocumentDatasourcePlugin:
def test_init(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
runtime = MagicMock(spec=DatasourceRuntime)
tenant_id = "test_tenant"
icon = "test_icon"
plugin_unique_identifier = "test_plugin_id"
# Act
plugin = OnlineDocumentDatasourcePlugin(
entity=entity,
runtime=runtime,
tenant_id=tenant_id,
icon=icon,
plugin_unique_identifier=plugin_unique_identifier,
)
# Assert
assert plugin.entity == entity
assert plugin.runtime == runtime
assert plugin.tenant_id == tenant_id
assert plugin.icon == icon
assert plugin.plugin_unique_identifier == plugin_unique_identifier
def test_get_online_document_pages(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
identity = MagicMock(spec=DatasourceIdentity)
entity.identity = identity
identity.provider = "test_provider"
identity.name = "test_name"
runtime = MagicMock(spec=DatasourceRuntime)
runtime.credentials = {"api_key": "test_key"}
tenant_id = "test_tenant"
icon = "test_icon"
plugin_unique_identifier = "test_plugin_id"
plugin = OnlineDocumentDatasourcePlugin(
entity=entity,
runtime=runtime,
tenant_id=tenant_id,
icon=icon,
plugin_unique_identifier=plugin_unique_identifier,
)
user_id = "test_user"
datasource_parameters = {"param": "value"}
provider_type = "test_type"
mock_generator = MagicMock()
# Patch PluginDatasourceManager to isolate plugin behavior from external dependencies
with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager:
mock_manager_instance = MockManager.return_value
mock_manager_instance.get_online_document_pages.return_value = mock_generator
# Act
result = plugin.get_online_document_pages(
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
)
# Assert
assert result == mock_generator
mock_manager_instance.get_online_document_pages.assert_called_once_with(
tenant_id=tenant_id,
user_id=user_id,
datasource_provider="test_provider",
datasource_name="test_name",
credentials=runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def test_get_online_document_page_content(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
identity = MagicMock(spec=DatasourceIdentity)
entity.identity = identity
identity.provider = "test_provider"
identity.name = "test_name"
runtime = MagicMock(spec=DatasourceRuntime)
runtime.credentials = {"api_key": "test_key"}
tenant_id = "test_tenant"
icon = "test_icon"
plugin_unique_identifier = "test_plugin_id"
plugin = OnlineDocumentDatasourcePlugin(
entity=entity,
runtime=runtime,
tenant_id=tenant_id,
icon=icon,
plugin_unique_identifier=plugin_unique_identifier,
)
user_id = "test_user"
datasource_parameters = MagicMock(spec=GetOnlineDocumentPageContentRequest)
provider_type = "test_type"
mock_generator = MagicMock()
with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager:
mock_manager_instance = MockManager.return_value
mock_manager_instance.get_online_document_page_content.return_value = mock_generator
# Act
result = plugin.get_online_document_page_content(
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
)
# Assert
assert result == mock_generator
mock_manager_instance.get_online_document_page_content.assert_called_once_with(
tenant_id=tenant_id,
user_id=user_id,
datasource_provider="test_provider",
datasource_name="test_name",
credentials=runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def test_datasource_provider_type(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
runtime = MagicMock(spec=DatasourceRuntime)
plugin = OnlineDocumentDatasourcePlugin(
entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
)
# Act
result = plugin.datasource_provider_type()
# Assert
assert result == DatasourceProviderType.ONLINE_DOCUMENT

View File

@ -0,0 +1,100 @@
from unittest.mock import MagicMock
import pytest
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderEntityWithPlugin,
DatasourceProviderType,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
class TestOnlineDocumentDatasourcePluginProviderController:
def test_init(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
plugin_id = "test_plugin_id"
plugin_unique_identifier = "test_plugin_uid"
tenant_id = "test_tenant_id"
# Act
controller = OnlineDocumentDatasourcePluginProviderController(
entity=mock_entity,
plugin_id=plugin_id,
plugin_unique_identifier=plugin_unique_identifier,
tenant_id=tenant_id,
)
# Assert
assert controller.entity == mock_entity
assert controller.plugin_id == plugin_id
assert controller.plugin_unique_identifier == plugin_unique_identifier
assert controller.tenant_id == tenant_id
def test_provider_type(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
controller = OnlineDocumentDatasourcePluginProviderController(
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
)
# Assert
assert controller.provider_type == DatasourceProviderType.ONLINE_DOCUMENT
def test_get_datasource_success(self):
# Arrange
from core.datasource.entities.datasource_entities import DatasourceIdentity
mock_datasource_entity = MagicMock(spec=DatasourceEntity)
mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity)
mock_datasource_entity.identity.name = "target_datasource"
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.datasources = [mock_datasource_entity]
mock_entity.identity = MagicMock()
mock_entity.identity.icon = "test_icon"
plugin_unique_identifier = "test_plugin_uid"
tenant_id = "test_tenant_id"
controller = OnlineDocumentDatasourcePluginProviderController(
entity=mock_entity,
plugin_id="test_plugin_id",
plugin_unique_identifier=plugin_unique_identifier,
tenant_id=tenant_id,
)
# Act
result = controller.get_datasource("target_datasource")
# Assert
assert isinstance(result, OnlineDocumentDatasourcePlugin)
assert result.entity == mock_datasource_entity
assert result.tenant_id == tenant_id
assert result.icon == "test_icon"
assert result.plugin_unique_identifier == plugin_unique_identifier
assert result.runtime.tenant_id == tenant_id
def test_get_datasource_not_found(self):
# Arrange
from core.datasource.entities.datasource_entities import DatasourceIdentity
mock_datasource_entity = MagicMock(spec=DatasourceEntity)
mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity)
mock_datasource_entity.identity.name = "other_datasource"
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
mock_entity.datasources = [mock_datasource_entity]
controller = OnlineDocumentDatasourcePluginProviderController(
entity=mock_entity,
plugin_id="test_plugin_id",
plugin_unique_identifier="test_plugin_uid",
tenant_id="test_tenant_id",
)
# Act & Assert
with pytest.raises(ValueError, match="Datasource with name missing_datasource not found"):
controller.get_datasource("missing_datasource")

View File

@ -0,0 +1,147 @@
from unittest.mock import MagicMock, patch
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceIdentity,
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
class TestOnlineDriveDatasourcePlugin:
def test_init(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
runtime = MagicMock(spec=DatasourceRuntime)
tenant_id = "test_tenant"
icon = "test_icon"
plugin_unique_identifier = "test_plugin_id"
# Act
plugin = OnlineDriveDatasourcePlugin(
entity=entity,
runtime=runtime,
tenant_id=tenant_id,
icon=icon,
plugin_unique_identifier=plugin_unique_identifier,
)
# Assert
assert plugin.entity == entity
assert plugin.runtime == runtime
assert plugin.tenant_id == tenant_id
assert plugin.icon == icon
assert plugin.plugin_unique_identifier == plugin_unique_identifier
def test_online_drive_browse_files(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
identity = MagicMock(spec=DatasourceIdentity)
entity.identity = identity
identity.provider = "test_provider"
identity.name = "test_name"
runtime = MagicMock(spec=DatasourceRuntime)
runtime.credentials = {"token": "test_token"}
tenant_id = "test_tenant"
icon = "test_icon"
plugin_unique_identifier = "test_plugin_id"
plugin = OnlineDriveDatasourcePlugin(
entity=entity,
runtime=runtime,
tenant_id=tenant_id,
icon=icon,
plugin_unique_identifier=plugin_unique_identifier,
)
user_id = "test_user"
request = MagicMock(spec=OnlineDriveBrowseFilesRequest)
provider_type = "test_type"
mock_generator = MagicMock()
with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager:
mock_manager_instance = MockManager.return_value
mock_manager_instance.online_drive_browse_files.return_value = mock_generator
# Act
result = plugin.online_drive_browse_files(user_id=user_id, request=request, provider_type=provider_type)
# Assert
assert result == mock_generator
mock_manager_instance.online_drive_browse_files.assert_called_once_with(
tenant_id=tenant_id,
user_id=user_id,
datasource_provider="test_provider",
datasource_name="test_name",
credentials=runtime.credentials,
request=request,
provider_type=provider_type,
)
def test_online_drive_download_file(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
identity = MagicMock(spec=DatasourceIdentity)
entity.identity = identity
identity.provider = "test_provider"
identity.name = "test_name"
runtime = MagicMock(spec=DatasourceRuntime)
runtime.credentials = {"token": "test_token"}
tenant_id = "test_tenant"
icon = "test_icon"
plugin_unique_identifier = "test_plugin_id"
plugin = OnlineDriveDatasourcePlugin(
entity=entity,
runtime=runtime,
tenant_id=tenant_id,
icon=icon,
plugin_unique_identifier=plugin_unique_identifier,
)
user_id = "test_user"
request = MagicMock(spec=OnlineDriveDownloadFileRequest)
provider_type = "test_type"
mock_generator = MagicMock()
with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager:
mock_manager_instance = MockManager.return_value
mock_manager_instance.online_drive_download_file.return_value = mock_generator
# Act
result = plugin.online_drive_download_file(user_id=user_id, request=request, provider_type=provider_type)
# Assert
assert result == mock_generator
mock_manager_instance.online_drive_download_file.assert_called_once_with(
tenant_id=tenant_id,
user_id=user_id,
datasource_provider="test_provider",
datasource_name="test_name",
credentials=runtime.credentials,
request=request,
provider_type=provider_type,
)
def test_datasource_provider_type(self):
# Arrange
entity = MagicMock(spec=DatasourceEntity)
runtime = MagicMock(spec=DatasourceRuntime)
plugin = OnlineDriveDatasourcePlugin(
entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
)
# Act
result = plugin.datasource_provider_type()
# Assert
assert result == DatasourceProviderType.ONLINE_DRIVE

View File

@ -0,0 +1,83 @@
from unittest.mock import MagicMock
import pytest
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
class TestOnlineDriveDatasourcePluginProviderController:
def test_init(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
plugin_id = "test_plugin_id"
plugin_unique_identifier = "test_plugin_unique_identifier"
tenant_id = "test_tenant_id"
# Act
controller = OnlineDriveDatasourcePluginProviderController(
entity=mock_entity,
plugin_id=plugin_id,
plugin_unique_identifier=plugin_unique_identifier,
tenant_id=tenant_id,
)
# Assert
assert controller.entity == mock_entity
assert controller.plugin_id == plugin_id
assert controller.plugin_unique_identifier == plugin_unique_identifier
assert controller.tenant_id == tenant_id
def test_provider_type(self):
# Arrange
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
controller = OnlineDriveDatasourcePluginProviderController(
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
)
# Act & Assert
assert controller.provider_type == DatasourceProviderType.ONLINE_DRIVE
def test_get_datasource_success(self):
# Arrange
mock_datasource_entity = MagicMock()
mock_datasource_entity.identity.name = "test_datasource"
mock_entity = MagicMock()
mock_entity.datasources = [mock_datasource_entity]
mock_entity.identity.icon = "test_icon"
plugin_unique_identifier = "test_plugin_unique_identifier"
tenant_id = "test_tenant_id"
controller = OnlineDriveDatasourcePluginProviderController(
entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
)
# Act
datasource = controller.get_datasource("test_datasource")
# Assert
assert isinstance(datasource, OnlineDriveDatasourcePlugin)
assert datasource.entity == mock_datasource_entity
assert datasource.tenant_id == tenant_id
assert datasource.icon == "test_icon"
assert datasource.plugin_unique_identifier == plugin_unique_identifier
assert datasource.runtime.tenant_id == tenant_id
def test_get_datasource_not_found(self):
# Arrange
mock_datasource_entity = MagicMock()
mock_datasource_entity.identity.name = "other_datasource"
mock_entity = MagicMock()
mock_entity.datasources = [mock_datasource_entity]
controller = OnlineDriveDatasourcePluginProviderController(
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
)
# Act & Assert
with pytest.raises(ValueError, match="Datasource with name test_datasource not found"):
controller.get_datasource("test_datasource")

View File

@ -0,0 +1,409 @@
import base64
import hashlib
import hmac
from unittest.mock import MagicMock, patch
import httpx
import pytest
from core.datasource.datasource_file_manager import DatasourceFileManager
from models.model import MessageFile, UploadFile
from models.tools import ToolFile
class TestDatasourceFileManager:
@patch("core.datasource.datasource_file_manager.time.time")
@patch("core.datasource.datasource_file_manager.os.urandom")
@patch("core.datasource.datasource_file_manager.dify_config")
def test_sign_file(self, mock_config, mock_urandom, mock_time):
# Setup
mock_config.FILES_URL = "http://localhost:5001"
mock_config.SECRET_KEY = "test_secret"
mock_time.return_value = 1700000000
mock_urandom.return_value = b"1234567890abcdef" # 16 bytes
datasource_file_id = "file_id_123"
extension = ".png"
# Execute
signed_url = DatasourceFileManager.sign_file(datasource_file_id, extension)
# Verify
assert signed_url.startswith("http://localhost:5001/files/datasources/file_id_123.png?")
assert "timestamp=1700000000" in signed_url
assert f"nonce={mock_urandom.return_value.hex()}" in signed_url
assert "sign=" in signed_url
@patch("core.datasource.datasource_file_manager.time.time")
@patch("core.datasource.datasource_file_manager.os.urandom")
@patch("core.datasource.datasource_file_manager.dify_config")
def test_sign_file_empty_secret(self, mock_config, mock_urandom, mock_time):
# Setup
mock_config.FILES_URL = "http://localhost:5001"
mock_config.SECRET_KEY = None # Empty secret
mock_time.return_value = 1700000000
mock_urandom.return_value = b"1234567890abcdef"
# Execute
signed_url = DatasourceFileManager.sign_file("file_id", ".png")
assert "sign=" in signed_url
@patch("core.datasource.datasource_file_manager.time.time")
@patch("core.datasource.datasource_file_manager.dify_config")
def test_verify_file(self, mock_config, mock_time):
# Setup
mock_config.SECRET_KEY = "test_secret"
mock_config.FILES_ACCESS_TIMEOUT = 300
mock_time.return_value = 1700000000
datasource_file_id = "file_id_123"
timestamp = "1699999800" # 200 seconds ago
nonce = "some_nonce"
# Manually calculate sign
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = b"test_secret"
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
# Execute & Verify Success
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True
# Verify Failure - Wrong Sign
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False
# Verify Failure - Timeout
mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout)
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False
@patch("core.datasource.datasource_file_manager.time.time")
@patch("core.datasource.datasource_file_manager.dify_config")
def test_verify_file_empty_secret(self, mock_config, mock_time):
# Setup
mock_config.SECRET_KEY = "" # Empty string secret
mock_config.FILES_ACCESS_TIMEOUT = 300
mock_time.return_value = 1700000000
datasource_file_id = "file_id_123"
timestamp = "1699999800"
nonce = "some_nonce"
# Calculate with empty secret
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
sign = hmac.new(b"", data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
@patch("core.datasource.datasource_file_manager.dify_config")
def test_create_file_by_raw(self, mock_config, mock_uuid, mock_storage, mock_db):
# Setup
mock_uuid.return_value = MagicMock(hex="unique_hex")
mock_config.STORAGE_TYPE = "local"
user_id = "user_123"
tenant_id = "tenant_456"
file_binary = b"fake binary data"
mimetype = "image/png"
# Execute
upload_file = DatasourceFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mimetype,
filename="test.png",
)
# Verify
assert upload_file.tenant_id == tenant_id
assert upload_file.name == "test.png"
assert upload_file.size == len(file_binary)
assert upload_file.mime_type == mimetype
assert upload_file.key == f"datasources/{tenant_id}/unique_hex.png"
mock_storage.save.assert_called_once_with(upload_file.key, file_binary)
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
@patch("core.datasource.datasource_file_manager.dify_config")
def test_create_file_by_raw_filename_no_extension(self, mock_config, mock_uuid, mock_storage, mock_db):
# Setup
mock_uuid.return_value = MagicMock(hex="unique_hex")
mock_config.STORAGE_TYPE = "local"
user_id = "user_123"
tenant_id = "tenant_456"
file_binary = b"fake binary data"
mimetype = "image/png"
# Execute
upload_file = DatasourceFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mimetype,
filename="test", # No extension
)
# Verify
assert upload_file.name == "test.png" # Should append extension
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
@patch("core.datasource.datasource_file_manager.dify_config")
@patch("core.datasource.datasource_file_manager.guess_extension")
def test_create_file_by_raw_unknown_extension(self, mock_guess_ext, mock_config, mock_uuid, mock_storage, mock_db):
# Setup
mock_guess_ext.return_value = None # Cannot guess
mock_uuid.return_value = MagicMock(hex="unique_hex")
# Execute
upload_file = DatasourceFileManager.create_file_by_raw(
user_id="user",
tenant_id="tenant",
conversation_id=None,
file_binary=b"data",
mimetype="application/x-unknown",
)
# Verify
assert upload_file.extension == ".bin"
assert upload_file.name == "unique_hex.bin"
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
@patch("core.datasource.datasource_file_manager.dify_config")
def test_create_file_by_raw_no_filename(self, mock_config, mock_uuid, mock_storage, mock_db):
# Setup
mock_uuid.return_value = MagicMock(hex="unique_hex")
mock_config.STORAGE_TYPE = "local"
# Execute
upload_file = DatasourceFileManager.create_file_by_raw(
user_id="user_123",
tenant_id="tenant_456",
conversation_id=None,
file_binary=b"data",
mimetype="application/pdf",
)
# Verify
assert upload_file.name == "unique_hex.pdf"
assert upload_file.extension == ".pdf"
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
def test_create_file_by_url_mimetype_from_guess(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
# Setup
mock_uuid.return_value = MagicMock(hex="unique_hex")
mock_response = MagicMock()
mock_response.content = b"bits"
mock_response.headers = {} # No content-type in headers
mock_ssrf.get.return_value = mock_response
# Execute
tool_file = DatasourceFileManager.create_file_by_url(
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.png"
)
# Verify
assert tool_file.mimetype == "image/png" # Guessed from .png in URL
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
def test_create_file_by_url_mimetype_default(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
# Setup
mock_uuid.return_value = MagicMock(hex="unique_hex")
mock_response = MagicMock()
mock_response.content = b"bits"
mock_response.headers = {}
mock_ssrf.get.return_value = mock_response
# Execute
tool_file = DatasourceFileManager.create_file_by_url(
user_id="user_123",
tenant_id="tenant_456",
file_url="https://example.com/unknown", # No extension, no headers
)
# Verify
assert tool_file.mimetype == "application/octet-stream"
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
@patch("core.datasource.datasource_file_manager.uuid4")
def test_create_file_by_url_success(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
# Setup
mock_uuid.return_value = MagicMock(hex="unique_hex")
mock_response = MagicMock()
mock_response.content = b"downloaded bits"
mock_response.headers = {"Content-Type": "image/jpeg"}
mock_ssrf.get.return_value = mock_response
# Execute
tool_file = DatasourceFileManager.create_file_by_url(
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.jpg"
)
# Verify
assert tool_file.mimetype == "image/jpeg"
assert tool_file.size == len(b"downloaded bits")
assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg"
mock_storage.save.assert_called_once()
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
def test_create_file_by_url_timeout(self, mock_ssrf):
# Setup
mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout")
# Execute & Verify
with pytest.raises(ValueError, match="timeout when downloading file"):
DatasourceFileManager.create_file_by_url(
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/large.file"
)
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
def test_get_file_binary(self, mock_storage, mock_db):
# Setup
mock_upload_file = MagicMock(spec=UploadFile)
mock_upload_file.key = "some_key"
mock_upload_file.mime_type = "image/png"
mock_query = mock_db.session.query.return_value
mock_where = mock_query.where.return_value
mock_where.first.return_value = mock_upload_file
mock_storage.load_once.return_value = b"file content"
# Execute
result = DatasourceFileManager.get_file_binary("file_id")
# Verify
assert result == (b"file content", "image/png")
# Case: Not found
mock_where.first.return_value = None
assert DatasourceFileManager.get_file_binary("unknown") is None
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
def test_get_file_binary_by_message_file_id(self, mock_storage, mock_db):
# Setup
mock_message_file = MagicMock(spec=MessageFile)
mock_message_file.url = "http://localhost/files/tools/tool_id.png"
mock_tool_file = MagicMock(spec=ToolFile)
mock_tool_file.file_key = "tool_key"
mock_tool_file.mimetype = "image/png"
# Mock query sequence
def mock_query(model):
m = MagicMock()
if model == MessageFile:
m.where.return_value.first.return_value = mock_message_file
elif model == ToolFile:
m.where.return_value.first.return_value = mock_tool_file
return m
mock_db.session.query.side_effect = mock_query
mock_storage.load_once.return_value = b"tool content"
# Execute
result = DatasourceFileManager.get_file_binary_by_message_file_id("msg_file_id")
# Verify
assert result == (b"tool content", "image/png")
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
def test_get_file_binary_by_message_file_id_with_extension(self, mock_storage, mock_db):
# Test that it correctly parses tool_id even with extension in URL
mock_message_file = MagicMock(spec=MessageFile)
mock_message_file.url = "http://localhost/files/tools/abcdef.png"
mock_tool_file = MagicMock(spec=ToolFile)
mock_tool_file.id = "abcdef"
mock_tool_file.file_key = "tk"
mock_tool_file.mimetype = "image/png"
def mock_query(model):
m = MagicMock()
if model == MessageFile:
m.where.return_value.first.return_value = mock_message_file
else:
m.where.return_value.first.return_value = mock_tool_file
return m
mock_db.session.query.side_effect = mock_query
mock_storage.load_once.return_value = b"bits"
result = DatasourceFileManager.get_file_binary_by_message_file_id("m")
assert result == (b"bits", "image/png")
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db):
# Setup common mock
mock_query_obj = MagicMock()
mock_db.session.query.return_value = mock_query_obj
mock_query_obj.where.return_value.first.return_value = None
# Case 1: Message file not found
assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None
# Case 2: Message file found but tool file not found
mock_message_file = MagicMock(spec=MessageFile)
mock_message_file.url = None
def mock_query_v2(model):
m = MagicMock()
if model == MessageFile:
m.where.return_value.first.return_value = mock_message_file
else:
m.where.return_value.first.return_value = None
return m
mock_db.session.query.side_effect = mock_query_v2
assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None
@patch("core.datasource.datasource_file_manager.db")
@patch("core.datasource.datasource_file_manager.storage")
def test_get_file_generator_by_upload_file_id(self, mock_storage, mock_db):
# Setup
mock_upload_file = MagicMock(spec=UploadFile)
mock_upload_file.key = "upload_key"
mock_upload_file.mime_type = "text/plain"
mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file
mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"])
# Execute
stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("upload_id")
# Verify
assert mimetype == "text/plain"
assert list(stream) == [b"chunk1", b"chunk2"]
# Case: Not found
mock_db.session.query.return_value.where.return_value.first.return_value = None
stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none")
assert stream is None
assert mimetype is None

View File

@ -1,9 +1,15 @@
import types
from collections.abc import Generator
import pytest
from contexts.wrapper import RecyclableContextVar
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.file import File
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
@ -15,6 +21,22 @@ def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, Non
)
def _drain_generator(gen: Generator[DatasourceMessage, None, object]) -> tuple[list[DatasourceMessage], object | None]:
messages: list[DatasourceMessage] = []
try:
while True:
messages.append(next(gen))
except StopIteration as e:
return messages, e.value
def _invalidate_recyclable_contextvars() -> None:
"""
Ensure RecyclableContextVar.get() raises LookupError until reset by code under test.
"""
RecyclableContextVar.increment_thread_recycles()
def test_get_icon_url_calls_runtime(mocker):
fake_runtime = mocker.Mock()
fake_runtime.get_icon_url.return_value = "https://icon"
@ -30,6 +52,119 @@ def test_get_icon_url_calls_runtime(mocker):
DatasourceManager.get_datasource_runtime.assert_called_once()
def test_get_datasource_runtime_delegates_to_provider_controller(mocker):
provider_controller = mocker.Mock()
provider_controller.get_datasource.return_value = object()
mocker.patch.object(DatasourceManager, "get_datasource_plugin_provider", return_value=provider_controller)
runtime = DatasourceManager.get_datasource_runtime(
provider_id="prov/x",
datasource_name="ds",
tenant_id="t1",
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
assert runtime is provider_controller.get_datasource.return_value
provider_controller.get_datasource.assert_called_once_with("ds")
@pytest.mark.parametrize(
("datasource_type", "controller_path"),
[
(
DatasourceProviderType.ONLINE_DOCUMENT,
"core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController",
),
(
DatasourceProviderType.ONLINE_DRIVE,
"core.datasource.datasource_manager.OnlineDriveDatasourcePluginProviderController",
),
(
DatasourceProviderType.WEBSITE_CRAWL,
"core.datasource.datasource_manager.WebsiteCrawlDatasourcePluginProviderController",
),
(
DatasourceProviderType.LOCAL_FILE,
"core.datasource.datasource_manager.LocalFileDatasourcePluginProviderController",
),
],
)
def test_get_datasource_plugin_provider_creates_controller_and_caches(mocker, datasource_type, controller_path):
_invalidate_recyclable_contextvars()
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
fetch = mocker.patch(
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
return_value=provider_entity,
)
ctrl_cls = mocker.patch(controller_path)
first = DatasourceManager.get_datasource_plugin_provider(
provider_id=f"prov/{datasource_type.value}",
tenant_id="t1",
datasource_type=datasource_type,
)
second = DatasourceManager.get_datasource_plugin_provider(
provider_id=f"prov/{datasource_type.value}",
tenant_id="t1",
datasource_type=datasource_type,
)
assert first is second
assert fetch.call_count == 1
assert ctrl_cls.call_count == 1
def test_get_datasource_plugin_provider_raises_when_provider_entity_missing(mocker):
_invalidate_recyclable_contextvars()
mocker.patch(
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
return_value=None,
)
with pytest.raises(DatasourceProviderNotFoundError, match="plugin provider prov/notfound not found"):
DatasourceManager.get_datasource_plugin_provider(
provider_id="prov/notfound",
tenant_id="t1",
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
def test_get_datasource_plugin_provider_raises_for_unsupported_type(mocker):
_invalidate_recyclable_contextvars()
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
mocker.patch(
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
return_value=provider_entity,
)
with pytest.raises(ValueError, match="Unsupported datasource type"):
DatasourceManager.get_datasource_plugin_provider(
provider_id="prov/x",
tenant_id="t1",
datasource_type=types.SimpleNamespace(), # not a DatasourceProviderType at runtime
)
def test_get_datasource_plugin_provider_raises_when_controller_none(mocker):
_invalidate_recyclable_contextvars()
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
mocker.patch(
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
return_value=provider_entity,
)
mocker.patch(
"core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController",
return_value=None,
)
with pytest.raises(DatasourceProviderNotFoundError, match="Datasource provider prov/x not found"):
DatasourceManager.get_datasource_plugin_provider(
provider_id="prov/x",
tenant_id="t1",
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
def test_stream_online_results_yields_messages_online_document(mocker):
# stub runtime to yield a text message
def _doc_messages(**_):
@ -60,6 +195,148 @@ def test_stream_online_results_yields_messages_online_document(mocker):
assert msgs[0].message.text == "hello"
def test_stream_online_results_sets_credentials_and_returns_empty_dict_online_document(mocker):
class _Runtime:
def __init__(self) -> None:
self.runtime = types.SimpleNamespace(credentials=None)
def get_online_document_page_content(self, **_kwargs):
yield from _gen_messages_text_only("hello")
runtime = _Runtime()
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime)
mocker.patch(
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
)
gen = DatasourceManager.stream_online_results(
user_id="u1",
datasource_name="ds",
datasource_type="online_document",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="cred",
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
online_drive_request=None,
)
messages, final_value = _drain_generator(gen)
assert runtime.runtime.credentials == {"token": "t"}
assert [m.message.text for m in messages] == ["hello"]
assert final_value == {}
def test_stream_online_results_raises_when_missing_params(mocker):
class _Runtime:
def __init__(self) -> None:
self.runtime = types.SimpleNamespace(credentials=None)
def get_online_document_page_content(self, **_kwargs):
yield from _gen_messages_text_only("never")
def online_drive_download_file(self, **_kwargs):
yield from _gen_messages_text_only("never")
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=_Runtime())
mocker.patch(
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
return_value={},
)
with pytest.raises(ValueError, match="datasource_param is required for ONLINE_DOCUMENT streaming"):
list(
DatasourceManager.stream_online_results(
user_id="u1",
datasource_name="ds",
datasource_type="online_document",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
datasource_param=None,
online_drive_request=None,
)
)
with pytest.raises(ValueError, match="online_drive_request is required for ONLINE_DRIVE streaming"):
list(
DatasourceManager.stream_online_results(
user_id="u1",
datasource_name="ds",
datasource_type="online_drive",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
datasource_param=None,
online_drive_request=None,
)
)
def test_stream_online_results_yields_messages_and_returns_empty_dict_online_drive(mocker):
class _Runtime:
def __init__(self) -> None:
self.runtime = types.SimpleNamespace(credentials=None)
def online_drive_download_file(self, **_kwargs):
yield from _gen_messages_text_only("drive")
runtime = _Runtime()
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime)
mocker.patch(
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
)
gen = DatasourceManager.stream_online_results(
user_id="u1",
datasource_name="ds",
datasource_type="online_drive",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="cred",
datasource_param=None,
online_drive_request=types.SimpleNamespace(id="fid", bucket="b"),
)
messages, final_value = _drain_generator(gen)
assert runtime.runtime.credentials == {"token": "t"}
assert [m.message.text for m in messages] == ["drive"]
assert final_value == {}
def test_stream_online_results_raises_for_unsupported_stream_type(mocker):
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=mocker.Mock())
mocker.patch(
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
return_value={},
)
with pytest.raises(ValueError, match="Unsupported datasource type for streaming"):
list(
DatasourceManager.stream_online_results(
user_id="u1",
datasource_name="ds",
datasource_type="website_crawl",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
datasource_param=None,
online_drive_request=None,
)
)
def test_stream_node_events_emits_events_online_document(mocker):
# make manager's low-level stream produce TEXT only
mocker.patch.object(
@ -93,6 +370,260 @@ def test_stream_node_events_emits_events_online_document(mocker):
assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
def _transformed(**_kwargs):
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"),
meta={},
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT,
message=DatasourceMessage.TextMessage(text="hello"),
meta=None,
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.LINK,
message=DatasourceMessage.TextMessage(text="http://example.com"),
meta=None,
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.VARIABLE,
message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="a", stream=True),
meta=None,
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.VARIABLE,
message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="b", stream=True),
meta=None,
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.VARIABLE,
message=DatasourceMessage.VariableMessage(variable_name="x", variable_value=1, stream=False),
meta=None,
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.JSON,
message=DatasourceMessage.JsonMessage(json_object={"k": "v"}),
meta=None,
)
mocker.patch(
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
side_effect=_transformed,
)
fake_tool_file = types.SimpleNamespace(mimetype="image/png")
class _Session:
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def scalar(self, _stmt):
return fake_tool_file
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
mocker.patch(
"core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE
)
built = File(
tenant_id="t1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tool_file_1",
extension=".png",
mime_type="image/png",
storage_key="k",
)
build_from_mapping = mocker.patch(
"core.datasource.datasource_manager.file_factory.build_from_mapping",
return_value=built,
)
variable_pool = mocker.Mock()
events = list(
DatasourceManager.stream_node_events(
node_id="nodeA",
user_id="u1",
datasource_name="ds",
datasource_type="online_document",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
parameters_for_log={"k": "v"},
datasource_info={"info": "x"},
variable_pool=variable_pool,
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
online_drive_request=None,
)
)
build_from_mapping.assert_called_once()
variable_pool.add.assert_not_called()
assert any(isinstance(e, StreamChunkEvent) and e.chunk == "hello" for e in events)
assert any(isinstance(e, StreamChunkEvent) and e.chunk.startswith("Link: http") for e in events)
assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "a" for e in events)
assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "b" for e in events)
assert isinstance(events[-2], StreamChunkEvent)
assert events[-2].is_final is True
assert isinstance(events[-1], StreamCompletedEvent)
assert events[-1].node_run_result.outputs["v"] == "ab"
assert events[-1].node_run_result.outputs["x"] == 1
def test_stream_node_events_raises_when_toolfile_missing(mocker):
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
def _transformed(**_kwargs):
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text="/files/datasources/missing.png"),
meta={},
)
mocker.patch(
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
side_effect=_transformed,
)
class _Session:
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def scalar(self, _stmt):
return None
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
with pytest.raises(ValueError, match="ToolFile not found for file_id=missing, tenant_id=t1"):
list(
DatasourceManager.stream_node_events(
node_id="nodeA",
user_id="u1",
datasource_name="ds",
datasource_type="online_document",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
parameters_for_log={},
datasource_info={},
variable_pool=mocker.Mock(),
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
online_drive_request=None,
)
)
def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(mocker):
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
file_in = File(
tenant_id="t1",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tf",
extension=".pdf",
mime_type="application/pdf",
storage_key="k",
)
def _transformed(**_kwargs):
yield DatasourceMessage(
type=DatasourceMessage.MessageType.FILE,
message=DatasourceMessage.FileMessage(file_marker="file_marker"),
meta={"file": file_in},
)
mocker.patch(
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
side_effect=_transformed,
)
variable_pool = mocker.Mock()
events = list(
DatasourceManager.stream_node_events(
node_id="nodeA",
user_id="u1",
datasource_name="ds",
datasource_type="online_drive",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
parameters_for_log={},
datasource_info={"k": "v"},
variable_pool=variable_pool,
datasource_param=None,
online_drive_request=types.SimpleNamespace(id="id", bucket="b"),
)
)
variable_pool.add.assert_called_once()
assert variable_pool.add.call_args[0][0] == ["nodeA", "file"]
assert variable_pool.add.call_args[0][1] == file_in
completed = events[-1]
assert isinstance(completed, StreamCompletedEvent)
assert completed.node_run_result.outputs["file"] == file_in
assert completed.node_run_result.outputs["datasource_type"] == DatasourceProviderType.ONLINE_DRIVE
def test_stream_node_events_skips_file_build_for_non_online_types(mocker):
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
def _transformed(**_kwargs):
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"),
meta={},
)
mocker.patch(
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
side_effect=_transformed,
)
build_from_mapping = mocker.patch("core.datasource.datasource_manager.file_factory.build_from_mapping")
events = list(
DatasourceManager.stream_node_events(
node_id="nodeA",
user_id="u1",
datasource_name="ds",
datasource_type="website_crawl",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
parameters_for_log={},
datasource_info={},
variable_pool=mocker.Mock(),
datasource_param=None,
online_drive_request=None,
)
)
build_from_mapping.assert_not_called()
assert isinstance(events[-1], StreamCompletedEvent)
assert events[-1].node_run_result.outputs["file"] is None
def test_get_upload_file_by_id_builds_file(mocker):
# fake UploadFile row
fake_row = types.SimpleNamespace(
@ -133,3 +664,27 @@ def test_get_upload_file_by_id_builds_file(mocker):
f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
assert f.related_id == "fid"
assert f.extension == ".txt"
def test_get_upload_file_by_id_raises_when_missing(mocker):
class _Q:
def where(self, *_args, **_kwargs):
return self
def first(self):
return None
class _S:
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def query(self, *_):
return _Q()
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S())
with pytest.raises(ValueError, match="UploadFile not found for file_id=fid, tenant_id=t1"):
DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")

View File

@ -0,0 +1,64 @@
from unittest.mock import MagicMock
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
from core.datasource.errors import (
DatasourceApiSchemaError,
DatasourceEngineInvokeError,
DatasourceInvokeError,
DatasourceNotFoundError,
DatasourceNotSupportedError,
DatasourceParameterValidationError,
DatasourceProviderCredentialValidationError,
DatasourceProviderNotFoundError,
)
class TestErrors:
def test_datasource_provider_not_found_error(self):
error = DatasourceProviderNotFoundError("Provider not found")
assert str(error) == "Provider not found"
assert isinstance(error, ValueError)
def test_datasource_not_found_error(self):
error = DatasourceNotFoundError("Datasource not found")
assert str(error) == "Datasource not found"
assert isinstance(error, ValueError)
def test_datasource_parameter_validation_error(self):
error = DatasourceParameterValidationError("Validation failed")
assert str(error) == "Validation failed"
assert isinstance(error, ValueError)
def test_datasource_provider_credential_validation_error(self):
error = DatasourceProviderCredentialValidationError("Credential validation failed")
assert str(error) == "Credential validation failed"
assert isinstance(error, ValueError)
def test_datasource_not_supported_error(self):
error = DatasourceNotSupportedError("Not supported")
assert str(error) == "Not supported"
assert isinstance(error, ValueError)
def test_datasource_invoke_error(self):
error = DatasourceInvokeError("Invoke error")
assert str(error) == "Invoke error"
assert isinstance(error, ValueError)
def test_datasource_api_schema_error(self):
error = DatasourceApiSchemaError("API schema error")
assert str(error) == "API schema error"
assert isinstance(error, ValueError)
def test_datasource_engine_invoke_error(self):
mock_meta = MagicMock(spec=DatasourceInvokeMeta)
error = DatasourceEngineInvokeError(meta=mock_meta)
assert error.meta == mock_meta
assert isinstance(error, Exception)
def test_datasource_engine_invoke_error_init(self):
# Test initialization with meta
meta = DatasourceInvokeMeta(time_cost=1.5, error="Engine failed")
error = DatasourceEngineInvokeError(meta=meta)
assert error.meta == meta
assert error.meta.time_cost == 1.5
assert error.meta.error == "Engine failed"

View File

@ -0,0 +1,337 @@
from unittest.mock import MagicMock, patch
import pytest
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from dify_graph.file import File
from dify_graph.file.enums import FileTransferMethod, FileType
from models.tools import ToolFile
class TestDatasourceFileMessageTransformer:
def test_transform_text_and_link_messages(self):
# Setup
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT, message=DatasourceMessage.TextMessage(text="hello")
),
DatasourceMessage(
type=DatasourceMessage.MessageType.LINK,
message=DatasourceMessage.TextMessage(text="https://example.com"),
),
]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify
assert len(result) == 2
assert result[0].type == DatasourceMessage.MessageType.TEXT
assert result[0].message.text == "hello"
assert result[1].type == DatasourceMessage.MessageType.LINK
assert result[1].message.text == "https://example.com"
@patch("core.datasource.utils.message_transformer.ToolFileManager")
@patch("core.datasource.utils.message_transformer.guess_extension")
def test_transform_image_message_success(self, mock_guess_ext, mock_tool_file_manager_cls):
# Setup
mock_manager = mock_tool_file_manager_cls.return_value
mock_tool_file = MagicMock(spec=ToolFile)
mock_tool_file.id = "file_id_123"
mock_tool_file.mimetype = "image/png"
mock_manager.create_file_by_url.return_value = mock_tool_file
mock_guess_ext.return_value = ".png"
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE,
message=DatasourceMessage.TextMessage(text="https://example.com/image.png"),
meta={"some": "meta"},
)
]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1", conversation_id="conv1"
)
)
# Verify
assert len(result) == 1
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
assert result[0].message.text == "/files/datasources/file_id_123.png"
assert result[0].meta == {"some": "meta"}
mock_manager.create_file_by_url.assert_called_once_with(
user_id="user1", tenant_id="tenant1", file_url="https://example.com/image.png", conversation_id="conv1"
)
@patch("core.datasource.utils.message_transformer.ToolFileManager")
def test_transform_image_message_failure(self, mock_tool_file_manager_cls):
# Setup
mock_manager = mock_tool_file_manager_cls.return_value
mock_manager.create_file_by_url.side_effect = Exception("Download failed")
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE,
message=DatasourceMessage.TextMessage(text="https://example.com/image.png"),
)
]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify
assert len(result) == 1
assert result[0].type == DatasourceMessage.MessageType.TEXT
assert "Failed to download image" in result[0].message.text
assert "Download failed" in result[0].message.text
@patch("core.datasource.utils.message_transformer.ToolFileManager")
@patch("core.datasource.utils.message_transformer.guess_extension")
def test_transform_blob_message_image(self, mock_guess_ext, mock_tool_file_manager_cls):
# Setup
mock_manager = mock_tool_file_manager_cls.return_value
mock_tool_file = MagicMock(spec=ToolFile)
mock_tool_file.id = "blob_id_456"
mock_tool_file.mimetype = "image/jpeg"
mock_manager.create_file_by_raw.return_value = mock_tool_file
mock_guess_ext.return_value = ".jpg"
blob_data = b"fake-image-bits"
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.BLOB,
message=DatasourceMessage.BlobMessage(blob=blob_data),
meta={"mime_type": "image/jpeg", "file_name": "test.jpg"},
)
]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify
assert len(result) == 1
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
assert result[0].message.text == "/files/datasources/blob_id_456.jpg"
mock_manager.create_file_by_raw.assert_called_once()
@patch("core.datasource.utils.message_transformer.ToolFileManager")
@patch("core.datasource.utils.message_transformer.guess_extension")
@patch("core.datasource.utils.message_transformer.guess_type")
def test_transform_blob_message_binary_guess_mimetype(
self, mock_guess_type, mock_guess_ext, mock_tool_file_manager_cls
):
# Setup
mock_manager = mock_tool_file_manager_cls.return_value
mock_tool_file = MagicMock(spec=ToolFile)
mock_tool_file.id = "blob_id_789"
mock_tool_file.mimetype = "application/pdf"
mock_manager.create_file_by_raw.return_value = mock_tool_file
mock_guess_type.return_value = ("application/pdf", None)
mock_guess_ext.return_value = ".pdf"
blob_data = b"fake-pdf-bits"
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.BLOB,
message=DatasourceMessage.BlobMessage(blob=blob_data),
meta={"file_name": "test.pdf"},
)
]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify
assert len(result) == 1
assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK
assert result[0].message.text == "/files/datasources/blob_id_789.pdf"
def test_transform_blob_message_invalid_type(self):
# Setup
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.BLOB, message=DatasourceMessage.TextMessage(text="not a blob")
)
]
# Execute & Verify
with pytest.raises(ValueError, match="unexpected message type"):
list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
def test_transform_file_tool_file_image(self):
# Setup
mock_file = MagicMock(spec=File)
mock_file.transfer_method = FileTransferMethod.TOOL_FILE
mock_file.related_id = "related_123"
mock_file.extension = ".png"
mock_file.type = FileType.IMAGE
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.FILE,
message=DatasourceMessage.TextMessage(text="ignored"),
meta={"file": mock_file},
)
]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify
assert len(result) == 1
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
assert result[0].message.text == "/files/datasources/related_123.png"
def test_transform_file_tool_file_binary(self):
# Setup
mock_file = MagicMock(spec=File)
mock_file.transfer_method = FileTransferMethod.TOOL_FILE
mock_file.related_id = "related_456"
mock_file.extension = ".txt"
mock_file.type = FileType.DOCUMENT
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.FILE,
message=DatasourceMessage.TextMessage(text="ignored"),
meta={"file": mock_file},
)
]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify
assert len(result) == 1
assert result[0].type == DatasourceMessage.MessageType.LINK
assert result[0].message.text == "/files/datasources/related_456.txt"
def test_transform_file_other_transfer_method(self):
# Setup
mock_file = MagicMock(spec=File)
mock_file.transfer_method = FileTransferMethod.REMOTE_URL
msg = DatasourceMessage(
type=DatasourceMessage.MessageType.FILE,
message=DatasourceMessage.TextMessage(text="remote image"),
meta={"file": mock_file},
)
messages = [msg]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify
assert len(result) == 1
assert result[0] == msg
def test_transform_other_message_type(self):
# JSON type is yielded by the default 'else' block or the 'yield message' at the end
msg = DatasourceMessage(
type=DatasourceMessage.MessageType.JSON, message=DatasourceMessage.JsonMessage(json_object={"k": "v"})
)
messages = [msg]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify
assert len(result) == 1
assert result[0] == msg
def test_get_datasource_file_url(self):
# Test with extension
url = DatasourceFileMessageTransformer.get_datasource_file_url("file1", ".jpg")
assert url == "/files/datasources/file1.jpg"
# Test without extension
url = DatasourceFileMessageTransformer.get_datasource_file_url("file2", None)
assert url == "/files/datasources/file2.bin"
def test_transform_blob_message_no_meta_filename(self):
# This tests line 70 where filename might be None
with patch("core.datasource.utils.message_transformer.ToolFileManager") as mock_tool_file_manager_cls:
mock_manager = mock_tool_file_manager_cls.return_value
mock_tool_file = MagicMock(spec=ToolFile)
mock_tool_file.id = "blob_id_no_name"
mock_tool_file.mimetype = "application/octet-stream"
mock_manager.create_file_by_raw.return_value = mock_tool_file
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.BLOB,
message=DatasourceMessage.BlobMessage(blob=b"data"),
meta={}, # No mime_type, no file_name
)
]
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
assert len(result) == 1
assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK
assert result[0].message.text == "/files/datasources/blob_id_no_name.bin"
@patch("core.datasource.utils.message_transformer.ToolFileManager")
def test_transform_image_message_not_text_message(self, mock_tool_file_manager_cls):
# This tests line 24-26 where it checks if message is instance of TextMessage
messages = [
DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE, message=DatasourceMessage.BlobMessage(blob=b"not-text")
)
]
# Execute
result = list(
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=iter(messages), user_id="user1", tenant_id="tenant1"
)
)
# Verify - should yield unchanged if it's not a TextMessage
assert len(result) == 1
assert result[0].type == DatasourceMessage.MessageType.IMAGE
assert isinstance(result[0].message, DatasourceMessage.BlobMessage)

View File

@ -0,0 +1,101 @@
from collections.abc import Generator
from unittest.mock import MagicMock, patch
import pytest
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
WebsiteCrawlMessage,
)
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
class TestWebsiteCrawlDatasourcePlugin:
@pytest.fixture
def mock_entity(self):
entity = MagicMock(spec=DatasourceEntity)
entity.identity = MagicMock()
entity.identity.provider = "test-provider"
entity.identity.name = "test-name"
return entity
@pytest.fixture
def mock_runtime(self):
runtime = MagicMock(spec=DatasourceRuntime)
runtime.credentials = {"api_key": "test-key"}
return runtime
def test_init(self, mock_entity, mock_runtime):
# Arrange
tenant_id = "test-tenant-id"
icon = "test-icon"
plugin_unique_identifier = "test-plugin-id"
# Act
plugin = WebsiteCrawlDatasourcePlugin(
entity=mock_entity,
runtime=mock_runtime,
tenant_id=tenant_id,
icon=icon,
plugin_unique_identifier=plugin_unique_identifier,
)
# Assert
assert plugin.tenant_id == tenant_id
assert plugin.plugin_unique_identifier == plugin_unique_identifier
assert plugin.entity == mock_entity
assert plugin.runtime == mock_runtime
assert plugin.icon == icon
def test_datasource_provider_type(self, mock_entity, mock_runtime):
# Arrange
plugin = WebsiteCrawlDatasourcePlugin(
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
)
# Act & Assert
assert plugin.datasource_provider_type() == DatasourceProviderType.WEBSITE_CRAWL
def test_get_website_crawl(self, mock_entity, mock_runtime):
# Arrange
plugin = WebsiteCrawlDatasourcePlugin(
entity=mock_entity,
runtime=mock_runtime,
tenant_id="test-tenant-id",
icon="test-icon",
plugin_unique_identifier="test-plugin-id",
)
user_id = "test-user-id"
datasource_parameters = {"url": "https://example.com"}
provider_type = "firecrawl"
mock_message = MagicMock(spec=WebsiteCrawlMessage)
# Mock PluginDatasourceManager
with patch("core.datasource.website_crawl.website_crawl_plugin.PluginDatasourceManager") as mock_manager_class:
mock_manager = mock_manager_class.return_value
mock_manager.get_website_crawl.return_value = (msg for msg in [mock_message])
# Act
result = plugin.get_website_crawl(
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
)
# Assert
assert isinstance(result, Generator)
messages = list(result)
assert len(messages) == 1
assert messages[0] == mock_message
mock_manager.get_website_crawl.assert_called_once_with(
tenant_id="test-tenant-id",
user_id=user_id,
datasource_provider="test-provider",
datasource_name="test-name",
credentials={"api_key": "test-key"},
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)

View File

@ -0,0 +1,95 @@
from unittest.mock import MagicMock, patch
import pytest
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceProviderEntityWithPlugin,
DatasourceProviderType,
)
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
class TestWebsiteCrawlDatasourcePluginProviderController:
@pytest.fixture
def mock_entity(self):
entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
entity.datasources = []
entity.identity = MagicMock()
entity.identity.icon = "test-icon"
return entity
def test_init(self, mock_entity):
# Arrange
plugin_id = "test-plugin-id"
plugin_unique_identifier = "test-unique-id"
tenant_id = "test-tenant-id"
# Act
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=mock_entity,
plugin_id=plugin_id,
plugin_unique_identifier=plugin_unique_identifier,
tenant_id=tenant_id,
)
# Assert
assert controller.entity == mock_entity
assert controller.plugin_id == plugin_id
assert controller.plugin_unique_identifier == plugin_unique_identifier
assert controller.tenant_id == tenant_id
def test_provider_type(self, mock_entity):
# Arrange
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
)
# Act & Assert
assert controller.provider_type == DatasourceProviderType.WEBSITE_CRAWL
def test_get_datasource_success(self, mock_entity):
# Arrange
datasource_name = "test-datasource"
tenant_id = "test-tenant-id"
plugin_unique_identifier = "test-unique-id"
mock_datasource_entity = MagicMock()
mock_datasource_entity.identity = MagicMock()
mock_datasource_entity.identity.name = datasource_name
mock_entity.datasources = [mock_datasource_entity]
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=mock_entity, plugin_id="test", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
)
# Act
with patch(
"core.datasource.website_crawl.website_crawl_provider.WebsiteCrawlDatasourcePlugin"
) as mock_plugin_class:
mock_plugin_instance = mock_plugin_class.return_value
result = controller.get_datasource(datasource_name)
# Assert
assert result == mock_plugin_instance
mock_plugin_class.assert_called_once()
args, kwargs = mock_plugin_class.call_args
assert kwargs["entity"] == mock_datasource_entity
assert isinstance(kwargs["runtime"], DatasourceRuntime)
assert kwargs["runtime"].tenant_id == tenant_id
assert kwargs["tenant_id"] == tenant_id
assert kwargs["icon"] == "test-icon"
assert kwargs["plugin_unique_identifier"] == plugin_unique_identifier
def test_get_datasource_not_found(self, mock_entity):
# Arrange
datasource_name = "non-existent"
mock_entity.datasources = []
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
)
# Act & Assert
with pytest.raises(ValueError, match=f"Datasource with name {datasource_name} not found"):
controller.get_datasource(datasource_name)