diff --git a/api/commands/vector.py b/api/commands/vector.py index bef18bf73b..cb7eb7c452 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,7 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -86,7 +86,7 @@ def migrate_annotation_vector_database(): dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, @@ -178,7 +178,9 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset) + .where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY) + .order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index fb98932269..507f22dcd8 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db @@ -355,7 +356,7 @@ class DatasetListApi(Resource): for item in data: # convert embedding_model_provider to plugin standard format - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: + if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]: item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: @@ -436,7 +437,7 @@ class DatasetApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) data["embedding_model_provider"] = str(provider_id) @@ -454,7 +455,7 @@ class DatasetApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": + if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: data["embedding_available"] = True @@ -485,7 +486,7 @@ class DatasetApi(Resource): current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( - payload.indexing_technique == "high_quality" + payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY and payload.embedding_model_provider is not None and payload.embedding_model is not None ): diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 074694e7ea..897724182f 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -27,6 +27,7 @@ from core.model_manager import ModelManager from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db @@ -449,7 +450,7 @@ class DatasetInitApi(Resource): raise Forbidden() knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: @@ -463,7 +464,7 @@ class DatasetInitApi(Resource): is_multimodal = DatasetService.check_is_multimodal_model( current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model ) - knowledge_config.is_multimodal = is_multimodal + knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment] except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." @@ -1337,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource): raise BadRequest("document_list cannot be empty.") # Check if dataset configuration supports summary generation - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: raise ValueError( f"Summary generation is only available for 'high_quality' indexing technique. " f"Current indexing technique: {dataset.indexing_technique}" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index fa9bc7f159..1f27989885 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -26,6 +26,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -569,7 +570,7 @@ class ChildChunkAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 89be847cd3..25b6436a71 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -15,6 +15,7 @@ from controllers.service_api.wraps import ( cloud_edition_billing_rate_limit_check, ) from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag @@ -153,9 +154,14 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: # type: ignore - item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) # type: ignore - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # type: ignore + if ( + item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index] + and item["embedding_model_provider"] # pyrefly: ignore[bad-index] + ): + item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation] + ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index] + ) + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index] if item_model in model_names: item["embedding_available"] = True # type: ignore else: @@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True @@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource): # check embedding model setting embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model - if payload.indexing_technique == "high_quality" or embedding_model_provider: + if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2e3b7fd85e..595b01a9f2 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -17,6 +17,7 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields @@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource): if not document.enabled: raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: model_manager = ModelManager() @@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 87d4772815..0bd904811a 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -4,6 +4,7 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models.dataset import Dataset from models.enums import CollectionBindingType, ConversationFromSource @@ -50,7 +51,7 @@ class AnnotationReplyFeature: dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 52776ee626..06bc366081 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -271,7 +271,7 @@ class IndexingRunner: doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, - indexing_technique: str = "economy", + indexing_technique: str = IndexTechniqueType.ECONOMY, ) -> IndexingEstimate: """ Estimate the indexing for the document. @@ -289,7 +289,7 @@ class IndexingRunner: dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") - if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": + if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, @@ -303,7 +303,7 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, @@ -573,7 +573,7 @@ class IndexingRunner: """ embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -587,7 +587,7 @@ class IndexingRunner: create_keyword_thread = None if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY ): # create keyword index create_keyword_thread = threading.Thread( @@ -597,7 +597,7 @@ class IndexingRunner: create_keyword_thread.start() max_workers = 10 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] @@ -628,7 +628,7 @@ class IndexingRunner: tokens += future.result() if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY and create_keyword_thread is not None ): create_keyword_thread.join() @@ -654,7 +654,7 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, @@ -764,7 +764,7 @@ class IndexingRunner: ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 16a5588024..cd27113245 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -6,6 +6,7 @@ from typing import Any from sqlalchemy import func, select from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db @@ -71,7 +72,7 @@ class DatasetDocumentStore: if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == "high_quality": + if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index d9145023ac..a6d1db214b 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,7 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview @@ -159,7 +160,7 @@ class IndexProcessor: tenant_id = dataset.tenant_id preview_output = self.format_preview(chunk_structure, chunks) - if indexing_technique != "high_quality": + if indexing_technique != IndexTechniqueType.HIGH_QUALITY: return preview_output if not summary_index_setting or not summary_index_setting.get("enable"): diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 80163b1707..726cc062f6 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -117,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -155,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -253,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: keyword = Keyword(dataset) keyword.add_texts(documents) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index df0761ca73..70504e6e50 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) for document in documents: child_documents = document.children @@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) @@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=True) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: all_child_documents = [] all_multimodal_documents = [] for doc in documents: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 62f88b7760..6874603a83 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor): # save node to document segment doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) else: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 78a97f79a5..52061fd93d 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -675,7 +675,7 @@ class DatasetRetrieval: # get top k top_k = retrieval_model_config["top_k"] # get retrieval method - if selected_dataset.indexing_technique == "economy": + if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY: retrieval_method = RetrievalMethod.KEYWORD_SEARCH else: retrieval_method = retrieval_model_config["search_method"] @@ -752,7 +752,7 @@ class DatasetRetrieval: "The configured knowledge base list have different indexing technique, please set reranking model." ) index_type = available_datasets[0].indexing_technique - if index_type == "high_quality": + if index_type == IndexTechniqueType.HIGH_QUALITY: embedding_model_check = all( item.embedding_model == available_datasets[0].embedding_model for item in available_datasets ) @@ -1068,7 +1068,7 @@ class DatasetRetrieval: else default_retrieval_model ) - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 31d21dbeee..6f120bd471 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,7 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService @@ -21,7 +22,7 @@ class SummaryIndex: if is_preview: with session_factory.create_session() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset or dataset.indexing_technique != "high_quality": + if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return if summary_index_setting is None: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c2b520fa99..75b923fd8b 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 429b7e6622..f3d390ed59 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict, from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(documents) document_score_list = {} - if dataset.indexing_technique != "economy": + if dataset.indexing_technique != IndexTechniqueType.ECONOMY: for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] diff --git a/api/models/dataset.py b/api/models/dataset.py index b4fb03a7f4..e323ccfd7f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -137,7 +137,7 @@ class Dataset(Base): default=DatasetPermissionEnum.ONLY_ME, ) data_source_type = mapped_column(EnumText(DataSourceType, length=255)) - indexing_technique: Mapped[str | None] = mapped_column(String(255)) + indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 65e112f1e9..969ca68545 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -21,7 +21,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.file import helpers as file_helpers from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType @@ -228,7 +228,7 @@ class DatasetService: if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid @@ -254,7 +254,10 @@ class DatasetService: retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) - dataset = Dataset(name=name, indexing_technique=indexing_technique) + dataset = Dataset( + name=name, + indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None, + ) # dataset = Dataset(name=name, provider=provider, config=config) dataset.description = description dataset.created_by = account.id @@ -349,7 +352,7 @@ class DatasetService: @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -717,13 +720,13 @@ class DatasetService: if "indexing_technique" not in data: return None if dataset.indexing_technique != data["indexing_technique"]: - if data["indexing_technique"] == "economy": + if data["indexing_technique"] == IndexTechniqueType.ECONOMY: # Remove embedding model configuration for economy mode filtered_data["embedding_model"] = None filtered_data["embedding_model_provider"] = None filtered_data["collection_binding_id"] = None return "remove" - elif data["indexing_technique"] == "high_quality": + elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: # Configure embedding model for high quality mode DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) return "add" @@ -953,8 +956,8 @@ class DatasetService: dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - dataset.indexing_technique = knowledge_configuration.indexing_technique - if knowledge_configuration.indexing_technique == "high_quality": + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error @@ -976,7 +979,7 @@ class DatasetService: embedding_model_name, ) dataset.collection_binding_id = dataset_collection_binding.id - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") @@ -991,9 +994,9 @@ class DatasetService: action = None if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: action = "add" # get embedding model setting try: @@ -1018,7 +1021,7 @@ class DatasetService: ) dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1029,7 +1032,7 @@ class DatasetService: else: # add default plugin id to both setting sets, to make sure the plugin model provider is consistent # Skip embedding model checks if not provided in the update request - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: skip_embedding_update = False try: # Handle existing model provider @@ -1089,7 +1092,7 @@ class DatasetService: ) except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() @@ -1907,8 +1910,8 @@ class DocumentService: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": + dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model @@ -2689,7 +2692,7 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: assert knowledge_config.embedding_model_provider assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -2712,7 +2715,7 @@ class DocumentService: tenant_id=tenant_id, name="", data_source_type=knowledge_config.data_source.info_list.data_source_type, - indexing_technique=knowledge_config.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique), created_by=account.id, embedding_model=knowledge_config.embedding_model, embedding_model_provider=knowledge_config.embedding_model_provider, @@ -3125,7 +3128,7 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3208,7 +3211,7 @@ class SegmentService: try: with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3230,7 +3233,7 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model: # calc embedding use tokens if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( @@ -3345,7 +3348,7 @@ class SegmentService: if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -3382,7 +3385,7 @@ class SegmentService: # When user manually provides summary, allow saving even if summary_index_setting doesn't exist # summary_index_setting is only needed for LLM generation, not for manual summary vectorization # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # Query existing summary from database from models.dataset import DocumentSegmentSummary @@ -3409,7 +3412,7 @@ class SegmentService: else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -3449,7 +3452,7 @@ class SegmentService: db.session.commit() if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -3481,7 +3484,7 @@ class SegmentService: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) # Handle summary index when content changed - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary existing_summary = ( diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index deb59da8d3..fd66d55c1a 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -22,6 +22,7 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData @@ -311,13 +312,13 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -343,7 +344,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -443,18 +444,18 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -480,7 +481,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -772,7 +773,7 @@ class RagPipelineDslService: ) case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) - if knowledge_index_entity.indexing_technique == "high_quality": + if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_index_entity.embedding_model_provider: dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 7dcfecdd1d..215a8c8528 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,7 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -105,29 +105,29 @@ class RagPipelineTransformService: if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.NOTION_IMPORT: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.WEBSITE_CRAWL: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.website-crawl-general-economy.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) @@ -170,11 +170,11 @@ class RagPipelineTransformService: ): knowledge_configuration_dict = node.get("data", {}) - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH knowledge_configuration.retrieval_model = retrieval_model else: diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 943dfc972b..ed7a33feae 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -12,6 +12,7 @@ from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document from dify_graph.model_runtime.entities.llm_entities import LLMUsage @@ -140,7 +141,7 @@ class SummaryIndexService: session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. If not provided, creates a new session and commits automatically. """ - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.warning( "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", dataset.id, @@ -724,7 +725,7 @@ class SummaryIndexService: List of created DocumentSegmentSummary instances """ # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", dataset.id, @@ -851,7 +852,7 @@ class SummaryIndexService: ) # Remove from vector database (but keep records) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: try: @@ -889,7 +890,7 @@ class SummaryIndexService: segment_ids: List of segment IDs to enable summaries for. If None, enable all. """ # Only enable summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return with session_factory.create_session() as session: @@ -981,7 +982,7 @@ class SummaryIndexService: return # Delete from vector database - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: vector = Vector(dataset) @@ -1012,7 +1013,7 @@ class SummaryIndexService: Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality """ # Only update summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return None # When user manually provides summary, allow saving even if summary_index_setting doesn't exist diff --git a/api/services/vector_service.py b/api/services/vector_service.py index b66fdd7a20..bb94a03ba3 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,7 +4,7 @@ from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document @@ -45,7 +45,7 @@ class VectorService: if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting model_manager = ModelManager() @@ -112,7 +112,7 @@ class VectorService: "dataset_id": segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) @@ -197,7 +197,7 @@ class VectorService: "dataset_id": child_segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # save vector index vector = Vector(dataset=dataset) vector.add_texts([child_document], duplicate_check=True) @@ -237,7 +237,7 @@ class VectorService: delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: delete_node_ids.append(delete_child_chunk.index_node_id) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) if delete_node_ids: @@ -252,7 +252,7 @@ class VectorService: @classmethod def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return attachments = segment.attachments diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index a9a8b892c2..dafa36cc34 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -36,7 +37,7 @@ def add_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fc6bf03454..c734e1321b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 432732af95..c9aa8fadb7 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=dataset_collection_binding.id, ) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 7b5cd46b00..41cf7ccbf6 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import exists, select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=app_annotation_setting.collection_binding_id, ) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 1fe43c3d62..2c07fe0f31 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -64,7 +65,7 @@ def enable_annotation_reply_task( old_dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=old_dataset_collection_binding.provider_name, embedding_model=old_dataset_collection_binding.model_name, collection_binding_id=old_dataset_collection_binding.id, @@ -93,7 +94,7 @@ def enable_annotation_reply_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 6ff34c0e74..f41da1d373 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -37,7 +38,7 @@ def update_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 7f810129ef..dd58378e0e 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,7 +11,7 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -120,7 +120,7 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None - if dataset_config["indexing_technique"] == "high_quality": + if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index b5794e33e2..23a80fa106 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,7 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -127,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): logger.warning("Dataset %s not found after indexing", dataset_id) return - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_index_setting = dataset.summary_index_setting if summary_index_setting and summary_index_setting.get("enable"): # expire all session to get latest document's indexing status diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index 6493833edc..e3d82d2851 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -7,6 +7,7 @@ import click from celery import shared_task from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: return # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary generation for dataset {dataset_id}: " diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index ac5d23408a..6f490ab7ea 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,7 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -53,7 +53,7 @@ def regenerate_summary_index_task( return # Only regenerate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary regeneration for dataset {dataset_id}: " diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index ea8d04502a..00d7496a40 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,7 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -39,7 +39,7 @@ class TestGetAvailableDatasetsIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -460,7 +460,7 @@ class TestKnowledgeRetrievalIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 6b35f867d7..02c3d1a80e 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -13,6 +13,7 @@ import pytest from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.enums import DataSourceType @@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory: name=name, description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 55bfb64e18..71c8874f79 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index c4d20bc02c..0702680f5c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -11,7 +11,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -63,7 +63,7 @@ class DatasetServiceIntegrationDataFactory: name: str = "Test Dataset", description: str | None = "Test description", provider: str = "vendor", - indexing_technique: str | None = "high_quality", + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, permission: str = DatasetPermissionEnum.ONLY_ME, retrieval_model: dict | None = None, embedding_model_provider: str | None = None, @@ -157,13 +157,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Economy Dataset", description=None, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "economy" + assert result.indexing_technique == IndexTechniqueType.ECONOMY assert result.embedding_model_provider is None assert result.embedding_model is None @@ -181,13 +181,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="High Quality Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name mock_model_manager.return_value.get_default_model_instance.assert_called_once_with( @@ -273,7 +273,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Dataset With Reranking", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, retrieval_model=retrieval_model, ) @@ -306,7 +306,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Custom Embedding Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, embedding_model_provider=embedding_provider, embedding_model_name=embedding_model_name, @@ -314,7 +314,7 @@ class TestDatasetServiceCreateDataset: # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) @@ -589,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure="text_model", ) DatasetServiceIntegrationDataFactory.create_document( @@ -685,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=str(uuid4()), ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": { "search_method": "full_text_search", "top_k": 10, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index 807d18322c..3cac964d89 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,7 +3,7 @@ from unittest.mock import patch from uuid import uuid4 -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -109,7 +109,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), @@ -208,7 +208,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index c4b3a57bb2..87239b2cb3 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -12,6 +12,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom @@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory: name=f"Test Dataset {uuid4()}", description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 3021d8984d..2f90d16176 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -15,6 +15,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index fd81948247..2899d5b8a5 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -4,6 +4,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings @@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory: provider: str = "vendor", name: str = "old_name", description: str = "old_description", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, retrieval_model: str = "old_model", permission: str = "only_me", embedding_model_provider: str | None = None, @@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": "new_description", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", @@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset: assert dataset.name == "new_name" assert dataset.description == "new_description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.retrieval_model == "new_model" assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" @@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": None, - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": None, "embedding_model": None, @@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, ) update_data = { - "indexing_technique": "economy", + "indexing_technique": IndexTechniqueType.ECONOMY, "retrieval_model": "new_model", } @@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "remove") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "economy" + assert dataset.indexing_technique == IndexTechniqueType.ECONOMY assert dataset.embedding_model is None assert dataset.embedding_model_provider is None assert dataset.collection_binding_id is None @@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) embedding_model = Mock() @@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", "retrieval_model": "new_model", @@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "add") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", } @@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset: db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.collection_binding_id == existing_binding_id @@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-3-small", "retrieval_model": "new_model", @@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "invalid_provider", "embedding_model": "invalid_model", "retrieval_model": "new_model", diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 1a72e3b6c2..f504f35589 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset from models.enums import DataSourceType, TagType @@ -102,7 +103,7 @@ class TestTagService: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 94173c34bf..4b04c1accb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 5ebf141828..d2e343ef52 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,7 +19,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -142,7 +142,7 @@ class TestBatchCreateSegmentToIndexTask: name=fake.company(), description=fake.text(), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 9449fee0af..1dd37fbc92 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,7 +18,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -154,7 +154,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name="test_dataset", description="Test dataset for cleanup testing", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, @@ -870,7 +870,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name=long_name, description=long_description, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph", "max_length": 10000}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 979435282b..9f8e37fc9e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,7 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -121,7 +121,7 @@ class TestCreateSegmentToIndexTask: description=fake.text(max_nb_chars=100), tenant_id=tenant_id, data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", created_by=account_id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 67f9dc7011..13ea94348a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 6fc2a53f9c..8a69707b38 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, Document, DocumentSegment, Tenant from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask: dataset.provider = "vendor" dataset.permission = "only_me" dataset.data_source_type = DataSourceType.UPLOAD_FILE - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id dataset.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index d21f1daf23..5bdf7d1389 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,7 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -100,7 +100,7 @@ class TestDisableSegmentFromIndexTask: name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index fbcb7b5264..3e9a0c8f7f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -103,7 +103,7 @@ class TestDisableSegmentsFromIndexTask: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, updated_by=account.id, embedding_model="text-embedding-ada-002", diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 10d97919fb..d4021143ef 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,7 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -57,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory: name=f"dataset-{uuid4()}", description="sync test dataset", data_source_type=DataSourceType.NOTION_IMPORT, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 9421b07285..cf1a8666f3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -5,6 +5,7 @@ import pytest from faker import Faker from core.entities.document_task import DocumentTask +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -99,7 +100,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -181,7 +182,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index c650d56091..d94abf2b40 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -64,7 +64,7 @@ class TestDocumentIndexingUpdateTask: name=fake.company(), description=fake.text(max_nb_chars=64), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 76b6a8ae73..6a8e186958 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -110,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -245,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 54b50016a8..e2f35067e3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset)