dify/api/core/datasource/datasource_manager.py

364 lines
16 KiB
Python
Raw Normal View History

import logging
from collections.abc import Generator
from threading import Lock
from typing import Any, cast
from sqlalchemy import select
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.db.session_factory import session_factory
from core.plugin.impl.datasource import PluginDatasourceManager
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from dify_graph.file import File
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
class DatasourceManager:
@classmethod
def get_datasource_plugin_provider(
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
"""
# check if context is set
try:
contexts.datasource_plugin_providers.get()
except LookupError:
contexts.datasource_plugin_providers.set({})
contexts.datasource_plugin_providers_lock.set(Lock())
with contexts.datasource_plugin_providers_lock.get():
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
if provider_id in datasource_plugin_providers:
return datasource_plugin_providers[provider_id]
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
controller: DatasourcePluginProviderController | None = None
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.ONLINE_DRIVE:
controller = OnlineDriveDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
if controller:
datasource_plugin_providers[provider_id] = controller
if controller is None:
raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.")
return controller
@classmethod
def get_datasource_runtime(
cls,
provider_id: str,
datasource_name: str,
tenant_id: str,
datasource_type: DatasourceProviderType,
) -> DatasourcePlugin:
"""
get the datasource runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param datasource_name: the name of the datasource
:param tenant_id: the tenant id
:return: the datasource plugin
"""
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)
@classmethod
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
datasource_runtime = cls.get_datasource_runtime(
provider_id=provider_id,
datasource_name=datasource_name,
tenant_id=tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
return datasource_runtime.get_icon_url(tenant_id)
@classmethod
def stream_online_results(
cls,
*,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[DatasourceMessage, None, Any]:
"""
Pull-based streaming of domain messages from datasource plugins.
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
"""
ds_type = DatasourceProviderType.value_of(datasource_type)
runtime = cls.get_datasource_runtime(
provider_id=provider_id,
datasource_name=datasource_name,
tenant_id=tenant_id,
datasource_type=ds_type,
)
dsp_service = DatasourceProviderService()
credentials = dsp_service.get_datasource_credentials(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
credential_id=credential_id,
)
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
if credentials:
doc_runtime.runtime.credentials = credentials
if datasource_param is None:
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
user_id=user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_param.workspace_id,
page_id=datasource_param.page_id,
type=datasource_param.type,
),
provider_type=ds_type,
)
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
if credentials:
drive_runtime.runtime.credentials = credentials
if online_drive_request is None:
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
inner_gen = drive_runtime.online_drive_download_file(
user_id=user_id,
request=OnlineDriveDownloadFileRequest(
id=online_drive_request.id,
bucket=online_drive_request.bucket,
),
provider_type=ds_type,
)
else:
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
# Bridge through to caller while preserving generator return contract
yield from inner_gen
# No structured final data here; node/adapter will assemble outputs
return {}
@classmethod
def stream_node_events(
cls,
*,
node_id: str,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: Any,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
ds_type = DatasourceProviderType.value_of(datasource_type)
messages = cls.stream_online_results(
user_id=user_id,
datasource_name=datasource_name,
datasource_type=datasource_type,
provider_id=provider_id,
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
credential_id=credential_id,
datasource_param=datasource_param,
online_drive_request=online_drive_request,
)
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
)
variables: dict[str, Any] = {}
file_out: File | None = None
for message in transformed:
mtype = message.type
if mtype in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
wanted_ds_type = ds_type in {
DatasourceProviderType.ONLINE_DRIVE,
DatasourceProviderType.ONLINE_DOCUMENT,
}
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
url = message.message.text
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with session_factory.create_session() as session:
stmt = select(ToolFile).where(
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
)
datasource_file = session.scalar(stmt)
if not datasource_file:
raise ValueError(
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
)
mime_type = datasource_file.mimetype
if datasource_file is not None:
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(mime_type),
"transfer_method": FileTransferMethod.TOOL_FILE,
"url": url,
}
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
elif mtype == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
elif mtype == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
)
elif mtype == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
name = message.message.variable_name
value = message.message.variable_value
if message.message.stream:
assert isinstance(value, str), "stream variable_value must be str"
variables[name] = variables.get(name, "") + value
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
else:
variables[name] = value
elif mtype == DatasourceMessage.MessageType.FILE:
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
f = message.meta.get("file")
if isinstance(f, File):
file_out = f
else:
pass
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
variable_pool.add([node_id, "file"], file_out)
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={**variables},
)
)
else:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file_out,
"datasource_type": ds_type,
},
)
)
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session:
upload_file = (
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
)
if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,
)
return file_info