mirror of https://github.com/langgenius/dify.git
test: add test for api core datasource (#32414)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
parent
86b6868772
commit
1083f5c46a
|
|
@ -59,8 +59,6 @@ class DatasourcePluginProviderController(ABC):
|
|||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from configs import dify_config
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType
|
||||
|
||||
|
||||
class ConcreteDatasourcePlugin(DatasourcePlugin):
|
||||
"""
|
||||
Concrete implementation of DatasourcePlugin for testing purposes.
|
||||
Since DatasourcePlugin is an ABC, we need a concrete class to instantiate it.
|
||||
"""
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
|
||||
class TestDatasourcePlugin:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon.png"
|
||||
|
||||
# Act
|
||||
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
|
||||
|
||||
# Assert
|
||||
assert plugin.entity == entity
|
||||
assert plugin.runtime == runtime
|
||||
assert plugin.icon == icon
|
||||
|
||||
def test_datasource_provider_type(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon.png"
|
||||
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
|
||||
|
||||
# Act
|
||||
provider_type = plugin.datasource_provider_type()
|
||||
# Call the base class method to ensure it's covered
|
||||
base_provider_type = DatasourcePlugin.datasource_provider_type(plugin)
|
||||
|
||||
# Assert
|
||||
assert provider_type == DatasourceProviderType.LOCAL_FILE
|
||||
assert base_provider_type == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def test_fork_datasource_runtime(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_entity_copy = MagicMock(spec=DatasourceEntity)
|
||||
mock_entity.model_copy.return_value = mock_entity_copy
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
new_runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon.png"
|
||||
|
||||
plugin = ConcreteDatasourcePlugin(entity=mock_entity, runtime=runtime, icon=icon)
|
||||
|
||||
# Act
|
||||
new_plugin = plugin.fork_datasource_runtime(new_runtime)
|
||||
|
||||
# Assert
|
||||
assert isinstance(new_plugin, ConcreteDatasourcePlugin)
|
||||
assert new_plugin.entity == mock_entity_copy
|
||||
assert new_plugin.runtime == new_runtime
|
||||
assert new_plugin.icon == icon
|
||||
mock_entity.model_copy.assert_called_once()
|
||||
|
||||
def test_get_icon_url(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon.png"
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
|
||||
|
||||
# Mocking dify_config.CONSOLE_API_URL
|
||||
with patch.object(dify_config, "CONSOLE_API_URL", "https://api.dify.ai"):
|
||||
# Act
|
||||
icon_url = plugin.get_icon_url(tenant_id)
|
||||
|
||||
# Assert
|
||||
expected_url = (
|
||||
f"https://api.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={icon}"
|
||||
)
|
||||
assert icon_url == expected_url
|
||||
|
|
@ -0,0 +1,265 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ConcreteDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
"""
|
||||
Concrete implementation of DatasourcePluginProviderController for testing purposes.
|
||||
"""
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||
return MagicMock(spec=DatasourcePlugin)
|
||||
|
||||
|
||||
class TestDatasourcePluginProviderController:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
# Act
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_need_credentials(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
tenant_id = "test-tenant-id"
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
|
||||
|
||||
# Case 1: credentials_schema is None
|
||||
mock_entity.credentials_schema = None
|
||||
assert controller.need_credentials is False
|
||||
|
||||
# Case 2: credentials_schema is empty
|
||||
mock_entity.credentials_schema = []
|
||||
assert controller.need_credentials is False
|
||||
|
||||
# Case 3: credentials_schema has items
|
||||
mock_entity.credentials_schema = [MagicMock()]
|
||||
assert controller.need_credentials is True
|
||||
|
||||
@patch("core.datasource.__base.datasource_provider.PluginToolManager")
|
||||
def test_validate_credentials(self, mock_manager_class):
|
||||
# Arrange
|
||||
mock_manager = mock_manager_class.return_value
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.identity = MagicMock()
|
||||
mock_entity.identity.name = "test-provider"
|
||||
tenant_id = "test-tenant-id"
|
||||
user_id = "test-user-id"
|
||||
credentials = {"api_key": "secret"}
|
||||
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
|
||||
|
||||
# Act: Successful validation
|
||||
mock_manager.validate_datasource_credentials.return_value = True
|
||||
controller._validate_credentials(user_id, credentials)
|
||||
|
||||
mock_manager.validate_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider="test-provider",
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Act: Failed validation
|
||||
mock_manager.validate_datasource_credentials.return_value = False
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"):
|
||||
controller._validate_credentials(user_id, credentials)
|
||||
|
||||
def test_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act & Assert
|
||||
assert controller.provider_type == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def test_validate_credentials_format_empty_schema(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = []
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
credentials = {}
|
||||
|
||||
# Act & Assert (Should not raise anything)
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
def test_validate_credentials_format_unknown_credential(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.identity = MagicMock()
|
||||
mock_entity.identity.name = "test-provider"
|
||||
mock_entity.credentials_schema = []
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
credentials = {"unknown": "value"}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(
|
||||
ToolProviderCredentialValidationError, match="credential unknown not found in provider test-provider"
|
||||
):
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
def test_validate_credentials_format_required_missing(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "api_key"
|
||||
mock_config.required = True
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="credential api_key is required"):
|
||||
controller.validate_credentials_format({})
|
||||
|
||||
def test_validate_credentials_format_not_required_null(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "optional"
|
||||
mock_config.required = False
|
||||
mock_config.default = None
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act & Assert
|
||||
credentials = {"optional": None}
|
||||
controller.validate_credentials_format(credentials)
|
||||
assert credentials["optional"] is None
|
||||
|
||||
def test_validate_credentials_format_type_mismatch_text(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "text_field"
|
||||
mock_config.required = True
|
||||
mock_config.type = ProviderConfig.Type.TEXT_INPUT
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="credential text_field should be string"):
|
||||
controller.validate_credentials_format({"text_field": 123})
|
||||
|
||||
def test_validate_credentials_format_select_validation(self):
|
||||
# Arrange
|
||||
mock_option = MagicMock()
|
||||
mock_option.value = "opt1"
|
||||
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "select_field"
|
||||
mock_config.required = True
|
||||
mock_config.type = ProviderConfig.Type.SELECT
|
||||
mock_config.options = [mock_option]
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Case 1: Value not string
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be string"):
|
||||
controller.validate_credentials_format({"select_field": 123})
|
||||
|
||||
# Case 2: Options not list
|
||||
mock_config.options = "invalid"
|
||||
with pytest.raises(
|
||||
ToolProviderCredentialValidationError, match="credential select_field options should be list"
|
||||
):
|
||||
controller.validate_credentials_format({"select_field": "opt1"})
|
||||
|
||||
# Case 3: Value not in options
|
||||
mock_config.options = [mock_option]
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be one of"):
|
||||
controller.validate_credentials_format({"select_field": "invalid_opt"})
|
||||
|
||||
def test_get_datasource_base(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act
|
||||
result = DatasourcePluginProviderController.get_datasource(controller, "test")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_validate_credentials_format_hits_pop(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "valid_field"
|
||||
mock_config.required = True
|
||||
mock_config.type = ProviderConfig.Type.TEXT_INPUT
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act
|
||||
credentials = {"valid_field": "valid_value"}
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
# Assert
|
||||
assert "valid_field" in credentials
|
||||
assert credentials["valid_field"] == "valid_value"
|
||||
|
||||
def test_validate_credentials_format_hits_continue(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "optional_field"
|
||||
mock_config.required = False
|
||||
mock_config.default = None
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act
|
||||
credentials = {"optional_field": None}
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
# Assert
|
||||
assert credentials["optional_field"] is None
|
||||
|
||||
def test_validate_credentials_format_default_values(self):
|
||||
# Arrange
|
||||
mock_config_text = MagicMock(spec=ProviderConfig)
|
||||
mock_config_text.name = "text_def"
|
||||
mock_config_text.required = False
|
||||
mock_config_text.type = ProviderConfig.Type.TEXT_INPUT
|
||||
mock_config_text.default = 123 # Int default, should be converted to str
|
||||
|
||||
mock_config_other = MagicMock(spec=ProviderConfig)
|
||||
mock_config_other.name = "other_def"
|
||||
mock_config_other.required = False
|
||||
mock_config_other.type = "OTHER"
|
||||
mock_config_other.default = "fallback"
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config_text, mock_config_other]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act
|
||||
credentials = {}
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
# Assert
|
||||
assert credentials["text_def"] == "123"
|
||||
assert credentials["other_def"] == "fallback"
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime, FakeDatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||
|
||||
|
||||
class TestDatasourceRuntime:
|
||||
def test_init(self):
|
||||
runtime = DatasourceRuntime(
|
||||
tenant_id="test-tenant",
|
||||
datasource_id="test-ds",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||
credentials={"key": "val"},
|
||||
runtime_parameters={"p": "v"},
|
||||
)
|
||||
assert runtime.tenant_id == "test-tenant"
|
||||
assert runtime.datasource_id == "test-ds"
|
||||
assert runtime.credentials["key"] == "val"
|
||||
|
||||
def test_fake_datasource_runtime(self):
|
||||
# This covers the FakeDatasourceRuntime class and its __init__
|
||||
runtime = FakeDatasourceRuntime()
|
||||
assert runtime.tenant_id == "fake_tenant_id"
|
||||
assert runtime.datasource_id == "fake_datasource_id"
|
||||
assert runtime.invoke_from == InvokeFrom.DEBUGGER
|
||||
assert runtime.datasource_invoke_from == DatasourceInvokeFrom.RAG_PIPELINE
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
def test_datasource_api_entity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
|
||||
entity = DatasourceApiEntity(
|
||||
author="author", name="name", label=label, description=description, labels=["l1", "l2"]
|
||||
)
|
||||
|
||||
assert entity.author == "author"
|
||||
assert entity.name == "name"
|
||||
assert entity.label == label
|
||||
assert entity.description == description
|
||||
assert entity.labels == ["l1", "l2"]
|
||||
assert entity.parameters is None
|
||||
assert entity.output_schema is None
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_defaults():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
entity = DatasourceProviderApiEntity(
|
||||
id="id", author="author", name="name", description=description, icon="icon", label=label, type="type"
|
||||
)
|
||||
|
||||
assert entity.id == "id"
|
||||
assert entity.datasources == []
|
||||
assert entity.is_team_authorization is False
|
||||
assert entity.allow_delete is True
|
||||
assert entity.plugin_id == ""
|
||||
assert entity.plugin_unique_identifier == ""
|
||||
assert entity.labels == []
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_convert_none_to_empty_list():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
# Implicitly testing the field_validator "convert_none_to_empty_list"
|
||||
entity = DatasourceProviderApiEntity(
|
||||
id="id",
|
||||
author="author",
|
||||
name="name",
|
||||
description=description,
|
||||
icon="icon",
|
||||
label=label,
|
||||
type="type",
|
||||
datasources=None, # type: ignore
|
||||
)
|
||||
|
||||
assert entity.datasources == []
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_to_dict():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
# Create a parameter that should be converted
|
||||
param = DatasourceParameter.get_simple_instance(
|
||||
name="test_param", typ=DatasourceParameter.DatasourceParameterType.SYSTEM_FILES, required=True
|
||||
)
|
||||
|
||||
ds_entity = DatasourceApiEntity(
|
||||
author="author", name="ds_name", label=label, description=description, parameters=[param]
|
||||
)
|
||||
|
||||
provider_entity = DatasourceProviderApiEntity(
|
||||
id="id",
|
||||
author="author",
|
||||
name="name",
|
||||
description=description,
|
||||
icon="icon",
|
||||
label=label,
|
||||
type="type",
|
||||
masked_credentials={"key": "masked"},
|
||||
datasources=[ds_entity],
|
||||
labels=["l1"],
|
||||
)
|
||||
|
||||
result = provider_entity.to_dict()
|
||||
|
||||
assert result["id"] == "id"
|
||||
assert result["author"] == "author"
|
||||
assert result["name"] == "name"
|
||||
assert result["description"] == description.to_dict()
|
||||
assert result["icon"] == "icon"
|
||||
assert result["label"] == label.to_dict()
|
||||
assert result["type"] == "type"
|
||||
assert result["team_credentials"] == {"key": "masked"}
|
||||
assert result["is_team_authorization"] is False
|
||||
assert result["allow_delete"] is True
|
||||
assert result["labels"] == ["l1"]
|
||||
|
||||
# Check if parameter type was converted from SYSTEM_FILES to files
|
||||
assert result["datasources"][0]["parameters"][0]["type"] == "files"
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_to_dict_no_params():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
ds_entity = DatasourceApiEntity(
|
||||
author="author", name="ds_name", label=label, description=description, parameters=None
|
||||
)
|
||||
|
||||
provider_entity = DatasourceProviderApiEntity(
|
||||
id="id",
|
||||
author="author",
|
||||
name="name",
|
||||
description=description,
|
||||
icon="icon",
|
||||
label=label,
|
||||
type="type",
|
||||
datasources=[ds_entity],
|
||||
)
|
||||
|
||||
result = provider_entity.to_dict()
|
||||
assert result["datasources"][0]["parameters"] is None
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_to_dict_other_param_type():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
param = DatasourceParameter.get_simple_instance(
|
||||
name="test_param", typ=DatasourceParameter.DatasourceParameterType.STRING, required=True
|
||||
)
|
||||
|
||||
ds_entity = DatasourceApiEntity(
|
||||
author="author", name="ds_name", label=label, description=description, parameters=[param]
|
||||
)
|
||||
|
||||
provider_entity = DatasourceProviderApiEntity(
|
||||
id="id",
|
||||
author="author",
|
||||
name="name",
|
||||
description=description,
|
||||
icon="icon",
|
||||
label=label,
|
||||
type="type",
|
||||
datasources=[ds_entity],
|
||||
)
|
||||
|
||||
result = provider_entity.to_dict()
|
||||
assert result["datasources"][0]["parameters"][0]["type"] == "string"
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from core.datasource.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
def test_i18n_object_fallback():
|
||||
# Only en_US provided
|
||||
obj = I18nObject(en_US="Hello")
|
||||
assert obj.en_US == "Hello"
|
||||
assert obj.zh_Hans == "Hello"
|
||||
assert obj.pt_BR == "Hello"
|
||||
assert obj.ja_JP == "Hello"
|
||||
|
||||
# Some fields provided
|
||||
obj = I18nObject(en_US="Hello", zh_Hans="你好")
|
||||
assert obj.en_US == "Hello"
|
||||
assert obj.zh_Hans == "你好"
|
||||
assert obj.pt_BR == "Hello"
|
||||
assert obj.ja_JP == "Hello"
|
||||
|
||||
|
||||
def test_i18n_object_all_fields():
|
||||
obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは")
|
||||
assert obj.en_US == "Hello"
|
||||
assert obj.zh_Hans == "你好"
|
||||
assert obj.pt_BR == "Olá"
|
||||
assert obj.ja_JP == "こんにちは"
|
||||
|
||||
|
||||
def test_i18n_object_to_dict():
|
||||
obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは")
|
||||
expected_dict = {"en_US": "Hello", "zh_Hans": "你好", "pt_BR": "Olá", "ja_JP": "こんにちは"}
|
||||
assert obj.to_dict() == expected_dict
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceIdentity,
|
||||
DatasourceInvokeMeta,
|
||||
DatasourceLabel,
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderEntity,
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderIdentity,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetWebsiteCrawlRequest,
|
||||
OnlineDocumentInfo,
|
||||
OnlineDocumentPage,
|
||||
OnlineDocumentPageContent,
|
||||
OnlineDocumentPagesMessage,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
OnlineDriveFile,
|
||||
OnlineDriveFileBucket,
|
||||
WebsiteCrawlMessage,
|
||||
WebSiteInfo,
|
||||
WebSiteInfoDetail,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabelEnum
|
||||
|
||||
|
||||
def test_datasource_provider_type():
|
||||
assert DatasourceProviderType.value_of("online_document") == DatasourceProviderType.ONLINE_DOCUMENT
|
||||
assert DatasourceProviderType.value_of("local_file") == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
with pytest.raises(ValueError, match="invalid mode value invalid"):
|
||||
DatasourceProviderType.value_of("invalid")
|
||||
|
||||
|
||||
def test_datasource_parameter_type():
|
||||
param_type = DatasourceParameter.DatasourceParameterType.STRING
|
||||
assert param_type.as_normal_type() == "string"
|
||||
assert param_type.cast_value("test") == "test"
|
||||
|
||||
param_type = DatasourceParameter.DatasourceParameterType.NUMBER
|
||||
assert param_type.cast_value("123") == 123
|
||||
|
||||
|
||||
def test_datasource_parameter():
|
||||
param = DatasourceParameter.get_simple_instance(
|
||||
name="test_param",
|
||||
typ=DatasourceParameter.DatasourceParameterType.STRING,
|
||||
required=True,
|
||||
options=["opt1", "opt2"],
|
||||
)
|
||||
assert param.name == "test_param"
|
||||
assert param.type == DatasourceParameter.DatasourceParameterType.STRING
|
||||
assert param.required is True
|
||||
assert len(param.options) == 2
|
||||
assert param.options[0].value == "opt1"
|
||||
|
||||
param_no_options = DatasourceParameter.get_simple_instance(
|
||||
name="test_param_2", typ=DatasourceParameter.DatasourceParameterType.NUMBER, required=False
|
||||
)
|
||||
assert param_no_options.options == []
|
||||
|
||||
# Test init_frontend_parameter
|
||||
# For STRING, it should just return the value as is (or cast to str)
|
||||
frontend_param = param.init_frontend_parameter("val")
|
||||
assert frontend_param == "val"
|
||||
|
||||
# Test parameter type methods
|
||||
assert DatasourceParameter.DatasourceParameterType.STRING.as_normal_type() == "string"
|
||||
assert DatasourceParameter.DatasourceParameterType.NUMBER.as_normal_type() == "number"
|
||||
assert DatasourceParameter.DatasourceParameterType.SECRET_INPUT.as_normal_type() == "string"
|
||||
|
||||
assert DatasourceParameter.DatasourceParameterType.NUMBER.cast_value("10.5") == 10.5
|
||||
assert DatasourceParameter.DatasourceParameterType.BOOLEAN.cast_value("true") is True
|
||||
assert DatasourceParameter.DatasourceParameterType.FILES.cast_value(["f1", "f2"]) == ["f1", "f2"]
|
||||
|
||||
|
||||
def test_datasource_identity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider", icon="icon")
|
||||
assert identity.author == "author"
|
||||
assert identity.name == "name"
|
||||
assert identity.label == label
|
||||
assert identity.provider == "provider"
|
||||
assert identity.icon == "icon"
|
||||
|
||||
|
||||
def test_datasource_entity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
|
||||
entity = DatasourceEntity(
|
||||
identity=identity,
|
||||
description=description,
|
||||
parameters=None, # Should be handled by validator
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
param = DatasourceParameter.get_simple_instance("p1", DatasourceParameter.DatasourceParameterType.STRING, True)
|
||||
entity_with_params = DatasourceEntity(identity=identity, description=description, parameters=[param])
|
||||
assert entity_with_params.parameters == [param]
|
||||
|
||||
|
||||
def test_datasource_provider_identity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
identity = DatasourceProviderIdentity(
|
||||
author="author", name="name", description=description, icon="icon.png", label=label, tags=[ToolLabelEnum.SEARCH]
|
||||
)
|
||||
|
||||
assert identity.author == "author"
|
||||
assert identity.name == "name"
|
||||
assert identity.description == description
|
||||
assert identity.icon == "icon.png"
|
||||
assert identity.label == label
|
||||
assert identity.tags == [ToolLabelEnum.SEARCH]
|
||||
|
||||
# Test generate_datasource_icon_url
|
||||
with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config:
|
||||
mock_config.CONSOLE_API_URL = "http://api.example.com"
|
||||
url = identity.generate_datasource_icon_url("tenant123")
|
||||
assert "http://api.example.com/console/api/workspaces/current/plugin/icon" in url
|
||||
assert "tenant_id=tenant123" in url
|
||||
assert "filename=icon.png" in url
|
||||
|
||||
# Test hardcoded icon
|
||||
identity.icon = "https://assets.dify.ai/images/File%20Upload.svg"
|
||||
assert identity.generate_datasource_icon_url("tenant123") == identity.icon
|
||||
|
||||
# Test with empty CONSOLE_API_URL
|
||||
identity.icon = "test.png"
|
||||
with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config:
|
||||
mock_config.CONSOLE_API_URL = None
|
||||
url = identity.generate_datasource_icon_url("tenant123")
|
||||
assert url.startswith("/console/api/workspaces/current/plugin/icon")
|
||||
|
||||
|
||||
def test_datasource_provider_entity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
identity = DatasourceProviderIdentity(
|
||||
author="author", name="name", description=description, icon="icon", label=label
|
||||
)
|
||||
|
||||
entity = DatasourceProviderEntity(
|
||||
identity=identity,
|
||||
provider_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
credentials_schema=[],
|
||||
oauth_schema=None,
|
||||
)
|
||||
assert entity.identity == identity
|
||||
assert entity.provider_type == DatasourceProviderType.ONLINE_DOCUMENT
|
||||
assert entity.credentials_schema == []
|
||||
|
||||
|
||||
def test_datasource_provider_entity_with_plugin():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
identity = DatasourceProviderIdentity(
|
||||
author="author", name="name", description=description, icon="icon", label=label
|
||||
)
|
||||
|
||||
entity = DatasourceProviderEntityWithPlugin(
|
||||
identity=identity, provider_type=DatasourceProviderType.ONLINE_DOCUMENT, datasources=[]
|
||||
)
|
||||
assert entity.datasources == []
|
||||
|
||||
|
||||
def test_datasource_invoke_meta():
|
||||
meta = DatasourceInvokeMeta(time_cost=1.5, error="some error", tool_config={"k": "v"})
|
||||
assert meta.time_cost == 1.5
|
||||
assert meta.error == "some error"
|
||||
assert meta.tool_config == {"k": "v"}
|
||||
|
||||
d = meta.to_dict()
|
||||
assert d == {"time_cost": 1.5, "error": "some error", "tool_config": {"k": "v"}}
|
||||
|
||||
empty_meta = DatasourceInvokeMeta.empty()
|
||||
assert empty_meta.time_cost == 0.0
|
||||
assert empty_meta.error is None
|
||||
assert empty_meta.tool_config == {}
|
||||
|
||||
error_meta = DatasourceInvokeMeta.error_instance("fatal error")
|
||||
assert error_meta.time_cost == 0.0
|
||||
assert error_meta.error == "fatal error"
|
||||
assert error_meta.tool_config == {}
|
||||
|
||||
|
||||
def test_datasource_label():
|
||||
label_obj = I18nObject(en_US="label", zh_Hans="标签")
|
||||
ds_label = DatasourceLabel(name="name", label=label_obj, icon="icon")
|
||||
assert ds_label.name == "name"
|
||||
assert ds_label.label == label_obj
|
||||
assert ds_label.icon == "icon"
|
||||
|
||||
|
||||
def test_online_document_models():
|
||||
page = OnlineDocumentPage(
|
||||
page_id="p1",
|
||||
page_name="name",
|
||||
page_icon={"type": "emoji"},
|
||||
type="page",
|
||||
last_edited_time="2023-01-01",
|
||||
parent_id=None,
|
||||
)
|
||||
assert page.page_id == "p1"
|
||||
|
||||
info = OnlineDocumentInfo(workspace_id="w1", workspace_name="name", workspace_icon="icon", total=1, pages=[page])
|
||||
assert info.total == 1
|
||||
|
||||
msg = OnlineDocumentPagesMessage(result=[info])
|
||||
assert msg.result == [info]
|
||||
|
||||
req = GetOnlineDocumentPageContentRequest(workspace_id="w1", page_id="p1", type="page")
|
||||
assert req.workspace_id == "w1"
|
||||
|
||||
content = OnlineDocumentPageContent(workspace_id="w1", page_id="p1", content="hello")
|
||||
assert content.content == "hello"
|
||||
|
||||
resp = GetOnlineDocumentPageContentResponse(result=content)
|
||||
assert resp.result == content
|
||||
|
||||
|
||||
def test_website_crawl_models():
|
||||
req = GetWebsiteCrawlRequest(crawl_parameters={"url": "http://test.com"})
|
||||
assert req.crawl_parameters == {"url": "http://test.com"}
|
||||
|
||||
detail = WebSiteInfoDetail(source_url="http://test.com", content="content", title="title", description="desc")
|
||||
assert detail.title == "title"
|
||||
|
||||
info = WebSiteInfo(status="completed", web_info_list=[detail], total=1, completed=1)
|
||||
assert info.status == "completed"
|
||||
|
||||
msg = WebsiteCrawlMessage(result=info)
|
||||
assert msg.result == info
|
||||
|
||||
# Test default values
|
||||
msg_default = WebsiteCrawlMessage()
|
||||
assert msg_default.result.status == ""
|
||||
assert msg_default.result.web_info_list == []
|
||||
|
||||
|
||||
def test_online_drive_models():
|
||||
file = OnlineDriveFile(id="f1", name="file.txt", size=100, type="file")
|
||||
assert file.name == "file.txt"
|
||||
|
||||
bucket = OnlineDriveFileBucket(bucket="b1", files=[file], is_truncated=False, next_page_parameters=None)
|
||||
assert bucket.bucket == "b1"
|
||||
|
||||
req = OnlineDriveBrowseFilesRequest(bucket="b1", prefix="folder1", max_keys=10, next_page_parameters=None)
|
||||
assert req.prefix == "folder1"
|
||||
|
||||
resp = OnlineDriveBrowseFilesResponse(result=[bucket])
|
||||
assert resp.result == [bucket]
|
||||
|
||||
dl_req = OnlineDriveDownloadFileRequest(id="f1", bucket="b1")
|
||||
assert dl_req.id == "f1"
|
||||
|
||||
|
||||
def test_datasource_message():
|
||||
# Use proper dict for message to avoid Pydantic Union validation ambiguity/crashes
|
||||
msg = DatasourceMessage(type="text", message={"text": "hello"})
|
||||
assert msg.message.text == "hello"
|
||||
|
||||
msg_json = DatasourceMessage(type="json", message={"json_object": {"k": "v"}})
|
||||
assert msg_json.message.json_object == {"k": "v"}
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
|
||||
|
||||
class TestLocalFileDatasourcePlugin:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_runtime = MagicMock(spec=DatasourceRuntime)
|
||||
tenant_id = "test-tenant-id"
|
||||
icon = "test-icon"
|
||||
plugin_unique_identifier = "test-plugin-id"
|
||||
|
||||
# Act
|
||||
plugin = LocalFileDatasourcePlugin(
|
||||
entity=mock_entity,
|
||||
runtime=mock_runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert plugin.tenant_id == tenant_id
|
||||
assert plugin.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert plugin.entity == mock_entity
|
||||
assert plugin.runtime == mock_runtime
|
||||
assert plugin.icon == icon
|
||||
|
||||
def test_datasource_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_runtime = MagicMock(spec=DatasourceRuntime)
|
||||
plugin = LocalFileDatasourcePlugin(
|
||||
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert plugin.datasource_provider_type() == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def test_get_icon_url(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon"
|
||||
plugin = LocalFileDatasourcePlugin(
|
||||
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon=icon, plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert plugin.get_icon_url("any-tenant-id") == icon
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
|
||||
|
||||
class TestLocalFileDatasourcePluginProviderController:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
plugin_id = "test_plugin_id"
|
||||
plugin_unique_identifier = "test_plugin_unique_identifier"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
# Act
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id=plugin_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.plugin_id == plugin_id
|
||||
assert controller.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert controller.provider_type == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def test_validate_credentials(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
# Should not raise any exception
|
||||
controller._validate_credentials("user_id", {"key": "value"})
|
||||
|
||||
def test_get_datasource_success(self):
|
||||
# Arrange
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity.name = "test_datasource"
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
mock_entity.identity.icon = "test_icon"
|
||||
|
||||
plugin_unique_identifier = "test_plugin_unique_identifier"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Act
|
||||
datasource = controller.get_datasource("test_datasource")
|
||||
|
||||
# Assert
|
||||
assert isinstance(datasource, LocalFileDatasourcePlugin)
|
||||
assert datasource.entity == mock_datasource_entity
|
||||
assert datasource.tenant_id == tenant_id
|
||||
assert datasource.icon == "test_icon"
|
||||
assert datasource.plugin_unique_identifier == plugin_unique_identifier
|
||||
|
||||
def test_get_datasource_not_found(self):
|
||||
# Arrange
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity.name = "other_datasource"
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Datasource with name test_datasource not found"):
|
||||
controller.get_datasource("test_datasource")
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceIdentity,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
|
||||
|
||||
class TestOnlineDocumentDatasourcePlugin:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
# Act
|
||||
plugin = OnlineDocumentDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert plugin.entity == entity
|
||||
assert plugin.runtime == runtime
|
||||
assert plugin.tenant_id == tenant_id
|
||||
assert plugin.icon == icon
|
||||
assert plugin.plugin_unique_identifier == plugin_unique_identifier
|
||||
|
||||
def test_get_online_document_pages(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
identity = MagicMock(spec=DatasourceIdentity)
|
||||
entity.identity = identity
|
||||
identity.provider = "test_provider"
|
||||
identity.name = "test_name"
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"api_key": "test_key"}
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
plugin = OnlineDocumentDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
user_id = "test_user"
|
||||
datasource_parameters = {"param": "value"}
|
||||
provider_type = "test_type"
|
||||
|
||||
mock_generator = MagicMock()
|
||||
|
||||
# Patch PluginDatasourceManager to isolate plugin behavior from external dependencies
|
||||
with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager:
|
||||
mock_manager_instance = MockManager.return_value
|
||||
mock_manager_instance.get_online_document_pages.return_value = mock_generator
|
||||
|
||||
# Act
|
||||
result = plugin.get_online_document_pages(
|
||||
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_generator
|
||||
mock_manager_instance.get_online_document_pages.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider="test_provider",
|
||||
datasource_name="test_name",
|
||||
credentials=runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def test_get_online_document_page_content(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
identity = MagicMock(spec=DatasourceIdentity)
|
||||
entity.identity = identity
|
||||
identity.provider = "test_provider"
|
||||
identity.name = "test_name"
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"api_key": "test_key"}
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
plugin = OnlineDocumentDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
user_id = "test_user"
|
||||
datasource_parameters = MagicMock(spec=GetOnlineDocumentPageContentRequest)
|
||||
provider_type = "test_type"
|
||||
|
||||
mock_generator = MagicMock()
|
||||
|
||||
with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager:
|
||||
mock_manager_instance = MockManager.return_value
|
||||
mock_manager_instance.get_online_document_page_content.return_value = mock_generator
|
||||
|
||||
# Act
|
||||
result = plugin.get_online_document_page_content(
|
||||
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_generator
|
||||
mock_manager_instance.get_online_document_page_content.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider="test_provider",
|
||||
datasource_name="test_name",
|
||||
credentials=runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def test_datasource_provider_type(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
plugin = OnlineDocumentDatasourcePlugin(
|
||||
entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act
|
||||
result = plugin.datasource_provider_type()
|
||||
|
||||
# Assert
|
||||
assert result == DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
|
||||
|
||||
class TestOnlineDocumentDatasourcePluginProviderController:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
plugin_id = "test_plugin_id"
|
||||
plugin_unique_identifier = "test_plugin_uid"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
# Act
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id=plugin_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.plugin_id == plugin_id
|
||||
assert controller.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.provider_type == DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def test_get_datasource_success(self):
|
||||
# Arrange
|
||||
from core.datasource.entities.datasource_entities import DatasourceIdentity
|
||||
|
||||
mock_datasource_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity)
|
||||
mock_datasource_entity.identity.name = "target_datasource"
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
mock_entity.identity = MagicMock()
|
||||
mock_entity.identity.icon = "test_icon"
|
||||
|
||||
plugin_unique_identifier = "test_plugin_uid"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id="test_plugin_id",
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = controller.get_datasource("target_datasource")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OnlineDocumentDatasourcePlugin)
|
||||
assert result.entity == mock_datasource_entity
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.icon == "test_icon"
|
||||
assert result.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert result.runtime.tenant_id == tenant_id
|
||||
|
||||
def test_get_datasource_not_found(self):
|
||||
# Arrange
|
||||
from core.datasource.entities.datasource_entities import DatasourceIdentity
|
||||
|
||||
mock_datasource_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity)
|
||||
mock_datasource_entity.identity.name = "other_datasource"
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id="test_plugin_id",
|
||||
plugin_unique_identifier="test_plugin_uid",
|
||||
tenant_id="test_tenant_id",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Datasource with name missing_datasource not found"):
|
||||
controller.get_datasource("missing_datasource")
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceIdentity,
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
|
||||
|
||||
class TestOnlineDriveDatasourcePlugin:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
# Act
|
||||
plugin = OnlineDriveDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert plugin.entity == entity
|
||||
assert plugin.runtime == runtime
|
||||
assert plugin.tenant_id == tenant_id
|
||||
assert plugin.icon == icon
|
||||
assert plugin.plugin_unique_identifier == plugin_unique_identifier
|
||||
|
||||
def test_online_drive_browse_files(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
identity = MagicMock(spec=DatasourceIdentity)
|
||||
entity.identity = identity
|
||||
identity.provider = "test_provider"
|
||||
identity.name = "test_name"
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"token": "test_token"}
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
plugin = OnlineDriveDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
user_id = "test_user"
|
||||
request = MagicMock(spec=OnlineDriveBrowseFilesRequest)
|
||||
provider_type = "test_type"
|
||||
|
||||
mock_generator = MagicMock()
|
||||
|
||||
with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager:
|
||||
mock_manager_instance = MockManager.return_value
|
||||
mock_manager_instance.online_drive_browse_files.return_value = mock_generator
|
||||
|
||||
# Act
|
||||
result = plugin.online_drive_browse_files(user_id=user_id, request=request, provider_type=provider_type)
|
||||
|
||||
# Assert
|
||||
assert result == mock_generator
|
||||
mock_manager_instance.online_drive_browse_files.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider="test_provider",
|
||||
datasource_name="test_name",
|
||||
credentials=runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def test_online_drive_download_file(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
identity = MagicMock(spec=DatasourceIdentity)
|
||||
entity.identity = identity
|
||||
identity.provider = "test_provider"
|
||||
identity.name = "test_name"
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"token": "test_token"}
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
plugin = OnlineDriveDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
user_id = "test_user"
|
||||
request = MagicMock(spec=OnlineDriveDownloadFileRequest)
|
||||
provider_type = "test_type"
|
||||
|
||||
mock_generator = MagicMock()
|
||||
|
||||
with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager:
|
||||
mock_manager_instance = MockManager.return_value
|
||||
mock_manager_instance.online_drive_download_file.return_value = mock_generator
|
||||
|
||||
# Act
|
||||
result = plugin.online_drive_download_file(user_id=user_id, request=request, provider_type=provider_type)
|
||||
|
||||
# Assert
|
||||
assert result == mock_generator
|
||||
mock_manager_instance.online_drive_download_file.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider="test_provider",
|
||||
datasource_name="test_name",
|
||||
credentials=runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def test_datasource_provider_type(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
plugin = OnlineDriveDatasourcePlugin(
|
||||
entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act
|
||||
result = plugin.datasource_provider_type()
|
||||
|
||||
# Assert
|
||||
assert result == DatasourceProviderType.ONLINE_DRIVE
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
||||
|
||||
|
||||
class TestOnlineDriveDatasourcePluginProviderController:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
plugin_id = "test_plugin_id"
|
||||
plugin_unique_identifier = "test_plugin_unique_identifier"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
# Act
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id=plugin_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.plugin_id == plugin_id
|
||||
assert controller.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert controller.provider_type == DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
def test_get_datasource_success(self):
|
||||
# Arrange
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity.name = "test_datasource"
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
mock_entity.identity.icon = "test_icon"
|
||||
|
||||
plugin_unique_identifier = "test_plugin_unique_identifier"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Act
|
||||
datasource = controller.get_datasource("test_datasource")
|
||||
|
||||
# Assert
|
||||
assert isinstance(datasource, OnlineDriveDatasourcePlugin)
|
||||
assert datasource.entity == mock_datasource_entity
|
||||
assert datasource.tenant_id == tenant_id
|
||||
assert datasource.icon == "test_icon"
|
||||
assert datasource.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert datasource.runtime.tenant_id == tenant_id
|
||||
|
||||
def test_get_datasource_not_found(self):
|
||||
# Arrange
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity.name = "other_datasource"
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Datasource with name test_datasource not found"):
|
||||
controller.get_datasource("test_datasource")
|
||||
|
|
@ -0,0 +1,409 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
|
||||
class TestDatasourceFileManager:
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.os.urandom")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_sign_file(self, mock_config, mock_urandom, mock_time):
|
||||
# Setup
|
||||
mock_config.FILES_URL = "http://localhost:5001"
|
||||
mock_config.SECRET_KEY = "test_secret"
|
||||
mock_time.return_value = 1700000000
|
||||
mock_urandom.return_value = b"1234567890abcdef" # 16 bytes
|
||||
|
||||
datasource_file_id = "file_id_123"
|
||||
extension = ".png"
|
||||
|
||||
# Execute
|
||||
signed_url = DatasourceFileManager.sign_file(datasource_file_id, extension)
|
||||
|
||||
# Verify
|
||||
assert signed_url.startswith("http://localhost:5001/files/datasources/file_id_123.png?")
|
||||
assert "timestamp=1700000000" in signed_url
|
||||
assert f"nonce={mock_urandom.return_value.hex()}" in signed_url
|
||||
assert "sign=" in signed_url
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.os.urandom")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_sign_file_empty_secret(self, mock_config, mock_urandom, mock_time):
|
||||
# Setup
|
||||
mock_config.FILES_URL = "http://localhost:5001"
|
||||
mock_config.SECRET_KEY = None # Empty secret
|
||||
mock_time.return_value = 1700000000
|
||||
mock_urandom.return_value = b"1234567890abcdef"
|
||||
|
||||
# Execute
|
||||
signed_url = DatasourceFileManager.sign_file("file_id", ".png")
|
||||
assert "sign=" in signed_url
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_verify_file(self, mock_config, mock_time):
|
||||
# Setup
|
||||
mock_config.SECRET_KEY = "test_secret"
|
||||
mock_config.FILES_ACCESS_TIMEOUT = 300
|
||||
mock_time.return_value = 1700000000
|
||||
|
||||
datasource_file_id = "file_id_123"
|
||||
timestamp = "1699999800" # 200 seconds ago
|
||||
nonce = "some_nonce"
|
||||
|
||||
# Manually calculate sign
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = b"test_secret"
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
# Execute & Verify Success
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True
|
||||
|
||||
# Verify Failure - Wrong Sign
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False
|
||||
|
||||
# Verify Failure - Timeout
|
||||
mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout)
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_verify_file_empty_secret(self, mock_config, mock_time):
|
||||
# Setup
|
||||
mock_config.SECRET_KEY = "" # Empty string secret
|
||||
mock_config.FILES_ACCESS_TIMEOUT = 300
|
||||
mock_time.return_value = 1700000000
|
||||
|
||||
datasource_file_id = "file_id_123"
|
||||
timestamp = "1699999800"
|
||||
nonce = "some_nonce"
|
||||
|
||||
# Calculate with empty secret
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(b"", data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_create_file_by_raw(self, mock_config, mock_uuid, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
user_id = "user_123"
|
||||
tenant_id = "tenant_456"
|
||||
file_binary = b"fake binary data"
|
||||
mimetype = "image/png"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mimetype,
|
||||
filename="test.png",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert upload_file.tenant_id == tenant_id
|
||||
assert upload_file.name == "test.png"
|
||||
assert upload_file.size == len(file_binary)
|
||||
assert upload_file.mime_type == mimetype
|
||||
assert upload_file.key == f"datasources/{tenant_id}/unique_hex.png"
|
||||
|
||||
mock_storage.save.assert_called_once_with(upload_file.key, file_binary)
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_create_file_by_raw_filename_no_extension(self, mock_config, mock_uuid, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
user_id = "user_123"
|
||||
tenant_id = "tenant_456"
|
||||
file_binary = b"fake binary data"
|
||||
mimetype = "image/png"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mimetype,
|
||||
filename="test", # No extension
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert upload_file.name == "test.png" # Should append extension
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
@patch("core.datasource.datasource_file_manager.guess_extension")
|
||||
def test_create_file_by_raw_unknown_extension(self, mock_guess_ext, mock_config, mock_uuid, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_guess_ext.return_value = None # Cannot guess
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id="user",
|
||||
tenant_id="tenant",
|
||||
conversation_id=None,
|
||||
file_binary=b"data",
|
||||
mimetype="application/x-unknown",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert upload_file.extension == ".bin"
|
||||
assert upload_file.name == "unique_hex.bin"
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_create_file_by_raw_no_filename(self, mock_config, mock_uuid, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id="user_123",
|
||||
tenant_id="tenant_456",
|
||||
conversation_id=None,
|
||||
file_binary=b"data",
|
||||
mimetype="application/pdf",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert upload_file.name == "unique_hex.pdf"
|
||||
assert upload_file.extension == ".pdf"
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
def test_create_file_by_url_mimetype_from_guess(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"bits"
|
||||
mock_response.headers = {} # No content-type in headers
|
||||
mock_ssrf.get.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
tool_file = DatasourceFileManager.create_file_by_url(
|
||||
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.png"
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert tool_file.mimetype == "image/png" # Guessed from .png in URL
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
def test_create_file_by_url_mimetype_default(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"bits"
|
||||
mock_response.headers = {}
|
||||
mock_ssrf.get.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
tool_file = DatasourceFileManager.create_file_by_url(
|
||||
user_id="user_123",
|
||||
tenant_id="tenant_456",
|
||||
file_url="https://example.com/unknown", # No extension, no headers
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert tool_file.mimetype == "application/octet-stream"
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
def test_create_file_by_url_success(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"downloaded bits"
|
||||
mock_response.headers = {"Content-Type": "image/jpeg"}
|
||||
mock_ssrf.get.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
tool_file = DatasourceFileManager.create_file_by_url(
|
||||
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.jpg"
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert tool_file.mimetype == "image/jpeg"
|
||||
assert tool_file.size == len(b"downloaded bits")
|
||||
assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg"
|
||||
mock_storage.save.assert_called_once()
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
def test_create_file_by_url_timeout(self, mock_ssrf):
|
||||
# Setup
|
||||
mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout")
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match="timeout when downloading file"):
|
||||
DatasourceFileManager.create_file_by_url(
|
||||
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/large.file"
|
||||
)
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary(self, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_upload_file = MagicMock(spec=UploadFile)
|
||||
mock_upload_file.key = "some_key"
|
||||
mock_upload_file.mime_type = "image/png"
|
||||
|
||||
mock_query = mock_db.session.query.return_value
|
||||
mock_where = mock_query.where.return_value
|
||||
mock_where.first.return_value = mock_upload_file
|
||||
|
||||
mock_storage.load_once.return_value = b"file content"
|
||||
|
||||
# Execute
|
||||
result = DatasourceFileManager.get_file_binary("file_id")
|
||||
|
||||
# Verify
|
||||
assert result == (b"file content", "image/png")
|
||||
|
||||
# Case: Not found
|
||||
mock_where.first.return_value = None
|
||||
assert DatasourceFileManager.get_file_binary("unknown") is None
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary_by_message_file_id(self, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_message_file = MagicMock(spec=MessageFile)
|
||||
mock_message_file.url = "http://localhost/files/tools/tool_id.png"
|
||||
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.file_key = "tool_key"
|
||||
mock_tool_file.mimetype = "image/png"
|
||||
|
||||
# Mock query sequence
|
||||
def mock_query(model):
|
||||
m = MagicMock()
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
elif model == ToolFile:
|
||||
m.where.return_value.first.return_value = mock_tool_file
|
||||
return m
|
||||
|
||||
mock_db.session.query.side_effect = mock_query
|
||||
mock_storage.load_once.return_value = b"tool content"
|
||||
|
||||
# Execute
|
||||
result = DatasourceFileManager.get_file_binary_by_message_file_id("msg_file_id")
|
||||
|
||||
# Verify
|
||||
assert result == (b"tool content", "image/png")
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary_by_message_file_id_with_extension(self, mock_storage, mock_db):
|
||||
# Test that it correctly parses tool_id even with extension in URL
|
||||
mock_message_file = MagicMock(spec=MessageFile)
|
||||
mock_message_file.url = "http://localhost/files/tools/abcdef.png"
|
||||
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "abcdef"
|
||||
mock_tool_file.file_key = "tk"
|
||||
mock_tool_file.mimetype = "image/png"
|
||||
|
||||
def mock_query(model):
|
||||
m = MagicMock()
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
else:
|
||||
m.where.return_value.first.return_value = mock_tool_file
|
||||
return m
|
||||
|
||||
mock_db.session.query.side_effect = mock_query
|
||||
mock_storage.load_once.return_value = b"bits"
|
||||
|
||||
result = DatasourceFileManager.get_file_binary_by_message_file_id("m")
|
||||
assert result == (b"bits", "image/png")
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db):
|
||||
# Setup common mock
|
||||
mock_query_obj = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query_obj
|
||||
mock_query_obj.where.return_value.first.return_value = None
|
||||
|
||||
# Case 1: Message file not found
|
||||
assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None
|
||||
|
||||
# Case 2: Message file found but tool file not found
|
||||
mock_message_file = MagicMock(spec=MessageFile)
|
||||
mock_message_file.url = None
|
||||
|
||||
def mock_query_v2(model):
|
||||
m = MagicMock()
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
else:
|
||||
m.where.return_value.first.return_value = None
|
||||
return m
|
||||
|
||||
mock_db.session.query.side_effect = mock_query_v2
|
||||
assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_generator_by_upload_file_id(self, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_upload_file = MagicMock(spec=UploadFile)
|
||||
mock_upload_file.key = "upload_key"
|
||||
mock_upload_file.mime_type = "text/plain"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
|
||||
mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"])
|
||||
|
||||
# Execute
|
||||
stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("upload_id")
|
||||
|
||||
# Verify
|
||||
assert mimetype == "text/plain"
|
||||
assert list(stream) == [b"chunk1", b"chunk2"]
|
||||
|
||||
# Case: Not found
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none")
|
||||
assert stream is None
|
||||
assert mimetype is None
|
||||
|
|
@ -1,9 +1,15 @@
|
|||
import types
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
|
||||
|
|
@ -15,6 +21,22 @@ def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, Non
|
|||
)
|
||||
|
||||
|
||||
def _drain_generator(gen: Generator[DatasourceMessage, None, object]) -> tuple[list[DatasourceMessage], object | None]:
|
||||
messages: list[DatasourceMessage] = []
|
||||
try:
|
||||
while True:
|
||||
messages.append(next(gen))
|
||||
except StopIteration as e:
|
||||
return messages, e.value
|
||||
|
||||
|
||||
def _invalidate_recyclable_contextvars() -> None:
|
||||
"""
|
||||
Ensure RecyclableContextVar.get() raises LookupError until reset by code under test.
|
||||
"""
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
|
||||
def test_get_icon_url_calls_runtime(mocker):
|
||||
fake_runtime = mocker.Mock()
|
||||
fake_runtime.get_icon_url.return_value = "https://icon"
|
||||
|
|
@ -30,6 +52,119 @@ def test_get_icon_url_calls_runtime(mocker):
|
|||
DatasourceManager.get_datasource_runtime.assert_called_once()
|
||||
|
||||
|
||||
def test_get_datasource_runtime_delegates_to_provider_controller(mocker):
|
||||
provider_controller = mocker.Mock()
|
||||
provider_controller.get_datasource.return_value = object()
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_plugin_provider", return_value=provider_controller)
|
||||
|
||||
runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="prov/x",
|
||||
datasource_name="ds",
|
||||
tenant_id="t1",
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
assert runtime is provider_controller.get_datasource.return_value
|
||||
provider_controller.get_datasource.assert_called_once_with("ds")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("datasource_type", "controller_path"),
|
||||
[
|
||||
(
|
||||
DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
"core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController",
|
||||
),
|
||||
(
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
"core.datasource.datasource_manager.OnlineDriveDatasourcePluginProviderController",
|
||||
),
|
||||
(
|
||||
DatasourceProviderType.WEBSITE_CRAWL,
|
||||
"core.datasource.datasource_manager.WebsiteCrawlDatasourcePluginProviderController",
|
||||
),
|
||||
(
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
"core.datasource.datasource_manager.LocalFileDatasourcePluginProviderController",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_datasource_plugin_provider_creates_controller_and_caches(mocker, datasource_type, controller_path):
|
||||
_invalidate_recyclable_contextvars()
|
||||
|
||||
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
|
||||
fetch = mocker.patch(
|
||||
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
|
||||
return_value=provider_entity,
|
||||
)
|
||||
ctrl_cls = mocker.patch(controller_path)
|
||||
|
||||
first = DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id=f"prov/{datasource_type.value}",
|
||||
tenant_id="t1",
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
second = DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id=f"prov/{datasource_type.value}",
|
||||
tenant_id="t1",
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
|
||||
assert first is second
|
||||
assert fetch.call_count == 1
|
||||
assert ctrl_cls.call_count == 1
|
||||
|
||||
|
||||
def test_get_datasource_plugin_provider_raises_when_provider_entity_missing(mocker):
|
||||
_invalidate_recyclable_contextvars()
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(DatasourceProviderNotFoundError, match="plugin provider prov/notfound not found"):
|
||||
DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id="prov/notfound",
|
||||
tenant_id="t1",
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
|
||||
|
||||
def test_get_datasource_plugin_provider_raises_for_unsupported_type(mocker):
|
||||
_invalidate_recyclable_contextvars()
|
||||
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
|
||||
return_value=provider_entity,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported datasource type"):
|
||||
DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id="prov/x",
|
||||
tenant_id="t1",
|
||||
datasource_type=types.SimpleNamespace(), # not a DatasourceProviderType at runtime
|
||||
)
|
||||
|
||||
|
||||
def test_get_datasource_plugin_provider_raises_when_controller_none(mocker):
|
||||
_invalidate_recyclable_contextvars()
|
||||
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
|
||||
return_value=provider_entity,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(DatasourceProviderNotFoundError, match="Datasource provider prov/x not found"):
|
||||
DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id="prov/x",
|
||||
tenant_id="t1",
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_online_results_yields_messages_online_document(mocker):
|
||||
# stub runtime to yield a text message
|
||||
def _doc_messages(**_):
|
||||
|
|
@ -60,6 +195,148 @@ def test_stream_online_results_yields_messages_online_document(mocker):
|
|||
assert msgs[0].message.text == "hello"
|
||||
|
||||
|
||||
def test_stream_online_results_sets_credentials_and_returns_empty_dict_online_document(mocker):
|
||||
class _Runtime:
|
||||
def __init__(self) -> None:
|
||||
self.runtime = types.SimpleNamespace(credentials=None)
|
||||
|
||||
def get_online_document_page_content(self, **_kwargs):
|
||||
yield from _gen_messages_text_only("hello")
|
||||
|
||||
runtime = _Runtime()
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
)
|
||||
|
||||
gen = DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="cred",
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
messages, final_value = _drain_generator(gen)
|
||||
|
||||
assert runtime.runtime.credentials == {"token": "t"}
|
||||
assert [m.message.text for m in messages] == ["hello"]
|
||||
assert final_value == {}
|
||||
|
||||
|
||||
def test_stream_online_results_raises_when_missing_params(mocker):
|
||||
class _Runtime:
|
||||
def __init__(self) -> None:
|
||||
self.runtime = types.SimpleNamespace(credentials=None)
|
||||
|
||||
def get_online_document_page_content(self, **_kwargs):
|
||||
yield from _gen_messages_text_only("never")
|
||||
|
||||
def online_drive_download_file(self, **_kwargs):
|
||||
yield from _gen_messages_text_only("never")
|
||||
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=_Runtime())
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="datasource_param is required for ONLINE_DOCUMENT streaming"):
|
||||
list(
|
||||
DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=None,
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="online_drive_request is required for ONLINE_DRIVE streaming"):
|
||||
list(
|
||||
DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_drive",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=None,
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_stream_online_results_yields_messages_and_returns_empty_dict_online_drive(mocker):
|
||||
class _Runtime:
|
||||
def __init__(self) -> None:
|
||||
self.runtime = types.SimpleNamespace(credentials=None)
|
||||
|
||||
def online_drive_download_file(self, **_kwargs):
|
||||
yield from _gen_messages_text_only("drive")
|
||||
|
||||
runtime = _Runtime()
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
)
|
||||
|
||||
gen = DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_drive",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="cred",
|
||||
datasource_param=None,
|
||||
online_drive_request=types.SimpleNamespace(id="fid", bucket="b"),
|
||||
)
|
||||
messages, final_value = _drain_generator(gen)
|
||||
|
||||
assert runtime.runtime.credentials == {"token": "t"}
|
||||
assert [m.message.text for m in messages] == ["drive"]
|
||||
assert final_value == {}
|
||||
|
||||
|
||||
def test_stream_online_results_raises_for_unsupported_stream_type(mocker):
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=mocker.Mock())
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported datasource type for streaming"):
|
||||
list(
|
||||
DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="website_crawl",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=None,
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_stream_node_events_emits_events_online_document(mocker):
|
||||
# make manager's low-level stream produce TEXT only
|
||||
mocker.patch.object(
|
||||
|
|
@ -93,6 +370,260 @@ def test_stream_node_events_emits_events_online_document(mocker):
|
|||
assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
def _transformed(**_kwargs):
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"),
|
||||
meta={},
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(text="hello"),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.LINK,
|
||||
message=DatasourceMessage.TextMessage(text="http://example.com"),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="a", stream=True),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="b", stream=True),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="x", variable_value=1, stream=False),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.JSON,
|
||||
message=DatasourceMessage.JsonMessage(json_object={"k": "v"}),
|
||||
meta=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
|
||||
side_effect=_transformed,
|
||||
)
|
||||
|
||||
fake_tool_file = types.SimpleNamespace(mimetype="image/png")
|
||||
|
||||
class _Session:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return fake_tool_file
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE
|
||||
)
|
||||
built = File(
|
||||
tenant_id="t1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="tool_file_1",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
storage_key="k",
|
||||
)
|
||||
build_from_mapping = mocker.patch(
|
||||
"core.datasource.datasource_manager.file_factory.build_from_mapping",
|
||||
return_value=built,
|
||||
)
|
||||
|
||||
variable_pool = mocker.Mock()
|
||||
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={"k": "v"},
|
||||
datasource_info={"info": "x"},
|
||||
variable_pool=variable_pool,
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
build_from_mapping.assert_called_once()
|
||||
variable_pool.add.assert_not_called()
|
||||
|
||||
assert any(isinstance(e, StreamChunkEvent) and e.chunk == "hello" for e in events)
|
||||
assert any(isinstance(e, StreamChunkEvent) and e.chunk.startswith("Link: http") for e in events)
|
||||
assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "a" for e in events)
|
||||
assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "b" for e in events)
|
||||
assert isinstance(events[-2], StreamChunkEvent)
|
||||
assert events[-2].is_final is True
|
||||
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
assert events[-1].node_run_result.outputs["v"] == "ab"
|
||||
assert events[-1].node_run_result.outputs["x"] == 1
|
||||
|
||||
|
||||
def test_stream_node_events_raises_when_toolfile_missing(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
def _transformed(**_kwargs):
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text="/files/datasources/missing.png"),
|
||||
meta={},
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
|
||||
side_effect=_transformed,
|
||||
)
|
||||
|
||||
class _Session:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return None
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
|
||||
|
||||
with pytest.raises(ValueError, match="ToolFile not found for file_id=missing, tenant_id=t1"):
|
||||
list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={},
|
||||
datasource_info={},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
file_in = File(
|
||||
tenant_id="t1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="tf",
|
||||
extension=".pdf",
|
||||
mime_type="application/pdf",
|
||||
storage_key="k",
|
||||
)
|
||||
|
||||
def _transformed(**_kwargs):
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.FILE,
|
||||
message=DatasourceMessage.FileMessage(file_marker="file_marker"),
|
||||
meta={"file": file_in},
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
|
||||
side_effect=_transformed,
|
||||
)
|
||||
|
||||
variable_pool = mocker.Mock()
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_drive",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={},
|
||||
datasource_info={"k": "v"},
|
||||
variable_pool=variable_pool,
|
||||
datasource_param=None,
|
||||
online_drive_request=types.SimpleNamespace(id="id", bucket="b"),
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool.add.assert_called_once()
|
||||
assert variable_pool.add.call_args[0][0] == ["nodeA", "file"]
|
||||
assert variable_pool.add.call_args[0][1] == file_in
|
||||
|
||||
completed = events[-1]
|
||||
assert isinstance(completed, StreamCompletedEvent)
|
||||
assert completed.node_run_result.outputs["file"] == file_in
|
||||
assert completed.node_run_result.outputs["datasource_type"] == DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
|
||||
def test_stream_node_events_skips_file_build_for_non_online_types(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
def _transformed(**_kwargs):
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"),
|
||||
meta={},
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
|
||||
side_effect=_transformed,
|
||||
)
|
||||
build_from_mapping = mocker.patch("core.datasource.datasource_manager.file_factory.build_from_mapping")
|
||||
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="website_crawl",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={},
|
||||
datasource_info={},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=None,
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
build_from_mapping.assert_not_called()
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
assert events[-1].node_run_result.outputs["file"] is None
|
||||
|
||||
|
||||
def test_get_upload_file_by_id_builds_file(mocker):
|
||||
# fake UploadFile row
|
||||
fake_row = types.SimpleNamespace(
|
||||
|
|
@ -133,3 +664,27 @@ def test_get_upload_file_by_id_builds_file(mocker):
|
|||
f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
|
||||
assert f.related_id == "fid"
|
||||
assert f.extension == ".txt"
|
||||
|
||||
|
||||
def test_get_upload_file_by_id_raises_when_missing(mocker):
|
||||
class _Q:
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
class _S:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def query(self, *_):
|
||||
return _Q()
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S())
|
||||
|
||||
with pytest.raises(ValueError, match="UploadFile not found for file_id=fid, tenant_id=t1"):
|
||||
DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||
from core.datasource.errors import (
|
||||
DatasourceApiSchemaError,
|
||||
DatasourceEngineInvokeError,
|
||||
DatasourceInvokeError,
|
||||
DatasourceNotFoundError,
|
||||
DatasourceNotSupportedError,
|
||||
DatasourceParameterValidationError,
|
||||
DatasourceProviderCredentialValidationError,
|
||||
DatasourceProviderNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class TestErrors:
|
||||
def test_datasource_provider_not_found_error(self):
|
||||
error = DatasourceProviderNotFoundError("Provider not found")
|
||||
assert str(error) == "Provider not found"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_not_found_error(self):
|
||||
error = DatasourceNotFoundError("Datasource not found")
|
||||
assert str(error) == "Datasource not found"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_parameter_validation_error(self):
|
||||
error = DatasourceParameterValidationError("Validation failed")
|
||||
assert str(error) == "Validation failed"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_provider_credential_validation_error(self):
|
||||
error = DatasourceProviderCredentialValidationError("Credential validation failed")
|
||||
assert str(error) == "Credential validation failed"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_not_supported_error(self):
|
||||
error = DatasourceNotSupportedError("Not supported")
|
||||
assert str(error) == "Not supported"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_invoke_error(self):
|
||||
error = DatasourceInvokeError("Invoke error")
|
||||
assert str(error) == "Invoke error"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_api_schema_error(self):
|
||||
error = DatasourceApiSchemaError("API schema error")
|
||||
assert str(error) == "API schema error"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_engine_invoke_error(self):
|
||||
mock_meta = MagicMock(spec=DatasourceInvokeMeta)
|
||||
error = DatasourceEngineInvokeError(meta=mock_meta)
|
||||
assert error.meta == mock_meta
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_datasource_engine_invoke_error_init(self):
|
||||
# Test initialization with meta
|
||||
meta = DatasourceInvokeMeta(time_cost=1.5, error="Engine failed")
|
||||
error = DatasourceEngineInvokeError(meta=meta)
|
||||
assert error.meta == meta
|
||||
assert error.meta.time_cost == 1.5
|
||||
assert error.meta.error == "Engine failed"
|
||||
|
|
@ -0,0 +1,337 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from models.tools import ToolFile
|
||||
|
||||
|
||||
class TestDatasourceFileMessageTransformer:
|
||||
def test_transform_text_and_link_messages(self):
|
||||
# Setup
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT, message=DatasourceMessage.TextMessage(text="hello")
|
||||
),
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.LINK,
|
||||
message=DatasourceMessage.TextMessage(text="https://example.com"),
|
||||
),
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0].type == DatasourceMessage.MessageType.TEXT
|
||||
assert result[0].message.text == "hello"
|
||||
assert result[1].type == DatasourceMessage.MessageType.LINK
|
||||
assert result[1].message.text == "https://example.com"
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
@patch("core.datasource.utils.message_transformer.guess_extension")
|
||||
def test_transform_image_message_success(self, mock_guess_ext, mock_tool_file_manager_cls):
|
||||
# Setup
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "file_id_123"
|
||||
mock_tool_file.mimetype = "image/png"
|
||||
mock_manager.create_file_by_url.return_value = mock_tool_file
|
||||
mock_guess_ext.return_value = ".png"
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE,
|
||||
message=DatasourceMessage.TextMessage(text="https://example.com/image.png"),
|
||||
meta={"some": "meta"},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1", conversation_id="conv1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
|
||||
assert result[0].message.text == "/files/datasources/file_id_123.png"
|
||||
assert result[0].meta == {"some": "meta"}
|
||||
mock_manager.create_file_by_url.assert_called_once_with(
|
||||
user_id="user1", tenant_id="tenant1", file_url="https://example.com/image.png", conversation_id="conv1"
|
||||
)
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
def test_transform_image_message_failure(self, mock_tool_file_manager_cls):
|
||||
# Setup
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_manager.create_file_by_url.side_effect = Exception("Download failed")
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE,
|
||||
message=DatasourceMessage.TextMessage(text="https://example.com/image.png"),
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.TEXT
|
||||
assert "Failed to download image" in result[0].message.text
|
||||
assert "Download failed" in result[0].message.text
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
@patch("core.datasource.utils.message_transformer.guess_extension")
|
||||
def test_transform_blob_message_image(self, mock_guess_ext, mock_tool_file_manager_cls):
|
||||
# Setup
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "blob_id_456"
|
||||
mock_tool_file.mimetype = "image/jpeg"
|
||||
mock_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_guess_ext.return_value = ".jpg"
|
||||
|
||||
blob_data = b"fake-image-bits"
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BLOB,
|
||||
message=DatasourceMessage.BlobMessage(blob=blob_data),
|
||||
meta={"mime_type": "image/jpeg", "file_name": "test.jpg"},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
|
||||
assert result[0].message.text == "/files/datasources/blob_id_456.jpg"
|
||||
mock_manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
@patch("core.datasource.utils.message_transformer.guess_extension")
|
||||
@patch("core.datasource.utils.message_transformer.guess_type")
|
||||
def test_transform_blob_message_binary_guess_mimetype(
|
||||
self, mock_guess_type, mock_guess_ext, mock_tool_file_manager_cls
|
||||
):
|
||||
# Setup
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "blob_id_789"
|
||||
mock_tool_file.mimetype = "application/pdf"
|
||||
mock_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_guess_type.return_value = ("application/pdf", None)
|
||||
mock_guess_ext.return_value = ".pdf"
|
||||
|
||||
blob_data = b"fake-pdf-bits"
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BLOB,
|
||||
message=DatasourceMessage.BlobMessage(blob=blob_data),
|
||||
meta={"file_name": "test.pdf"},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK
|
||||
assert result[0].message.text == "/files/datasources/blob_id_789.pdf"
|
||||
|
||||
def test_transform_blob_message_invalid_type(self):
|
||||
# Setup
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BLOB, message=DatasourceMessage.TextMessage(text="not a blob")
|
||||
)
|
||||
]
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match="unexpected message type"):
|
||||
list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
def test_transform_file_tool_file_image(self):
|
||||
# Setup
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
mock_file.related_id = "related_123"
|
||||
mock_file.extension = ".png"
|
||||
mock_file.type = FileType.IMAGE
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.FILE,
|
||||
message=DatasourceMessage.TextMessage(text="ignored"),
|
||||
meta={"file": mock_file},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
|
||||
assert result[0].message.text == "/files/datasources/related_123.png"
|
||||
|
||||
def test_transform_file_tool_file_binary(self):
|
||||
# Setup
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
mock_file.related_id = "related_456"
|
||||
mock_file.extension = ".txt"
|
||||
mock_file.type = FileType.DOCUMENT
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.FILE,
|
||||
message=DatasourceMessage.TextMessage(text="ignored"),
|
||||
meta={"file": mock_file},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.LINK
|
||||
assert result[0].message.text == "/files/datasources/related_456.txt"
|
||||
|
||||
def test_transform_file_other_transfer_method(self):
|
||||
# Setup
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
|
||||
msg = DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.FILE,
|
||||
message=DatasourceMessage.TextMessage(text="remote image"),
|
||||
meta={"file": mock_file},
|
||||
)
|
||||
messages = [msg]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0] == msg
|
||||
|
||||
def test_transform_other_message_type(self):
|
||||
# JSON type is yielded by the default 'else' block or the 'yield message' at the end
|
||||
msg = DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.JSON, message=DatasourceMessage.JsonMessage(json_object={"k": "v"})
|
||||
)
|
||||
messages = [msg]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0] == msg
|
||||
|
||||
def test_get_datasource_file_url(self):
|
||||
# Test with extension
|
||||
url = DatasourceFileMessageTransformer.get_datasource_file_url("file1", ".jpg")
|
||||
assert url == "/files/datasources/file1.jpg"
|
||||
|
||||
# Test without extension
|
||||
url = DatasourceFileMessageTransformer.get_datasource_file_url("file2", None)
|
||||
assert url == "/files/datasources/file2.bin"
|
||||
|
||||
def test_transform_blob_message_no_meta_filename(self):
|
||||
# This tests line 70 where filename might be None
|
||||
with patch("core.datasource.utils.message_transformer.ToolFileManager") as mock_tool_file_manager_cls:
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "blob_id_no_name"
|
||||
mock_tool_file.mimetype = "application/octet-stream"
|
||||
mock_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BLOB,
|
||||
message=DatasourceMessage.BlobMessage(blob=b"data"),
|
||||
meta={}, # No mime_type, no file_name
|
||||
)
|
||||
]
|
||||
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK
|
||||
assert result[0].message.text == "/files/datasources/blob_id_no_name.bin"
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
def test_transform_image_message_not_text_message(self, mock_tool_file_manager_cls):
|
||||
# This tests line 24-26 where it checks if message is instance of TextMessage
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE, message=DatasourceMessage.BlobMessage(blob=b"not-text")
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify - should yield unchanged if it's not a TextMessage
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.IMAGE
|
||||
assert isinstance(result[0].message, DatasourceMessage.BlobMessage)
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
|
||||
|
||||
class TestWebsiteCrawlDatasourcePlugin:
|
||||
@pytest.fixture
|
||||
def mock_entity(self):
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
entity.identity = MagicMock()
|
||||
entity.identity.provider = "test-provider"
|
||||
entity.identity.name = "test-name"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime(self):
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"api_key": "test-key"}
|
||||
return runtime
|
||||
|
||||
def test_init(self, mock_entity, mock_runtime):
|
||||
# Arrange
|
||||
tenant_id = "test-tenant-id"
|
||||
icon = "test-icon"
|
||||
plugin_unique_identifier = "test-plugin-id"
|
||||
|
||||
# Act
|
||||
plugin = WebsiteCrawlDatasourcePlugin(
|
||||
entity=mock_entity,
|
||||
runtime=mock_runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert plugin.tenant_id == tenant_id
|
||||
assert plugin.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert plugin.entity == mock_entity
|
||||
assert plugin.runtime == mock_runtime
|
||||
assert plugin.icon == icon
|
||||
|
||||
def test_datasource_provider_type(self, mock_entity, mock_runtime):
|
||||
# Arrange
|
||||
plugin = WebsiteCrawlDatasourcePlugin(
|
||||
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert plugin.datasource_provider_type() == DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def test_get_website_crawl(self, mock_entity, mock_runtime):
|
||||
# Arrange
|
||||
plugin = WebsiteCrawlDatasourcePlugin(
|
||||
entity=mock_entity,
|
||||
runtime=mock_runtime,
|
||||
tenant_id="test-tenant-id",
|
||||
icon="test-icon",
|
||||
plugin_unique_identifier="test-plugin-id",
|
||||
)
|
||||
|
||||
user_id = "test-user-id"
|
||||
datasource_parameters = {"url": "https://example.com"}
|
||||
provider_type = "firecrawl"
|
||||
|
||||
mock_message = MagicMock(spec=WebsiteCrawlMessage)
|
||||
|
||||
# Mock PluginDatasourceManager
|
||||
with patch("core.datasource.website_crawl.website_crawl_plugin.PluginDatasourceManager") as mock_manager_class:
|
||||
mock_manager = mock_manager_class.return_value
|
||||
mock_manager.get_website_crawl.return_value = (msg for msg in [mock_message])
|
||||
|
||||
# Act
|
||||
result = plugin.get_website_crawl(
|
||||
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, Generator)
|
||||
messages = list(result)
|
||||
assert len(messages) == 1
|
||||
assert messages[0] == mock_message
|
||||
|
||||
mock_manager.get_website_crawl.assert_called_once_with(
|
||||
tenant_id="test-tenant-id",
|
||||
user_id=user_id,
|
||||
datasource_provider="test-provider",
|
||||
datasource_name="test-name",
|
||||
credentials={"api_key": "test-key"},
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
|
||||
|
||||
class TestWebsiteCrawlDatasourcePluginProviderController:
|
||||
@pytest.fixture
|
||||
def mock_entity(self):
|
||||
entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
entity.datasources = []
|
||||
entity.identity = MagicMock()
|
||||
entity.identity.icon = "test-icon"
|
||||
return entity
|
||||
|
||||
def test_init(self, mock_entity):
|
||||
# Arrange
|
||||
plugin_id = "test-plugin-id"
|
||||
plugin_unique_identifier = "test-unique-id"
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
# Act
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id=plugin_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.plugin_id == plugin_id
|
||||
assert controller.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_provider_type(self, mock_entity):
|
||||
# Arrange
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert controller.provider_type == DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def test_get_datasource_success(self, mock_entity):
|
||||
# Arrange
|
||||
datasource_name = "test-datasource"
|
||||
tenant_id = "test-tenant-id"
|
||||
plugin_unique_identifier = "test-unique-id"
|
||||
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity = MagicMock()
|
||||
mock_datasource_entity.identity.name = datasource_name
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="test", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"core.datasource.website_crawl.website_crawl_provider.WebsiteCrawlDatasourcePlugin"
|
||||
) as mock_plugin_class:
|
||||
mock_plugin_instance = mock_plugin_class.return_value
|
||||
result = controller.get_datasource(datasource_name)
|
||||
|
||||
# Assert
|
||||
assert result == mock_plugin_instance
|
||||
mock_plugin_class.assert_called_once()
|
||||
args, kwargs = mock_plugin_class.call_args
|
||||
assert kwargs["entity"] == mock_datasource_entity
|
||||
assert isinstance(kwargs["runtime"], DatasourceRuntime)
|
||||
assert kwargs["runtime"].tenant_id == tenant_id
|
||||
assert kwargs["tenant_id"] == tenant_id
|
||||
assert kwargs["icon"] == "test-icon"
|
||||
assert kwargs["plugin_unique_identifier"] == plugin_unique_identifier
|
||||
|
||||
def test_get_datasource_not_found(self, mock_entity):
|
||||
# Arrange
|
||||
datasource_name = "non-existent"
|
||||
mock_entity.datasources = []
|
||||
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match=f"Datasource with name {datasource_name} not found"):
|
||||
controller.get_datasource(datasource_name)
|
||||
Loading…
Reference in New Issue