dify/api/tests/unit_tests/controllers/console/datasets/test_datasets.py

1929 lines
62 KiB
Python
Raw Normal View History

import datetime
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.datasets import (
DatasetApi,
DatasetApiBaseUrlApi,
DatasetApiDeleteApi,
DatasetApiKeyApi,
DatasetAutoDisableLogApi,
DatasetEnableApiApi,
DatasetErrorDocs,
DatasetIndexingEstimateApi,
DatasetIndexingStatusApi,
DatasetListApi,
DatasetPermissionUserListApi,
DatasetQueryApi,
DatasetRelatedAppListApi,
DatasetRetrievalSettingApi,
DatasetRetrievalSettingMockApi,
DatasetUseCheckApi,
)
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.provider_manager import ProviderManager
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetPermissionService, DatasetService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDatasetList:
def _mock_dataset_dict(self, **overrides):
base = {
"id": "ds-1",
"indexing_technique": "economy",
"embedding_model": None,
"embedding_model_provider": None,
"permission": "only_me",
}
base.update(overrides)
return base
def _mock_user(self):
user = MagicMock()
user.is_dataset_editor = True
return user
def test_get_success_basic(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [self._mock_dataset_dict()]
with app.test_request_context("/datasets"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets",
return_value=(datasets, 1),
),
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=MagicMock(get_models=lambda **_: []),
),
):
resp, status = method(api)
assert status == 200
assert resp["total"] == 1
assert resp["data"][0]["embedding_available"] is True
def test_get_with_ids_filter(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [self._mock_dataset_dict()]
with app.test_request_context("/datasets?ids=1&ids=2"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets_by_ids",
return_value=(datasets, 2),
) as by_ids_mock,
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=MagicMock(get_models=lambda **_: []),
),
):
resp, status = method(api)
by_ids_mock.assert_called_once()
assert status == 200
assert resp["total"] == 2
def test_get_with_tag_ids(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [self._mock_dataset_dict()]
with app.test_request_context("/datasets?tag_ids=tag1"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets",
return_value=(datasets, 1),
),
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=MagicMock(get_models=lambda **_: []),
),
):
resp, status = method(api)
assert status == 200
def test_embedding_available_false(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [
self._mock_dataset_dict(
indexing_technique="high_quality",
embedding_model="text-embed",
embedding_model_provider="openai",
)
]
config = MagicMock()
config.get_models.return_value = [] # model not available
with app.test_request_context("/datasets"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets",
return_value=(datasets, 1),
),
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=config,
),
):
resp, status = method(api)
assert resp["data"][0]["embedding_available"] is False
def test_partial_members_permission(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [self._mock_dataset_dict(permission="partial_members")]
with app.test_request_context("/datasets"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets",
return_value=(datasets, 1),
),
patch(
"controllers.console.datasets.datasets.db.session.execute",
return_value=MagicMock(all=lambda: [("ds-1", "u1")]),
),
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=MagicMock(get_models=lambda **_: []),
),
):
resp, status = method(api)
assert resp["data"][0]["partial_member_list"] == ["u1"]
class TestDatasetListApiPost:
def test_post_success(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {
"name": "My Dataset",
"description": "desc",
"indexing_technique": "economy",
"provider": "vendor",
}
user = MagicMock()
user.is_dataset_editor = True
dataset = MagicMock()
# ---- minimal required fields for marshal ----
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context("/datasets", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasetService,
"create_empty_dataset",
return_value=dataset,
),
):
_, status = method(api)
assert status == 201
def test_post_forbidden(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {"name": "test"}
user = MagicMock()
user.is_dataset_editor = False
with (
app.test_request_context("/datasets", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {"name": "duplicate"}
user = MagicMock()
user.is_dataset_editor = True
with (
app.test_request_context("/datasets", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasetService,
"create_empty_dataset",
side_effect=services.errors.dataset.DatasetNameDuplicateError(),
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload_missing_name(self, app):
api = DatasetListApi()
method = unwrap(api.post)
with app.test_request_context("/datasets", json={}), patch.object(type(console_ns), "payload", {}):
with pytest.raises(ValueError):
method(api)
def test_post_invalid_indexing_technique(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {
"name": "bad",
"indexing_technique": "invalid-tech",
}
with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(ValueError, match="Invalid indexing technique"):
method(api)
def test_post_invalid_provider(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {
"name": "bad",
"provider": "unknown",
}
with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(ValueError, match="Invalid provider"):
method(api)
class TestDatasetApiGet:
def test_get_success_basic(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "123e4567-e89b-12d3-a456-426614174000"
user = MagicMock()
tenant_id = "tenant-1"
dataset = MagicMock()
dataset.id = dataset_id
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
dataset.permission = "only_me"
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
):
# embedding models exist → embedding_available stays True
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
data, status = method(api, dataset_id)
assert status == 200
assert data["embedding_available"] is True
def test_get_dataset_not_found(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "missing-id"
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_permission_denied(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
dataset = MagicMock()
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no access"),
),
):
with pytest.raises(Forbidden, match="no access"):
method(api, dataset_id)
def test_get_high_quality_embedding_unavailable(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
user = MagicMock()
tenant_id = "tenant-1"
dataset = MagicMock()
dataset.id = dataset_id
dataset.indexing_technique = "high_quality"
dataset.embedding_model = "text-embedding"
dataset.embedding_model_provider = "openai"
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
dataset.permission = "only_me"
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
):
# embedding model NOT configured
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
data, _ = method(api, dataset_id)
assert data["embedding_available"] is False
def test_get_partial_members_permission(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
dataset = MagicMock()
dataset.id = dataset_id
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.permission = "partial_members"
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
partial_members = [{"id": "u1"}, {"id": "u2"}]
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch.object(
DatasetPermissionService,
"get_dataset_partial_member_list",
return_value=partial_members,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
):
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
data, _ = method(api, dataset_id)
assert data["partial_member_list"] == partial_members
class TestDatasetApiPatch:
def test_patch_success_basic(self, app):
api = DatasetApi()
method = unwrap(api.patch)
dataset_id = "dataset-id"
payload = {
"name": "updated-name",
"description": "updated description",
}
user = MagicMock()
tenant_id = "tenant-1"
dataset = MagicMock()
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.permission = "only_me"
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"check_permission",
return_value=None,
),
patch.object(
DatasetService,
"update_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"get_dataset_partial_member_list",
return_value=[],
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result["partial_member_list"] == []
def test_patch_dataset_not_found(self, app):
api = DatasetApi()
method = unwrap(api.patch)
with (
app.test_request_context("/datasets/missing"),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, "missing")
def test_patch_permission_denied(self, app):
api = DatasetApi()
method = unwrap(api.patch)
dataset_id = "dataset-id"
dataset = MagicMock()
payload = {"name": "x"}
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch.object(type(console_ns), "payload", payload),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetPermissionService,
"check_permission",
side_effect=Forbidden("no permission"),
),
):
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_patch_partial_members_update(self, app):
api = DatasetApi()
method = unwrap(api.patch)
dataset_id = "dataset-id"
payload = {
"permission": "partial_members",
"partial_member_list": [{"id": "u1"}, {"id": "u2"}],
}
dataset = MagicMock()
dataset.id = dataset_id
dataset.permission = "partial_members"
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"check_permission",
return_value=None,
),
patch.object(
DatasetService,
"update_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"update_partial_member_list",
return_value=None,
),
patch.object(
DatasetPermissionService,
"get_dataset_partial_member_list",
return_value=payload["partial_member_list"],
),
):
result, _ = method(api, dataset_id)
assert result["partial_member_list"] == payload["partial_member_list"]
def test_patch_clear_partial_members(self, app):
api = DatasetApi()
method = unwrap(api.patch)
dataset_id = "dataset-id"
payload = {
"permission": "only_me",
}
dataset = MagicMock()
dataset.id = dataset_id
dataset.permission = "only_me"
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"check_permission",
return_value=None,
),
patch.object(
DatasetService,
"update_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"clear_partial_member_list",
return_value=None,
),
patch.object(
DatasetPermissionService,
"get_dataset_partial_member_list",
return_value=[],
),
):
result, _ = method(api, dataset_id)
assert result["partial_member_list"] == []
class TestDatasetApiDelete:
def test_delete_success(self, app):
api = DatasetApi()
method = unwrap(api.delete)
dataset_id = "dataset-id"
user = MagicMock()
user.has_edit_permission = True
user.is_dataset_operator = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch.object(
DatasetService,
"delete_dataset",
return_value=True,
),
patch.object(
DatasetPermissionService,
"clear_partial_member_list",
return_value=None,
),
):
result, status = method(api, dataset_id)
assert status == 204
assert result == {"result": "success"}
def test_delete_forbidden_no_permission(self, app):
api = DatasetApi()
method = unwrap(api.delete)
dataset_id = "dataset-id"
user = MagicMock()
user.has_edit_permission = False
user.is_dataset_operator = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant"),
),
):
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_delete_dataset_not_found(self, app):
api = DatasetApi()
method = unwrap(api.delete)
dataset_id = "missing-dataset"
user = MagicMock()
user.has_edit_permission = True
user.is_dataset_operator = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch.object(
DatasetService,
"delete_dataset",
return_value=False,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_delete_dataset_in_use(self, app):
api = DatasetApi()
method = unwrap(api.delete)
dataset_id = "dataset-id"
user = MagicMock()
user.has_edit_permission = True
user.is_dataset_operator = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch.object(
DatasetService,
"delete_dataset",
side_effect=services.errors.dataset.DatasetInUseError(),
),
):
with pytest.raises(DatasetInUseError):
method(api, dataset_id)
class TestDatasetUseCheckApi:
def test_get_use_check_true(self, app):
api = DatasetUseCheckApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
with (
app.test_request_context(f"/datasets/{dataset_id}/use-check"),
patch.object(
DatasetService,
"dataset_use_check",
return_value=True,
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result == {"is_using": True}
def test_get_use_check_false(self, app):
api = DatasetUseCheckApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
with (
app.test_request_context(f"/datasets/{dataset_id}/use-check"),
patch.object(
DatasetService,
"dataset_use_check",
return_value=False,
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result == {"is_using": False}
class TestDatasetQueryApi:
def test_get_queries_success(self, app):
api = DatasetQueryApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
current_user = MagicMock()
dataset = MagicMock()
dataset.id = dataset_id
queries = [MagicMock(), MagicMock()]
with (
app.test_request_context("/datasets/queries?page=1&limit=20"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch.object(
DatasetService,
"get_dataset_queries",
return_value=(queries, 2),
),
):
response, status = method(api, dataset_id)
assert status == 200
assert response["total"] == 2
assert response["page"] == 1
assert response["limit"] == 20
assert response["has_more"] is False
assert len(response["data"]) == 2
def test_get_queries_dataset_not_found(self, app):
api = DatasetQueryApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
current_user = MagicMock()
with (
app.test_request_context("/datasets/queries"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_queries_permission_denied(self, app):
api = DatasetQueryApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
current_user = MagicMock()
dataset = MagicMock()
with (
app.test_request_context("/datasets/queries"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no access"),
),
):
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_get_queries_pagination_has_more(self, app):
api = DatasetQueryApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
current_user = MagicMock()
dataset = MagicMock()
dataset.id = dataset_id
queries = [MagicMock() for _ in range(20)]
with (
app.test_request_context("/datasets/queries?page=1&limit=20"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch.object(
DatasetService,
"get_dataset_queries",
return_value=(queries, 40),
),
):
response, status = method(api, dataset_id)
assert status == 200
assert response["has_more"] is True
assert len(response["data"]) == 20
class TestDatasetIndexingEstimateApi:
def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=StorageType.LOCAL,
key="key",
name="name.txt",
size=1,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="user-1",
created_at=datetime.datetime.now(tz=datetime.UTC),
used=False,
)
upload_file.id = file_id
return upload_file
def _base_payload(self):
return {
"info_list": {
"data_source_type": "upload_file",
"file_info_list": {
"file_ids": ["file-1"],
},
},
"process_rule": {"chunk_size": 100},
"indexing_technique": "high_quality",
"doc_form": IndexStructureType.PARAGRAPH_INDEX,
"doc_language": "English",
"dataset_id": None,
}
def test_post_success_upload_file(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
payload = self._base_payload()
mock_file = self._upload_file()
mock_response = MagicMock()
mock_response.model_dump.return_value = {"tokens": 100}
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_file]),
),
patch(
"controllers.console.datasets.datasets.IndexingRunner.indexing_estimate",
return_value=mock_response,
),
):
response, status = method(api)
assert status == 200
assert response == {"tokens": 100}
def test_post_file_not_found(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
payload = self._base_payload()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: None),
),
):
with pytest.raises(NotFound):
method(api)
def test_post_llm_bad_request_error(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
payload = self._base_payload()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_file]),
),
patch(
"controllers.console.datasets.datasets.IndexingRunner.indexing_estimate",
side_effect=LLMBadRequestError(),
),
):
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_provider_token_not_init(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
payload = self._base_payload()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_file]),
),
patch(
"controllers.console.datasets.datasets.IndexingRunner.indexing_estimate",
side_effect=ProviderTokenNotInitError("token missing"),
),
):
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_generic_exception(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
payload = self._base_payload()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_file]),
),
patch(
"controllers.console.datasets.datasets.IndexingRunner.indexing_estimate",
side_effect=Exception("boom"),
),
):
with pytest.raises(IndexingEstimateError):
method(api)
class TestDatasetRelatedAppListApi:
def test_get_success(self, app):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
dataset = MagicMock()
dataset.id = "dataset-1"
app1 = MagicMock()
app2 = MagicMock()
join1 = MagicMock(app=app1)
join2 = MagicMock(app=app2)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_related_apps",
return_value=[join1, join2],
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response["total"] == 2
assert response["data"] == [app1, app2]
def test_get_dataset_not_found(self, app):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-1")
def test_get_permission_denied(self, app):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
dataset = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no permission"),
),
):
with pytest.raises(Forbidden):
method(api, "dataset-1")
def test_get_filters_none_apps(self, app):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
dataset = MagicMock()
dataset.id = "dataset-1"
app1 = MagicMock()
join1 = MagicMock(app=app1)
join2 = MagicMock(app=None)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_related_apps",
return_value=[join1, join2],
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response["total"] == 1
assert response["data"] == [app1]
class TestDatasetIndexingStatusApi:
def test_get_success_with_documents(self, app):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
document = MagicMock()
document.id = "doc-1"
document.indexing_status = "completed"
document.processing_started_at = None
document.parsing_completed_at = None
document.cleaning_completed_at = None
document.splitting_completed_at = None
document.completed_at = None
document.paused_at = None
document.error = None
document.stopped_at = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert "data" in response
assert len(response["data"]) == 1
item = response["data"][0]
assert item["completed_segments"] == 3
assert item["total_segments"] == 3
def test_get_success_no_documents(self, app):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: []),
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response == {"data": []}
def test_segment_counts_different_values(self, app):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
document = MagicMock()
document.id = "doc-1"
document.indexing_status = "indexing"
document.processing_started_at = None
document.parsing_completed_at = None
document.cleaning_completed_at = None
document.splitting_completed_at = None
document.completed_at = None
document.paused_at = None
document.error = None
document.stopped_at = None
# First count = completed segments, second = total segments
query_mock = MagicMock()
query_mock.where.side_effect = [
MagicMock(count=lambda: 2),
MagicMock(count=lambda: 5),
]
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=query_mock,
),
):
response, status = method(api, "dataset-1")
assert status == 200
item = response["data"][0]
assert item["completed_segments"] == 2
assert item["total_segments"] == 5
class TestDatasetApiKeyApi:
def test_get_api_keys_success(self, app):
api = DatasetApiKeyApi()
method = unwrap(api.get)
mock_key_1 = MagicMock(spec=ApiToken)
mock_key_2 = MagicMock(spec=ApiToken)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_key_1, mock_key_2]),
),
):
response = method(api)
assert "items" in response
assert response["items"] == [mock_key_1, mock_key_2]
def test_post_create_api_key_success(self, app):
api = DatasetApiKeyApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
),
patch(
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
return_value="dataset-abc123",
),
patch(
"controllers.console.datasets.datasets.db.session.add",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.commit",
return_value=None,
),
):
response, status = method(api)
assert status == 200
assert isinstance(response, ApiToken)
assert response.token == "dataset-abc123"
assert response.type == "dataset"
def test_post_exceed_max_keys(self, app):
api = DatasetApiKeyApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)),
),
):
with pytest.raises(BadRequest) as exc_info:
method(api)
assert exc_info.value.code == 400
assert exc_info.value.data == {
"message": "Cannot create more than 10 API keys for this resource type.",
"custom": "max_keys_exceeded",
}
class TestDatasetApiDeleteApi:
def test_delete_success(self, app):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
mock_key = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)),
),
patch(
"controllers.console.datasets.datasets.db.session.commit",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.delete",
return_value=None,
),
):
response, status = method(api, "api-key-id")
assert status == 204
assert response["result"] == "success"
def test_delete_key_not_found(self, app):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)),
),
):
with pytest.raises(NotFound):
method(api, "api-key-id")
class TestDatasetEnableApiApi:
def test_enable_api(self, app):
api = DatasetEnableApiApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.update_dataset_api_status",
return_value=None,
),
):
response, status = method(api, "dataset-1", "enable")
assert status == 200
assert response["result"] == "success"
def test_disable_api(self, app):
api = DatasetEnableApiApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.update_dataset_api_status",
return_value=None,
),
):
response, status = method(api, "dataset-1", "disable")
assert status == 200
assert response["result"] == "success"
class TestDatasetApiBaseUrlApi:
def test_get_api_base_url_from_config(self, app):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.dify_config.SERVICE_API_URL",
"https://example.com",
),
):
response = method(api)
assert response["api_base_url"] == "https://example.com/v1"
def test_get_api_base_url_from_request(self, app):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
with (
app.test_request_context("http://localhost:5000/"),
patch(
"controllers.console.datasets.datasets.dify_config.SERVICE_API_URL",
None,
),
):
response = method(api)
assert response["api_base_url"] == "http://localhost:5000/v1"
class TestDatasetRetrievalSettingApi:
def test_get_success(self, app):
api = DatasetRetrievalSettingApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.dify_config.VECTOR_STORE",
"qdrant",
),
patch(
"controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type",
return_value={"retrieval_method": ["semantic", "hybrid"]},
),
):
response = method(api)
assert "retrieval_method" in response
class TestDatasetRetrievalSettingMockApi:
def test_get_success(self, app):
api = DatasetRetrievalSettingMockApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type",
return_value={"retrieval_method": ["semantic"]},
),
):
response = method(api, "milvus")
assert response["retrieval_method"] == ["semantic"]
class TestDatasetErrorDocs:
def test_get_success(self, app):
api = DatasetErrorDocs()
method = unwrap(api.get)
dataset = MagicMock()
error_doc = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DocumentService.get_error_documents_by_dataset_id",
return_value=[error_doc],
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response["total"] == 1
def test_get_dataset_not_found(self, app):
api = DatasetErrorDocs()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-1")
class TestDatasetPermissionUserListApi:
def test_get_success(self, app):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
dataset = MagicMock()
users = [{"id": "u1"}, {"id": "u2"}]
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.DatasetPermissionService.get_dataset_partial_member_list",
return_value=users,
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response["data"] == users
def test_get_permission_denied(self, app):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
dataset = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no permission"),
),
):
with pytest.raises(Forbidden):
method(api, "dataset-1")
class TestDatasetAutoDisableLogApi:
def test_get_success(self, app):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)
dataset = MagicMock()
logs = [{"reason": "quota"}]
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset_auto_disable_logs",
return_value=logs,
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response == logs
def test_get_dataset_not_found(self, app):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-1")