diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 69c099a262..4da070bdbf 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -63,7 +63,8 @@ pnpm analyze-component --review ### File Naming -- Test files: `ComponentName.spec.tsx` (same directory as component) +- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory +- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`. - Integration tests: `web/__tests__/` directory ## Test Structure Template diff --git a/.agents/skills/frontend-testing/assets/component-test.template.tsx b/.agents/skills/frontend-testing/assets/component-test.template.tsx index 6b7803bd4b..ff38f88d23 100644 --- a/.agents/skills/frontend-testing/assets/component-test.template.tsx +++ b/.agents/skills/frontend-testing/assets/component-test.template.tsx @@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event' // Router (if component uses useRouter, usePathname, useSearchParams) // WHY: Isolates tests from Next.js routing, enables testing navigation behavior // const mockPush = vi.fn() -// vi.mock('next/navigation', () => ({ +// vi.mock('@/next/navigation', () => ({ // useRouter: () => ({ push: mockPush }), // usePathname: () => '/test-path', // })) diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 11222146cf..be2595a599 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -29,8 +29,8 @@ jobs: strategy: fail-fast: false matrix: - shardIndex: [1, 2, 3, 4] - shardTotal: [4] + shardIndex: [1, 2, 3, 4, 5, 6] + shardTotal: [6] defaults: run: shell: bash diff --git a/api/commands/vector.py b/api/commands/vector.py index 5f41d469c8..52ce26c26d 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -242,7 +243,7 @@ def migrate_knowledge_vector_database(): dataset_documents = db.session.scalars( select(DatasetDocument).where( DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == "completed", + DatasetDocument.indexing_status == IndexingStatus.COMPLETED, DatasetDocument.enabled == True, DatasetDocument.archived == False, ) @@ -254,7 +255,7 @@ def migrate_knowledge_vector_database(): segments = db.session.scalars( select(DocumentSegment).where( DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", + DocumentSegment.status == SegmentStatus.COMPLETED, DocumentSegment.enabled == True, ) ).all() @@ -430,7 +431,7 @@ def old_metadata_migration(): tenant_id=document.tenant_id, dataset_id=document.dataset_id, name=key, - type="string", + type=DatasetMetadataType.STRING, created_by=document.created_by, ) db.session.add(dataset_metadata) diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 367cb52731..3b91207545 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -1,4 +1,4 @@ -from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt +from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator from pydantic_settings import BaseSettings @@ -116,3 +116,13 @@ class RedisConfig(BaseSettings): description="Maximum connections in the Redis connection pool (unset for library default)", default=None, ) + + @field_validator("REDIS_MAX_CONNECTIONS", mode="before") + @classmethod + def _empty_string_to_none_for_max_conns(cls, v): + """Allow empty string in env/.env to mean 'unset' (None).""" + if v is None: + return None + if isinstance(v, str) and v.strip() == "": + return None + return v diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index eebba57fa3..725a8380cd 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum +from models.enums import SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -741,13 +742,15 @@ class DatasetIndexingStatusApi(Resource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ee726bc470..0c441553be 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog +from models.enums import IndexingStatus, SegmentStatus from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService @@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) .count() ) document.completed_segments = completed_segments @@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in {"completed", "error"}: + if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule @@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} extract_settings = [] for document in documents: - if document.indexing_status in {"completed", "error"}: + if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict match document.data_source_type: @@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields document_dict = { "id": document.id, - "indexing_status": "paused" if document.is_paused else document.indexing_status, + "indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status, "processing_started_at": document.processing_started_at, "parsing_completed_at": document.parsing_completed_at, "cleaning_completed_at": document.cleaning_completed_at, @@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document_id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) .count() ) # Create a dictionary with document attributes and additional fields document_dict = { "id": document.id, - "indexing_status": "paused" if document.is_paused else document.indexing_status, + "indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status, "processing_started_at": document.processing_started_at, "parsing_completed_at": document.parsing_completed_at, "cleaning_completed_at": document.cleaning_completed_at, @@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource): match action: case "pause": - if document.indexing_status != "indexing": + if document.indexing_status != IndexingStatus.INDEXING: raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id @@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource): db.session.commit() case "resume": - if document.indexing_status not in {"paused", "error"}: + if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}: raise InvalidActionError("Document not in paused or error state.") document.paused_by = None @@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource): raise ArchivedDocumentImmutableError() # 400 if document is completed - if document.indexing_status == "completed": + if document.indexing_status == IndexingStatus.COMPLETED: raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 6e0cd31b8d..4f31093cfe 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource): type = request.args.get("type", default="built-in", type=str) rag_pipeline_service = RagPipelineService() pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type) + if pipeline_template is None: + return {"error": "Pipeline template not found from upstream service."}, 404 return pipeline_template, 200 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 5a1d28ea1d..d34b4124ae 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -36,6 +36,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment +from models.enums import SegmentStatus from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( KnowledgeConfig, @@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields diff --git a/api/controllers/trigger/webhook.py b/api/controllers/trigger/webhook.py index 22b24271c6..eb579da5d4 100644 --- a/api/controllers/trigger/webhook.py +++ b/api/controllers/trigger/webhook.py @@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str): @bp.route("/webhook-debug/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]) def handle_webhook_debug(webhook_id: str): - """Handle webhook debug calls without triggering production workflow execution.""" + """Handle webhook debug calls without triggering production workflow execution. + + The debug webhook endpoint is only for draft inspection flows. It never enqueues + Celery work for the published workflow; instead it dispatches an in-memory debug + event to an active Variable Inspector listener. Returning a clear error when no + listener is registered prevents a misleading 200 response for requests that are + effectively dropped. + """ try: webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True) if error: @@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str): "method": webhook_data.get("method"), }, ) - TriggerDebugEventBus.dispatch( + dispatch_count = TriggerDebugEventBus.dispatch( tenant_id=webhook_trigger.tenant_id, event=event, pool_key=pool_key, ) + if dispatch_count == 0: + logger.warning( + "Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)", + webhook_trigger.webhook_id, + webhook_trigger.tenant_id, + webhook_trigger.app_id, + webhook_trigger.node_id, + ) + return ( + jsonify( + { + "error": "No active debug listener", + "message": ( + "The webhook debug URL only works while the Variable Inspector is listening. " + "Use the published webhook URL to execute the workflow in Celery." + ), + "execution_url": webhook_trigger.webhook_url, + } + ), + 409, + ) response_data, status_code = WebhookService.generate_webhook_response(node_config) return jsonify(response_data), status_code diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 4a8b5f3549..1bdc8df813 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner): continue result.append(self.organize_agent_user_prompt(message)) - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + agent_thoughts = message.agent_thoughts if agent_thoughts: for agent_thought in agent_thoughts: tool_names_raw = agent_thought.tool diff --git a/api/core/app/app_config/common/parameters_mapping/__init__.py b/api/core/app/app_config/common/parameters_mapping/__init__.py index 6f1a3bf045..460fdfb3ba 100644 --- a/api/core/app/app_config/common/parameters_mapping/__init__.py +++ b/api/core/app/app_config/common/parameters_mapping/__init__.py @@ -1,13 +1,36 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS +class SystemParametersDict(TypedDict): + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + file_size_limit: int + workflow_file_upload_limit: int + + +class AppParametersDict(TypedDict): + opening_statement: str | None + suggested_questions: list[str] + suggested_questions_after_answer: dict[str, Any] + speech_to_text: dict[str, Any] + text_to_speech: dict[str, Any] + retriever_resource: dict[str, Any] + annotation_reply: dict[str, Any] + more_like_this: dict[str, Any] + user_input_form: list[dict[str, Any]] + sensitive_word_avoidance: dict[str, Any] + file_upload: dict[str, Any] + system_parameters: SystemParametersDict + + def get_parameters_from_feature_dict( *, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]] -) -> Mapping[str, Any]: +) -> AppParametersDict: """ Mapping from feature dict to webapp parameters """ diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 70f43b2c83..f04a8df119 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -8,6 +8,7 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.entities.agent_entities import PlanningStrategy +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from models.model import AppMode, AppModelConfigDict from services.dataset_service import DatasetService @@ -117,8 +118,10 @@ class DatasetConfigManager: score_threshold=float(score_threshold_val) if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None else None, - reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None, - weights=weights_val if isinstance(weights_val, dict) else None, + reranking_model=cast(RerankingModelDict, reranking_model_val) + if isinstance(reranking_model_val, dict) + else None, + weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None, reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), metadata_filtering_mode=cast( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index ac21577d57..95ea70bc40 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -4,6 +4,7 @@ from typing import Any, Literal from pydantic import BaseModel, Field +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from dify_graph.file import FileUploadConfig from dify_graph.model_runtime.entities.llm_entities import LLMMode from dify_graph.model_runtime.entities.message_entities import PromptMessageRole @@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel): top_k: int | None = None score_threshold: float | None = 0.0 rerank_mode: str | None = "reranking_model" - reranking_model: dict | None = None - weights: dict | None = None + reranking_model: RerankingModelDict | None = None + weights: WeightsDict | None = None reranking_enabled: bool | None = True metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" metadata_model_config: ModelConfig | None = None diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 5665a2b76c..5509764508 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any, NewType, Union +from typing import Any, NewType, TypedDict, Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -76,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str) logger = logging.getLogger(__name__) +class AccountCreatedByDict(TypedDict): + id: str + name: str + email: str + + +class EndUserCreatedByDict(TypedDict): + id: str + user: str + + +CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict + + @dataclass(slots=True) class _NodeSnapshot: """In-memory cache for node metadata between start and completion events.""" @@ -249,19 +263,19 @@ class WorkflowResponseConverter: outputs_mapping = graph_runtime_state.outputs or {} encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping) - created_by: Mapping[str, object] | None + created_by: CreatedByDict | dict[str, object] = {} user = self._user if isinstance(user, Account): - created_by = { - "id": user.id, - "name": user.name, - "email": user.email, - } - else: - created_by = { - "id": user.id, - "user": user.session_id, - } + created_by = AccountCreatedByDict( + id=user.id, + name=user.name, + email=user.email, + ) + elif isinstance(user, EndUser): + created_by = EndUserCreatedByDict( + id=user.id, + user=user.session_id, + ) return WorkflowFinishStreamResponse( task_id=task_id, diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 3f9f3da9b2..50aed37163 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from models.dataset import Dataset +from models.enums import CollectionBindingType from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.dataset_service import DatasetCollectionBindingService @@ -43,7 +44,7 @@ class AnnotationReplyFeature: embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, "annotation" + embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION ) dataset = Dataset( diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index 843e9eea30..fc8b6c6b5a 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,3 +1,5 @@ +from typing import TypedDict + from core.tools.signature import sign_tool_file from dify_graph.file import helpers as file_helpers from dify_graph.file.enums import FileTransferMethod @@ -6,7 +8,20 @@ from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 -def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict: +class MessageFileInfoDict(TypedDict): + related_id: str + extension: str + filename: str + size: int + mime_type: str + transfer_method: str + type: str + url: str + upload_file_id: str + remote_url: str | None + + +def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict: """ Prepare file dictionary for message end stream response. diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index b054409681..8de5cb1690 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -12,7 +12,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DatasetQuerySource _logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source="app", + source=DatasetQuerySource.APP, source_app_id=self._app_id, created_by_role=( CreatorUserRole.ACCOUNT diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 0279725ff2..c6a270e470 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.engine import db +from models.enums import CredentialSourceType from models.provider import ( LoadBalancingModelConfig, Provider, @@ -546,7 +547,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) except Exception: @@ -623,7 +624,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "provider", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: @@ -1043,7 +1044,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="custom_model", + credential_source=CredentialSourceType.CUSTOM_MODEL, session=session, ) except Exception: @@ -1073,7 +1074,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "custom_model", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() @@ -1711,7 +1712,7 @@ class ProviderConfiguration(BaseModel): provider_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "custom_model" + if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL ] load_balancing_enabled = model_setting.load_balancing_enabled @@ -1769,7 +1770,7 @@ class ProviderConfiguration(BaseModel): custom_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "provider" + if config.credential_source_type != CredentialSourceType.PROVIDER ] load_balancing_enabled = model_setting.load_balancing_enabled diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 275c1fc110..52776ee626 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -40,6 +40,7 @@ from libs.datetime_utils import naive_utc_now from models import Account from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus from models.model import UploadFile from services.feature_service import FeatureService @@ -56,7 +57,7 @@ class IndexingRunner: logger.exception("consume document failed") document = db.session.get(DatasetDocument, document_id) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR error_message = getattr(error, "description", str(error)) document.error = str(error_message) document.stopped_at = naive_utc_now() @@ -219,7 +220,7 @@ class IndexingRunner: if document_segments: for document_segment in document_segments: # transform segment to node - if document_segment.status != "completed": + if document_segment.status != SegmentStatus.COMPLETED: document = Document( page_content=document_segment.content, metadata={ @@ -382,7 +383,7 @@ class IndexingRunner: data_source_info = dataset_document.data_source_info_dict text_docs = [] match dataset_document.data_source_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) @@ -395,7 +396,7 @@ class IndexingRunner: document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - case "notion_import": + case DataSourceType.NOTION_IMPORT: if ( not data_source_info or "notion_workspace_id" not in data_source_info @@ -417,7 +418,7 @@ class IndexingRunner: document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: if ( not data_source_info or "provider" not in data_source_info @@ -445,7 +446,7 @@ class IndexingRunner: # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="splitting", + after_indexing_status=IndexingStatus.SPLITTING, extra_update_params={ DatasetDocument.parsing_completed_at: naive_utc_now(), }, @@ -545,7 +546,7 @@ class IndexingRunner: Clean the document text according to the processing rules. """ rules: AutomaticRulesConfig | dict[str, Any] - if processing_rule.mode == "automatic": + if processing_rule.mode == ProcessRuleMode.AUTOMATIC: rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} @@ -636,7 +637,7 @@ class IndexingRunner: # update document status to completed self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="completed", + after_indexing_status=IndexingStatus.COMPLETED, extra_update_params={ DatasetDocument.tokens: tokens, DatasetDocument.completed_at: naive_utc_now(), @@ -659,10 +660,10 @@ class IndexingRunner: DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing", + DocumentSegment.status == SegmentStatus.INDEXING, ).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.enabled: True, DocumentSegment.completed_at: naive_utc_now(), } @@ -703,10 +704,10 @@ class IndexingRunner: DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing", + DocumentSegment.status == SegmentStatus.INDEXING, ).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.enabled: True, DocumentSegment.completed_at: naive_utc_now(), } @@ -725,7 +726,7 @@ class IndexingRunner: @staticmethod def _update_document_index_status( - document_id: str, after_indexing_status: str, extra_update_params: dict | None = None + document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None ): """ Update the document indexing status. @@ -803,7 +804,7 @@ class IndexingRunner: cur_time = naive_utc_now() self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="indexing", + after_indexing_status=IndexingStatus.INDEXING, extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, @@ -815,7 +816,7 @@ class IndexingRunner: self._update_segments_by_document( dataset_document_id=dataset_document.id, update_params={ - DocumentSegment.status: "indexing", + DocumentSegment.status: SegmentStatus.INDEXING, DocumentSegment.indexing_at: naive_utc_now(), }, ) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 2b73ef5f26..33eb5f963a 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,3 +1,5 @@ +from typing_extensions import TypedDict + from core.model_manager import ModelInstance, ModelManager from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType @@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +class RerankingModelDict(TypedDict): + reranking_provider_name: str + reranking_model_name: str + + +class VectorSettingDict(TypedDict): + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSettingDict(TypedDict): + keyword_weight: float + + +class WeightsDict(TypedDict): + vector_setting: VectorSettingDict + keyword_setting: KeywordSettingDict + + class DataPostProcessor: """Interface for data post-processing document.""" @@ -17,8 +39,8 @@ class DataPostProcessor: self, tenant_id: str, reranking_mode: str, - reranking_model: dict | None = None, - weights: dict | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, reorder_enabled: bool = False, ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) @@ -45,8 +67,8 @@ class DataPostProcessor: self, reranking_mode: str, tenant_id: str, - reranking_model: dict | None = None, - weights: dict | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, ) -> BaseRerankRunner | None: if reranking_mode == RerankMode.WEIGHTED_SCORE and weights: runner = RerankRunnerFactory.create_rerank_runner( @@ -79,12 +101,14 @@ class DataPostProcessor: return ReorderRunner() return None - def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None: + def _get_rerank_model_instance( + self, tenant_id: str, reranking_model: RerankingModelDict | None + ) -> ModelInstance | None: if reranking_model: try: model_manager = ModelManager() - reranking_provider_name = reranking_model.get("reranking_provider_name") - reranking_model_name = reranking_model.get("reranking_model_name") + reranking_provider_name = reranking_model["reranking_provider_name"] + reranking_model_name = reranking_model["reranking_model_name"] if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index e8a3a05e19..7f6ecc3d3f 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,19 +1,20 @@ import concurrent.futures import logging from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, NotRequired from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only +from typing_extensions import TypedDict from configs import dify_config from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments +from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { + +class SegmentAttachmentResult(TypedDict): + attachment_info: AttachmentInfoDict + segment_id: str + + +class SegmentAttachmentInfoResult(TypedDict): + attachment_id: str + attachment_info: AttachmentInfoDict + segment_id: str + + +class ChildChunkDetail(TypedDict): + id: str + content: str + position: int + score: float + + +class SegmentChildMapDetail(TypedDict): + max_score: float + child_chunks: list[ChildChunkDetail] + + +class SegmentRecord(TypedDict): + segment: DocumentSegment + score: NotRequired[float] + child_chunks: NotRequired[list[ChildChunkDetail]] + files: NotRequired[list[AttachmentInfoDict]] + + +class DefaultRetrievalModelDict(TypedDict): + search_method: RetrievalMethod | str + reranking_enable: bool + reranking_model: RerankingModelDict + top_k: int + score_threshold_enabled: bool + + +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -56,9 +96,9 @@ class RetrievalService: query: str, top_k: int = 4, score_threshold: float | None = 0.0, - reranking_model: dict | None = None, + reranking_model: RerankingModelDict | None = None, reranking_mode: str = "reranking_model", - weights: dict | None = None, + weights: WeightsDict | None = None, document_ids_filter: list[str] | None = None, attachment_ids: list | None = None, ): @@ -235,7 +275,7 @@ class RetrievalService: query: str, top_k: int, score_threshold: float | None, - reranking_model: dict | None, + reranking_model: RerankingModelDict | None, all_documents: list, retrieval_method: RetrievalMethod, exceptions: list, @@ -277,8 +317,8 @@ class RetrievalService: if documents: if ( reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") + and reranking_model["reranking_model_name"] + and reranking_model["reranking_provider_name"] and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH ): data_post_processor = DataPostProcessor( @@ -288,8 +328,8 @@ class RetrievalService: model_manager = ModelManager() is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, - provider=reranking_model.get("reranking_provider_name") or "", - model=reranking_model.get("reranking_model_name") or "", + provider=reranking_model["reranking_provider_name"], + model=reranking_model["reranking_model_name"], model_type=ModelType.RERANK, ) if is_support_vision: @@ -329,7 +369,7 @@ class RetrievalService: query: str, top_k: int, score_threshold: float | None, - reranking_model: dict | None, + reranking_model: RerankingModelDict | None, all_documents: list, retrieval_method: str, exceptions: list, @@ -349,8 +389,8 @@ class RetrievalService: if documents: if ( reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") + and reranking_model["reranking_model_name"] + and reranking_model["reranking_provider_name"] and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH ): data_post_processor = DataPostProcessor( @@ -459,7 +499,7 @@ class RetrievalService: segment_ids: list[str] = [] index_node_segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = [] - attachment_map: dict[str, list[dict[str, Any]]] = {} + attachment_map: dict[str, list[AttachmentInfoDict]] = {} child_chunk_map: dict[str, list[ChildChunk]] = {} doc_segment_map: dict[str, list[str]] = {} segment_summary_map: dict[str, str] = {} # Map segment_id to summary content @@ -544,12 +584,12 @@ class RetrievalService: segment_summary_map[summary.chunk_id] = summary.summary_content include_segment_ids = set() - segment_child_map: dict[str, dict[str, Any]] = {} - records: list[dict[str, Any]] = [] + segment_child_map: dict[str, SegmentChildMapDetail] = {} + records: list[SegmentRecord] = [] for segment in segments: child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) - attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) + attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, []) ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id) if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: @@ -560,14 +600,14 @@ class RetrievalService: max_score = summary_score_map.get(segment.id, 0.0) if child_chunks or attachment_infos: - child_chunk_details = [] + child_chunk_details: list[ChildChunkDetail] = [] for child_chunk in child_chunks: child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id) if child_document: child_score = child_document.metadata.get("score", 0.0) else: child_score = 0.0 - child_chunk_detail = { + child_chunk_detail: ChildChunkDetail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, @@ -580,7 +620,7 @@ class RetrievalService: if file_document: max_score = max(max_score, file_document.metadata.get("score", 0.0)) - map_detail = { + map_detail: SegmentChildMapDetail = { "max_score": max_score, "child_chunks": child_chunk_details, } @@ -593,7 +633,7 @@ class RetrievalService: "max_score": summary_score, "child_chunks": [], } - record: dict[str, Any] = { + record: SegmentRecord = { "segment": segment, } records.append(record) @@ -617,19 +657,19 @@ class RetrievalService: if file_doc: max_score = max(max_score, file_doc.metadata.get("score", 0.0)) - record = { + another_record: SegmentRecord = { "segment": segment, "score": max_score, } - records.append(record) + records.append(another_record) # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: - record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore - record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore + record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"] + record["score"] = segment_child_map[record["segment"].id]["max_score"] if record["segment"].id in attachment_map: - record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment] + record["files"] = attachment_map[record["segment"].id] result: list[RetrievalSegments] = [] for record in records: @@ -693,9 +733,9 @@ class RetrievalService: query: str | None = None, top_k: int = 4, score_threshold: float | None = 0.0, - reranking_model: dict | None = None, + reranking_model: RerankingModelDict | None = None, reranking_mode: str = "reranking_model", - weights: dict | None = None, + weights: WeightsDict | None = None, document_ids_filter: list[str] | None = None, attachment_id: str | None = None, ): @@ -807,7 +847,7 @@ class RetrievalService: @classmethod def get_segment_attachment_info( cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session - ) -> dict[str, Any] | None: + ) -> SegmentAttachmentResult | None: upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() if upload_file: attachment_binding = ( @@ -816,7 +856,7 @@ class RetrievalService: .first() ) if attachment_binding: - attachment_info = { + attachment_info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, "extension": "." + upload_file.extension, @@ -828,8 +868,10 @@ class RetrievalService: return None @classmethod - def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]: - attachment_infos = [] + def get_segment_attachment_infos( + cls, attachment_ids: list[str], session: Session + ) -> list[SegmentAttachmentInfoResult]: + attachment_infos: list[SegmentAttachmentInfoResult] = [] upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() if upload_files: upload_file_ids = [upload_file.id for upload_file in upload_files] @@ -843,7 +885,7 @@ class RetrievalService: if attachment_bindings: for upload_file in upload_files: attachment_binding = attachment_binding_map.get(upload_file.id) - attachment_info = { + info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, "extension": "." + upload_file.extension, @@ -855,7 +897,7 @@ class RetrievalService: attachment_infos.append( { "attachment_id": attachment_binding.attachment_id, - "attachment_info": attachment_info, + "attachment_info": info, "segment_id": attachment_binding.segment_id, } ) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 5ab03a1380..d29d62c93f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r document embeddings used in retrieval-augmented generation workflows. """ +import atexit import datetime import json import logging @@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None _weaviate_client_lock = threading.Lock() +def _shutdown_weaviate_client() -> None: + """ + Best-effort shutdown hook to close the module-level Weaviate client. + + This is registered with atexit so that HTTP/gRPC resources are released + when the Python interpreter exits. + """ + global _weaviate_client + + # Ensure thread-safety when accessing the shared client instance + with _weaviate_client_lock: + client = _weaviate_client + _weaviate_client = None + + if client is not None: + try: + client.close() + except Exception: + # Best-effort cleanup; log at debug level and ignore errors. + logger.debug("Failed to close Weaviate client during shutdown", exc_info=True) + + +# Register the shutdown hook once per process. +atexit.register(_shutdown_weaviate_client) + + class WeaviateConfig(BaseModel): """ Configuration model for Weaviate connection settings. @@ -85,18 +112,6 @@ class WeaviateVector(BaseVector): self._client = self._init_client(config) self._attributes = attributes - def __del__(self): - """ - Destructor to properly close the Weaviate client connection. - Prevents connection leaks and resource warnings. - """ - if hasattr(self, "_client") and self._client is not None: - try: - self._client.close() - except Exception as e: - # Ignore errors during cleanup as object is being destroyed - logger.warning("Error closing Weaviate client %s", e, exc_info=True) - def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient: """ Initializes and returns a connected Weaviate client. diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index f6834ab87b..030237559d 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -1,8 +1,18 @@ from pydantic import BaseModel +from typing_extensions import TypedDict from models.dataset import DocumentSegment +class AttachmentInfoDict(TypedDict): + id: str + name: str + extension: str + mime_type: str + source_url: str + size: int + + class RetrievalChildChunk(BaseModel): """Retrieval segments.""" @@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None - files: list[dict[str, str | int]] | None = None + files: list[AttachmentInfoDict] | None = None summary: str | None = None # Summary content if retrieved via summary index diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index a7c42c5a4e..d9145023ac 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,7 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview from models.dataset import Dataset, Document, DocumentSegment @@ -51,7 +52,7 @@ class IndexProcessor: original_document_id: str, chunks: Mapping[str, Any], batch: Any, - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, ): with session_factory.create_session() as session: document = session.query(Document).filter_by(id=document_id).first() @@ -131,7 +132,12 @@ class IndexProcessor: } def get_preview_output( - self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None + self, + chunks: Any, + dataset_id: str, + document_id: str, + chunk_structure: str, + summary_index_setting: SummaryIndexSettingDict | None, ) -> Preview: doc_language = None with session_factory.create_session() as session: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index f2191f3702..a435dfc46a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -7,14 +7,16 @@ import os import re from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, NotRequired, Optional from urllib.parse import unquote, urlparse import httpx +from typing_extensions import TypedDict from configs import dify_config from core.entities.knowledge_entities import PreviewDetail from core.helper import ssrf_proxy +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import AttachmentDocument, Document @@ -35,6 +37,13 @@ if TYPE_CHECKING: from core.model_manager import ModelInstance +class SummaryIndexSettingDict(TypedDict): + enable: bool + model_name: NotRequired[str] + model_provider_name: NotRequired[str] + summary_prompt: NotRequired[str] + + class BaseIndexProcessor(ABC): """Interface for extract files.""" @@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ @@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 9c21dad488..80163b1707 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector @@ -22,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols @@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: # Set search parameters. results = RetrievalService.retrieve( @@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ @@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): def generate_summary( tenant_id: str, text: str, - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, segment_id: str | None = None, document_language: str | None = None, ) -> tuple[str, LLMUsage]: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 367f0aec00..df0761ca73 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -11,6 +11,7 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore @@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: # Set search parameters. results = RetrievalService.retrieve( @@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 503cce2132..62f88b7760 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -15,13 +15,14 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols @@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ): # Set search parameters. results = RetrievalService.retrieve( @@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4c96b63f25..c44e9b847b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -31,7 +31,7 @@ from core.ops.utils import measure_time from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode -from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata @@ -83,7 +83,7 @@ from models.dataset import ( ) from models.dataset import Document as DatasetDocument from models.dataset import Document as DocumentModel -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DatasetQuerySource from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureService @@ -727,8 +727,8 @@ class DatasetRetrieval: top_k: int, score_threshold: float, reranking_mode: str, - reranking_model: dict | None = None, - weights: dict[str, Any] | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, reranking_enable: bool = True, message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, @@ -1008,7 +1008,7 @@ class DatasetRetrieval: dataset_query = DatasetQuery( dataset_id=dataset_id, content=json.dumps(contents), - source="app", + source=DatasetQuerySource.APP, source_app_id=app_id, created_by_role=CreatorUserRole(user_from), created_by=user_id, @@ -1181,8 +1181,8 @@ class DatasetRetrieval: hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), - reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), + reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"], + reranking_model_name=retrieve_config.reranking_model["reranking_model_name"], ) tools.append(tool) @@ -1685,8 +1685,8 @@ class DatasetRetrieval: tenant_id: str, reranking_enable: bool, reranking_mode: str, - reranking_model: dict | None, - weights: dict[str, Any] | None, + reranking_model: RerankingModelDict | None, + weights: WeightsDict | None, top_k: int, score_threshold: float, query: str | None, diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 79d7821b4e..31d21dbeee 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,7 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService from tasks.generate_summary_index_task import generate_summary_index_task @@ -11,7 +12,11 @@ logger = logging.getLogger(__name__) class SummaryIndex: def generate_and_vectorize_summary( - self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None + self, + dataset_id: str, + document_id: str, + is_preview: bool, + summary_index_setting: SummaryIndexSettingDict | None = None, ) -> None: if is_preview: with session_factory.create_session() as session: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7f7787b92a..23a877b7e3 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict): controller: ApiToolProviderController +class EmojiIconDict(TypedDict): + background: str + content: str + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -916,7 +921,7 @@ class ToolManager: ) @classmethod - def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: workflow_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) @@ -933,7 +938,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: api_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) @@ -950,7 +955,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: try: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) @@ -970,7 +975,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> str | Mapping[str, str]: + ) -> str | EmojiIconDict | dict[str, str]: """ get the tool icon diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 3dbbbe6563..c2b520fa99 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,5 +1,4 @@ import threading -from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model: dict[str, Any] = { +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 057ec41f65..429b7e6622 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,9 +1,10 @@ -from typing import Any, cast +from typing import NotRequired, TypedDict, cast from pydantic import BaseModel, Field from sqlalchemy import select from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext @@ -16,7 +17,19 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model: dict[str, Any] = { + +class DefaultRetrievalModelDict(TypedDict): + search_method: RetrievalMethod + reranking_enable: bool + reranking_model: RerankingModelDict + reranking_mode: NotRequired[str] + weights: NotRequired[WeightsDict | None] + score_threshold: NotRequired[float] + top_k: int + score_threshold_enabled: bool + + +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if metadata_condition and not document_ids_filter: return "" # get retrieval model , if the model is not setting , using default - retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] if dataset.indexing_technique == "economy": # use keyword table query diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index fc2b41d960..f7484b93fb 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError @@ -20,10 +21,18 @@ class InterfaceDict(TypedDict): operation: dict[str, Any] +class OpenAPISpecDict(TypedDict): + openapi: str + info: dict[str, str] + servers: list[dict[str, Any]] + paths: dict[str, Any] + components: dict[str, Any] + + class ApiBasedToolSchemaParser: @staticmethod def parse_openapi_to_tool_bundle( - openapi: dict, extra_info: dict | None = None, warning: dict | None = None + openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser: @staticmethod def parse_swagger_to_openapi( swagger: dict, extra_info: dict | None = None, warning: dict | None = None - ) -> dict[str, Any]: + ) -> OpenAPISpecDict: warning = warning or {} """ parse swagger to openapi @@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - converted_openapi: dict[str, Any] = { + converted_openapi: OpenAPISpecDict = { "openapi": "3.0.0", "info": { "title": info.get("title", "Swagger"), diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 8b00746268..8d2e9bf3cb 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -2,6 +2,7 @@ from typing import Literal, Union from pydantic import BaseModel +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData @@ -161,4 +162,4 @@ class KnowledgeIndexNodeData(BaseNodeData): chunk_structure: str index_chunk_variable_selector: list[str] indexing_technique: str | None = None - summary_index_setting: dict | None = None + summary_index_setting: SummaryIndexSettingDict | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 0a74847bc1..4ea9091c5b 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any from core.rag.index_processor.index_processor import IndexProcessor +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.graph_config import NodeConfigDict @@ -127,7 +128,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): is_preview: bool, batch: Any, chunks: Mapping[str, Any], - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, ): if not document_id: raise KnowledgeIndexNodeError("document_id is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 9c3b9aacbf..80f59140be 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict @@ -201,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - reranking_model = None - weights = None + reranking_model: RerankingModelDict | None = None + weights: WeightsDict | None = None match node_data.multiple_retrieval_config.reranking_mode: case "reranking_model": if node_data.multiple_retrieval_config.reranking_model: diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index f964f79582..e1311ab962 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -2,6 +2,7 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, Field +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from dify_graph.model_runtime.entities import LLMUsage from dify_graph.nodes.llm.entities import ModelConfig @@ -75,8 +76,8 @@ class KnowledgeRetrievalRequest(BaseModel): top_k: int = Field(default=0, description="Number of top results to return") score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") reranking_mode: str = Field(default="reranking_model", description="Reranking strategy") - reranking_model: dict | None = Field(default=None, description="Reranking model configuration") - weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking") + reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration") + weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking") reranking_enable: bool = Field(default=True, description="Whether reranking is enabled") attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval") diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index b17c820a80..486ae241ee 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]): timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, http_request_config=self._http_request_config, - max_retries=0, ssl_verify=self.node_data.ssl_verify, http_client=self._http_client, file_manager=self._file_manager, diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 8778f5cafe..76de5a0740 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -10,6 +10,7 @@ from events.document_index_event import document_index_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document +from models.enums import IndexingStatus logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ def handle(sender, **kwargs): if not document: raise NotFound("Document not found") - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) diff --git a/api/models/account.py b/api/models/account.py index 1a43c9ca17..5960ac6564 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase): @classmethod def get_by_openid(cls, provider: str, open_id: str): - account_integrate = ( - db.session.query(AccountIntegrate) - .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) - .one_or_none() - ) + account_integrate = db.session.execute( + select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + ).scalar_one_or_none() if account_integrate: - return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none() + return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id)) return None # check current_user.current_tenant.current_role in ['admin', 'owner'] diff --git a/api/models/dataset.py b/api/models/dataset.py index 8438fda25f..d0163e6984 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -8,6 +8,7 @@ import os import pickle import re import time +from collections.abc import Sequence from datetime import datetime from json import JSONDecodeError from typing import Any, TypedDict, cast @@ -30,7 +31,20 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode, from .account import Account from .base import Base, TypeBase from .engine import db -from .enums import CreatorUserRole +from .enums import ( + CollectionBindingType, + CreatorUserRole, + DatasetMetadataType, + DatasetQuerySource, + DatasetRuntimeMode, + DataSourceType, + DocumentCreatedFrom, + DocumentDocType, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, + SummaryStatus, +) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -120,7 +134,7 @@ class Dataset(Base): server_default=sa.text("'only_me'"), default=DatasetPermissionEnum.ONLY_ME, ) - data_source_type = mapped_column(String(255)) + data_source_type = mapped_column(EnumText(DataSourceType, length=255)) indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) @@ -137,7 +151,9 @@ class Dataset(Base): summary_index_setting = mapped_column(AdjustedJSON, nullable=True) built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) icon_info = mapped_column(AdjustedJSON, nullable=True) - runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'")) + runtime_mode = mapped_column( + EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'") + ) pipeline_id = mapped_column(StringUUID, nullable=True) chunk_structure = mapped_column(sa.String(255), nullable=True) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) @@ -145,30 +161,25 @@ class Dataset(Base): @property def total_documents(self): - return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() + return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0 @property def total_available_documents(self): return ( - db.session.query(func.count(Document.id)) - .where( - Document.dataset_id == self.id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, + db.session.scalar( + select(func.count(Document.id)).where( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) ) - .scalar() + or 0 ) @property def dataset_keyword_table(self): - dataset_keyword_table = ( - db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first() - ) - if dataset_keyword_table: - return dataset_keyword_table - - return None + return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id)) @property def index_struct_dict(self): @@ -195,64 +206,66 @@ class Dataset(Base): @property def latest_process_rule(self): - return ( - db.session.query(DatasetProcessRule) + return db.session.scalar( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == self.id) .order_by(DatasetProcessRule.created_at.desc()) - .first() + .limit(1) ) @property def app_count(self): return ( - db.session.query(func.count(AppDatasetJoin.id)) - .where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) - .scalar() + db.session.scalar( + select(func.count(AppDatasetJoin.id)).where( + AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id + ) + ) + or 0 ) @property def document_count(self): - return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() + return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0 @property def available_document_count(self): return ( - db.session.query(func.count(Document.id)) - .where( - Document.dataset_id == self.id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, + db.session.scalar( + select(func.count(Document.id)).where( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) ) - .scalar() + or 0 ) @property def available_segment_count(self): return ( - db.session.query(func.count(DocumentSegment.id)) - .where( - DocumentSegment.dataset_id == self.id, - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) ) - .scalar() + or 0 ) @property def word_count(self): - return ( - db.session.query(Document) - .with_entities(func.coalesce(func.sum(Document.word_count), 0)) - .where(Document.dataset_id == self.id) - .scalar() + return db.session.scalar( + select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id) ) @property def doc_form(self) -> str | None: if self.chunk_structure: return self.chunk_structure - document = db.session.query(Document).where(Document.dataset_id == self.id).first() + document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1)) if document: return document.doc_form return None @@ -270,8 +283,8 @@ class Dataset(Base): @property def tags(self): - tags = ( - db.session.query(Tag) + tags = db.session.scalars( + select(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .where( TagBinding.target_id == self.id, @@ -279,8 +292,7 @@ class Dataset(Base): Tag.tenant_id == self.tenant_id, Tag.type == "knowledge", ) - .all() - ) + ).all() return tags or [] @@ -288,8 +300,8 @@ class Dataset(Base): def external_knowledge_info(self): if self.provider != "external": return None - external_knowledge_binding = ( - db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first() + external_knowledge_binding = db.session.scalar( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id) ) if not external_knowledge_binding: return None @@ -310,7 +322,7 @@ class Dataset(Base): @property def is_published(self): if self.pipeline_id: - pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first() + pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id)) if pipeline: return pipeline.is_published return False @@ -382,7 +394,7 @@ class DatasetProcessRule(Base): # bug id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'")) rules = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -428,12 +440,12 @@ class Document(Base): tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) + data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False) data_source_info = mapped_column(LongText, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) batch: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_api_request_id = mapped_column(StringUUID, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -467,7 +479,9 @@ class Document(Base): stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'")) + indexing_status = mapped_column( + EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'") + ) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) @@ -478,7 +492,7 @@ class Document(Base): updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - doc_type = mapped_column(String(40), nullable=True) + doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) doc_language = mapped_column(String(255), nullable=True) @@ -521,10 +535,8 @@ class Document(Base): if self.data_source_info: if self.data_source_type == "upload_file": data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) - file_detail = ( - db.session.query(UploadFile) - .where(UploadFile.id == data_source_info_dict["upload_file_id"]) - .one_or_none() + file_detail = db.session.scalar( + select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"]) ) if file_detail: return { @@ -557,24 +569,23 @@ class Document(Base): @property def dataset(self): - return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none() + return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) @property def segment_count(self): - return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count() + return ( + db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0 + ) @property def hit_count(self): - return ( - db.session.query(DocumentSegment) - .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0)) - .where(DocumentSegment.document_id == self.id) - .scalar() + return db.session.scalar( + select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id) ) @property def uploader(self): - user = db.session.query(Account).where(Account.id == self.created_by).first() + user = db.session.scalar(select(Account).where(Account.id == self.created_by)) return user.name if user else None @property @@ -588,14 +599,13 @@ class Document(Base): @property def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None: if self.doc_metadata: - document_metadatas = ( - db.session.query(DatasetMetadata) + document_metadatas = db.session.scalars( + select(DatasetMetadata) .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) .where( DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id ) - .all() - ) + ).all() metadata_list: list[DocMetadataDetailItem] = [] for metadata in document_metadatas: metadata_dict: DocMetadataDetailItem = { @@ -791,7 +801,7 @@ class DocumentSegment(Base): enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'")) + status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'")) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -826,7 +836,7 @@ class DocumentSegment(Base): ) @property - def child_chunks(self) -> list[Any]: + def child_chunks(self) -> Sequence[Any]: if not self.document: return [] process_rule = self.document.dataset_process_rule @@ -835,16 +845,13 @@ class DocumentSegment(Base): if rules_dict: rules = Rule.model_validate(rules_dict) if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: - child_chunks = ( - db.session.query(ChildChunk) - .where(ChildChunk.segment_id == self.id) - .order_by(ChildChunk.position.asc()) - .all() - ) + child_chunks = db.session.scalars( + select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc()) + ).all() return child_chunks or [] return [] - def get_child_chunks(self) -> list[Any]: + def get_child_chunks(self) -> Sequence[Any]: if not self.document: return [] process_rule = self.document.dataset_process_rule @@ -853,12 +860,9 @@ class DocumentSegment(Base): if rules_dict: rules = Rule.model_validate(rules_dict) if rules.parent_mode: - child_chunks = ( - db.session.query(ChildChunk) - .where(ChildChunk.segment_id == self.id) - .order_by(ChildChunk.position.asc()) - .all() - ) + child_chunks = db.session.scalars( + select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc()) + ).all() return child_chunks or [] return [] @@ -1007,15 +1011,15 @@ class ChildChunk(Base): @property def dataset(self): - return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first() + return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) @property def document(self): - return db.session.query(Document).where(Document.id == self.document_id).first() + return db.session.scalar(select(Document).where(Document.id == self.document_id)) @property def segment(self): - return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first() + return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id)) class AppDatasetJoin(TypeBase): @@ -1061,7 +1065,7 @@ class DatasetQuery(TypeBase): ) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False) - source: Mapped[str] = mapped_column(String(255), nullable=False) + source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False) source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) @@ -1076,7 +1080,7 @@ class DatasetQuery(TypeBase): if isinstance(queries, list): for query in queries: if query["content_type"] == QueryType.IMAGE_QUERY: - file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first() + file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"])) if file_info: query["file_info"] = { "id": file_info.id, @@ -1141,7 +1145,7 @@ class DatasetKeywordTable(TypeBase): super().__init__(object_hook=object_hook, *args, **kwargs) # get dataset - dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() + dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) if not dataset: return None if self.data_source_type == "database": @@ -1206,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase): ) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False) + type: Mapped[str] = mapped_column( + EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False + ) collection_name: Mapped[str] = mapped_column(String(64), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1433,7 +1439,7 @@ class DatasetMetadata(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False @@ -1535,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase): @property def created_user_name(self): - account = db.session.query(Account).where(Account.id == self.created_by).first() + account = db.session.scalar(select(Account).where(Account.id == self.created_by)) if account: return account.name return "" @@ -1570,7 +1576,7 @@ class Pipeline(TypeBase): ) def retrieve_dataset(self, session: Session): - return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() + return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id)) class DocumentPipelineExecutionLog(TypeBase): @@ -1660,7 +1666,9 @@ class DocumentSegmentSummary(Base): summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'")) + status: Mapped[str] = mapped_column( + EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") + ) error: Mapped[str] = mapped_column(LongText, nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) diff --git a/api/models/enums.py b/api/models/enums.py index eb478fe02c..6499c5b443 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" + @classmethod + def _missing_(cls, value): + if value == "end-user": + return cls.END_USER + else: + return super()._missing_(value) + class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" @@ -96,3 +103,216 @@ class ConversationStatus(StrEnum): """Conversation Status Enum""" NORMAL = "normal" + + +class DataSourceType(StrEnum): + """Data Source Type for Dataset and Document""" + + UPLOAD_FILE = "upload_file" + NOTION_IMPORT = "notion_import" + WEBSITE_CRAWL = "website_crawl" + LOCAL_FILE = "local_file" + ONLINE_DOCUMENT = "online_document" + + +class ProcessRuleMode(StrEnum): + """Dataset Process Rule Mode""" + + AUTOMATIC = "automatic" + CUSTOM = "custom" + HIERARCHICAL = "hierarchical" + + +class IndexingStatus(StrEnum): + """Document Indexing Status""" + + WAITING = "waiting" + PARSING = "parsing" + CLEANING = "cleaning" + SPLITTING = "splitting" + INDEXING = "indexing" + PAUSED = "paused" + COMPLETED = "completed" + ERROR = "error" + + +class DocumentCreatedFrom(StrEnum): + """Document Created From""" + + WEB = "web" + API = "api" + RAG_PIPELINE = "rag-pipeline" + + +class ConversationFromSource(StrEnum): + """Conversation / Message from_source""" + + API = "api" + CONSOLE = "console" + + +class FeedbackFromSource(StrEnum): + """MessageFeedback from_source""" + + USER = "user" + ADMIN = "admin" + + +class InvokeFrom(StrEnum): + """How a conversation/message was invoked""" + + SERVICE_API = "service-api" + WEB_APP = "web-app" + TRIGGER = "trigger" + EXPLORE = "explore" + DEBUGGER = "debugger" + PUBLISHED_PIPELINE = "published" + VALIDATION = "validation" + + @classmethod + def value_of(cls, value: str) -> "InvokeFrom": + return cls(value) + + def to_source(self) -> str: + source_mapping = { + InvokeFrom.WEB_APP: "web_app", + InvokeFrom.DEBUGGER: "dev", + InvokeFrom.EXPLORE: "explore_app", + InvokeFrom.TRIGGER: "trigger", + InvokeFrom.SERVICE_API: "api", + } + return source_mapping.get(self, "dev") + + +class DocumentDocType(StrEnum): + """Document doc_type classification""" + + BOOK = "book" + WEB_PAGE = "web_page" + PAPER = "paper" + SOCIAL_MEDIA_POST = "social_media_post" + WIKIPEDIA_ENTRY = "wikipedia_entry" + PERSONAL_DOCUMENT = "personal_document" + BUSINESS_DOCUMENT = "business_document" + IM_CHAT_LOG = "im_chat_log" + SYNCED_FROM_NOTION = "synced_from_notion" + SYNCED_FROM_GITHUB = "synced_from_github" + OTHERS = "others" + + +class TagType(StrEnum): + """Tag type""" + + KNOWLEDGE = "knowledge" + APP = "app" + + +class DatasetMetadataType(StrEnum): + """Dataset metadata value type""" + + STRING = "string" + NUMBER = "number" + TIME = "time" + + +class SegmentStatus(StrEnum): + """Document segment status""" + + WAITING = "waiting" + INDEXING = "indexing" + COMPLETED = "completed" + ERROR = "error" + PAUSED = "paused" + RE_SEGMENT = "re_segment" + + +class DatasetRuntimeMode(StrEnum): + """Dataset runtime mode""" + + GENERAL = "general" + RAG_PIPELINE = "rag_pipeline" + + +class CollectionBindingType(StrEnum): + """Dataset collection binding type""" + + DATASET = "dataset" + ANNOTATION = "annotation" + + +class DatasetQuerySource(StrEnum): + """Dataset query source""" + + HIT_TESTING = "hit_testing" + APP = "app" + + +class TidbAuthBindingStatus(StrEnum): + """TiDB auth binding status""" + + CREATING = "CREATING" + ACTIVE = "ACTIVE" + + +class MessageFileBelongsTo(StrEnum): + """MessageFile belongs_to""" + + USER = "user" + ASSISTANT = "assistant" + + +class CredentialSourceType(StrEnum): + """Load balancing credential source type""" + + PROVIDER = "provider" + CUSTOM_MODEL = "custom_model" + + +class PaymentStatus(StrEnum): + """Provider order payment status""" + + WAIT_PAY = "wait_pay" + PAID = "paid" + FAILED = "failed" + REFUNDED = "refunded" + + +class BannerStatus(StrEnum): + """ExporleBanner status""" + + ENABLED = "enabled" + DISABLED = "disabled" + + +class SummaryStatus(StrEnum): + """Document segment summary status""" + + NOT_STARTED = "not_started" + GENERATING = "generating" + COMPLETED = "completed" + ERROR = "error" + TIMEOUT = "timeout" + + +class MessageChainType(StrEnum): + """Message chain type""" + + SYSTEM = "system" + + +class ProviderQuotaType(StrEnum): + PAID = "paid" + """hosted paid quota""" + + FREE = "free" + """third-party free quota""" + + TRIAL = "trial" + """hosted trial quota""" + + @staticmethod + def value_of(value: str) -> "ProviderQuotaType": + for member in ProviderQuotaType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") diff --git a/api/models/model.py b/api/models/model.py index 2e747df2c7..fe70fcd401 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -380,13 +380,12 @@ class App(Base): @property def site(self) -> Site | None: - site = db.session.query(Site).where(Site.app_id == self.id).first() - return site + return db.session.scalar(select(Site).where(Site.app_id == self.id)) @property def app_model_config(self) -> AppModelConfig | None: if self.app_model_config_id: - return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() + return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)) return None @@ -395,7 +394,7 @@ class App(Base): if self.workflow_id: from .workflow import Workflow - return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id)) return None @@ -405,8 +404,7 @@ class App(Base): @property def tenant(self) -> Tenant | None: - tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() - return tenant + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) @property def is_agent(self) -> bool: @@ -546,9 +544,9 @@ class App(Base): return deleted_tools @property - def tags(self) -> list[Tag]: - tags = ( - db.session.query(Tag) + def tags(self) -> Sequence[Tag]: + tags = db.session.scalars( + select(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .where( TagBinding.target_id == self.id, @@ -556,15 +554,14 @@ class App(Base): Tag.tenant_id == self.tenant_id, Tag.type == "app", ) - .all() - ) + ).all() return tags or [] @property def author_name(self) -> str | None: if self.created_by: - account = db.session.query(Account).where(Account.id == self.created_by).first() + account = db.session.scalar(select(Account).where(Account.id == self.created_by)) if account: return account.name @@ -616,8 +613,7 @@ class AppModelConfig(TypeBase): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property def model_dict(self) -> ModelConfig: @@ -652,8 +648,8 @@ class AppModelConfig(TypeBase): @property def annotation_reply_dict(self) -> AnnotationReplyConfig: - annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id) ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail @@ -845,8 +841,7 @@ class RecommendedApp(Base): # bug @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) class InstalledApp(TypeBase): @@ -873,13 +868,11 @@ class InstalledApp(TypeBase): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property def tenant(self) -> Tenant | None: - tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() - return tenant + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) class TrialApp(Base): @@ -899,8 +892,7 @@ class TrialApp(Base): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) class AccountTrialAppRecord(Base): @@ -919,13 +911,11 @@ class AccountTrialAppRecord(Base): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property def user(self) -> Account | None: - user = db.session.query(Account).where(Account.id == self.account_id).first() - return user + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class ExporleBanner(TypeBase): @@ -1117,8 +1107,8 @@ class Conversation(Base): else: model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key] else: - app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() + app_model_config = db.session.scalar( + select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id) ) if app_model_config: model_config = app_model_config.to_dict() @@ -1141,36 +1131,43 @@ class Conversation(Base): @property def annotated(self): - return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0 + return ( + db.session.scalar( + select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id) + ) + or 0 + ) > 0 @property def annotation(self): - return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first() + return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1)) @property def message_count(self): - return db.session.query(Message).where(Message.conversation_id == self.id).count() + return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0 @property def user_feedback_stats(self): like = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "user", - MessageFeedback.rating == "like", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "like", + ) ) - .count() + or 0 ) dislike = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "user", - MessageFeedback.rating == "dislike", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "dislike", + ) ) - .count() + or 0 ) return {"like": like, "dislike": dislike} @@ -1178,23 +1175,25 @@ class Conversation(Base): @property def admin_feedback_stats(self): like = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "admin", - MessageFeedback.rating == "like", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "like", + ) ) - .count() + or 0 ) dislike = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "admin", - MessageFeedback.rating == "dislike", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "dislike", + ) ) - .count() + or 0 ) return {"like": like, "dislike": dislike} @@ -1256,22 +1255,19 @@ class Conversation(Base): @property def first_message(self): - return ( - db.session.query(Message) - .where(Message.conversation_id == self.id) - .order_by(Message.created_at.asc()) - .first() + return db.session.scalar( + select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc()) ) @property def app(self) -> App | None: with Session(db.engine, expire_on_commit=False) as session: - return session.query(App).where(App.id == self.app_id).first() + return session.scalar(select(App).where(App.id == self.app_id)) @property def from_end_user_session_id(self): if self.from_end_user_id: - end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id)) if end_user: return end_user.session_id @@ -1280,7 +1276,7 @@ class Conversation(Base): @property def from_account_name(self) -> str | None: if self.from_account_id: - account = db.session.query(Account).where(Account.id == self.from_account_id).first() + account = db.session.scalar(select(Account).where(Account.id == self.from_account_id)) if account: return account.name @@ -1505,21 +1501,15 @@ class Message(Base): @property def user_feedback(self): - feedback = ( - db.session.query(MessageFeedback) - .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") - .first() + return db.session.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") ) - return feedback @property def admin_feedback(self): - feedback = ( - db.session.query(MessageFeedback) - .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") - .first() + return db.session.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") ) - return feedback @property def feedbacks(self): @@ -1528,28 +1518,27 @@ class Message(Base): @property def annotation(self): - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first() + annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id)) return annotation @property def annotation_hit_history(self): - annotation_history = ( - db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first() + annotation_history = db.session.scalar( + select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id) ) if annotation_history: - annotation = ( - db.session.query(MessageAnnotation) - .where(MessageAnnotation.id == annotation_history.annotation_id) - .first() + return db.session.scalar( + select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id) ) - return annotation return None @property def app_model_config(self): - conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first() + conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id)) if conversation: - return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first() + return db.session.scalar( + select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id) + ) return None @@ -1562,13 +1551,12 @@ class Message(Base): return json.loads(self.message_metadata) if self.message_metadata else {} @property - def agent_thoughts(self) -> list[MessageAgentThought]: - return ( - db.session.query(MessageAgentThought) + def agent_thoughts(self) -> Sequence[MessageAgentThought]: + return db.session.scalars( + select(MessageAgentThought) .where(MessageAgentThought.message_id == self.id) .order_by(MessageAgentThought.position.asc()) - .all() - ) + ).all() @property def retriever_resources(self) -> Any: @@ -1579,7 +1567,7 @@ class Message(Base): from factories import file_factory message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() - current_app = db.session.query(App).where(App.id == self.app_id).first() + current_app = db.session.scalar(select(App).where(App.id == self.app_id)) if not current_app: raise ValueError(f"App {self.app_id} not found") @@ -1743,8 +1731,7 @@ class MessageFeedback(TypeBase): @property def from_account(self) -> Account | None: - account = db.session.query(Account).where(Account.id == self.from_account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.from_account_id)) def to_dict(self) -> MessageFeedbackDict: return { @@ -1817,13 +1804,11 @@ class MessageAnnotation(Base): @property def account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) @property def annotation_create_account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class AppAnnotationHitHistory(TypeBase): @@ -1852,18 +1837,15 @@ class AppAnnotationHitHistory(TypeBase): @property def account(self): - account = ( - db.session.query(Account) + return db.session.scalar( + select(Account) .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) .where(MessageAnnotation.id == self.annotation_id) - .first() ) - return account @property def annotation_create_account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class AppAnnotationSetting(TypeBase): @@ -1896,12 +1878,9 @@ class AppAnnotationSetting(TypeBase): def collection_binding_detail(self): from .dataset import DatasetCollectionBinding - collection_binding_detail = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == self.collection_binding_id) - .first() + return db.session.scalar( + select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id) ) - return collection_binding_detail class OperationLog(TypeBase): @@ -2007,7 +1986,9 @@ class AppMCPServer(TypeBase): def generate_server_code(n: int) -> str: while True: result = generate_string(n) - while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: + while ( + db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0 + ) > 0: result = generate_string(n) return result @@ -2068,7 +2049,7 @@ class Site(Base): def generate_code(n: int) -> str: while True: result = generate_string(n) - while db.session.query(Site).where(Site.code == result).count() > 0: + while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0: result = generate_string(n) return result diff --git a/api/models/provider.py b/api/models/provider.py index 18a0fe92c8..4e114bb034 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,13 +6,14 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, String, func, text +from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db +from .enums import CredentialSourceType, PaymentStatus from .types import EnumText, LongText, StringUUID @@ -96,7 +97,7 @@ class Provider(TypeBase): @cached_property def credential(self): if self.credential_id: - return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first() + return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id)) @property def credential_name(self): @@ -159,10 +160,8 @@ class ProviderModel(TypeBase): @cached_property def credential(self): if self.credential_id: - return ( - db.session.query(ProviderModelCredential) - .where(ProviderModelCredential.id == self.credential_id) - .first() + return db.session.scalar( + select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id) ) @property @@ -239,7 +238,9 @@ class ProviderOrder(TypeBase): quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) currency: Mapped[str | None] = mapped_column(String(40)) total_amount: Mapped[int | None] = mapped_column(sa.Integer) - payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'")) + payment_status: Mapped[PaymentStatus] = mapped_column( + EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'") + ) paid_at: Mapped[datetime | None] = mapped_column(DateTime) pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) refunded_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase): name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) - credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None) + credential_source_type: Mapped[CredentialSourceType | None] = mapped_column( + EnumText(CredentialSourceType, length=40), nullable=True, default=None + ) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/tools.py b/api/models/tools.py index e7b98dcf27..c09f054e7d 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -8,7 +8,7 @@ from uuid import uuid4 import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey, String, func +from sqlalchemy import ForeignKey, String, func, select from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject @@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase): def user(self) -> Account | None: if not self.user_id: return None - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) class ToolLabelBinding(TypeBase): @@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase): @property def user(self) -> Account | None: - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: @@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase): @property def app(self) -> App | None: - return db.session.query(App).where(App.id == self.app_id).first() + return db.session.scalar(select(App).where(App.id == self.app_id)) class MCPToolProvider(TypeBase): @@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase): encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) def load_user(self) -> Account | None: - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def credentials(self) -> dict[str, Any]: diff --git a/api/models/web.py b/api/models/web.py index a1cc11c375..1fb37340d7 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -2,7 +2,7 @@ from datetime import datetime from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, func +from sqlalchemy import DateTime, func, select from sqlalchemy.orm import Mapped, mapped_column from .base import TypeBase @@ -38,7 +38,7 @@ class SavedMessage(TypeBase): @property def message(self): - return db.session.query(Message).where(Message.id == self.message_id).first() + return db.session.scalar(select(Message).where(Message.id == self.message_id)) class PinnedConversation(TypeBase): diff --git a/api/models/workflow.py b/api/models/workflow.py index f2e8305758..9bb249481f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -679,14 +679,14 @@ class WorkflowRun(Base): def message(self): from .model import Message - return ( - db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + return db.session.scalar( + select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id) ) @property @deprecated("This method is retained for historical reasons; avoid using it if possible.") def workflow(self): - return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id)) def to_dict(self): return { diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index 9a76de1927..ad3c1e8389 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -182,4 +182,9 @@ tasks/app_generate/workflow_execute_task.py tasks/regenerate_summary_index_task.py tasks/trigger_processing_tasks.py tasks/workflow_cfs_scheduler/cfs_scheduler.py +tasks/add_document_to_index_task.py +tasks/create_segment_to_index_task.py +tasks/disable_segment_from_index_task.py +tasks/enable_segment_to_index_task.py +tasks/remove_document_from_index_task.py tasks/workflow_execution_tasks.py diff --git a/api/services/agent_service.py b/api/services/agent_service.py index b2db895a5a..2b8a3ee594 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager from extensions.ext_database import db from libs.login import current_user from models import Account -from models.model import App, Conversation, EndUser, Message, MessageAgentThought +from models.model import App, Conversation, EndUser, Message class AgentService: @@ -47,7 +47,7 @@ class AgentService: if not message: raise ValueError(f"Message not found: {message_id}") - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + agent_thoughts = message.agent_thoughts if conversation.from_end_user_id: # only select name field diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index c527c71d7b..cdab90a3dc 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -51,6 +51,14 @@ from models.dataset import ( Pipeline, SegmentAttachmentBinding, ) +from models.enums import ( + DatasetRuntimeMode, + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, +) from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding @@ -319,7 +327,7 @@ class DatasetService: description=rag_pipeline_dataset_create_entity.description, permission=rag_pipeline_dataset_create_entity.permission, provider="vendor", - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), created_by=current_user.id, pipeline_id=pipeline.id, @@ -614,7 +622,7 @@ class DatasetService: """ Update pipeline knowledge base node data. """ - if dataset.runtime_mode != "rag_pipeline": + if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE: return pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() @@ -1229,10 +1237,15 @@ class DocumentService: "enabled": "available", } - _INDEXING_STATUSES: tuple[str, ...] = ("parsing", "cleaning", "splitting", "indexing") + _INDEXING_STATUSES: tuple[IndexingStatus, ...] = ( + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + ) DISPLAY_STATUS_FILTERS: dict[str, tuple[Any, ...]] = { - "queuing": (Document.indexing_status == "waiting",), + "queuing": (Document.indexing_status == IndexingStatus.WAITING,), "indexing": ( Document.indexing_status.in_(_INDEXING_STATUSES), Document.is_paused.is_not(True), @@ -1241,19 +1254,19 @@ class DocumentService: Document.indexing_status.in_(_INDEXING_STATUSES), Document.is_paused.is_(True), ), - "error": (Document.indexing_status == "error",), + "error": (Document.indexing_status == IndexingStatus.ERROR,), "available": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(False), Document.enabled.is_(True), ), "disabled": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(False), Document.enabled.is_(False), ), "archived": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(True), ), } @@ -1536,7 +1549,7 @@ class DocumentService: """ Normalize and validate `Document -> UploadFile` linkage for download flows. """ - if document.data_source_type != "upload_file": + if document.data_source_type != DataSourceType.UPLOAD_FILE: raise NotFound(invalid_source_message) data_source_info: dict[str, Any] = document.data_source_info_dict or {} @@ -1617,7 +1630,7 @@ class DocumentService: select(Document).where( Document.id.in_(document_ids), Document.enabled == True, - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived == False, ) ).all() @@ -1640,7 +1653,7 @@ class DocumentService: select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived == False, ) ).all() @@ -1650,7 +1663,10 @@ class DocumentService: @staticmethod def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: documents = db.session.scalars( - select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + select(Document).where( + Document.dataset_id == dataset_id, + Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]), + ) ).all() return documents @@ -1683,7 +1699,7 @@ class DocumentService: def delete_document(document): # trigger document_was_deleted signal file_id = None - if document.data_source_type == "upload_file": + if document.data_source_type == DataSourceType.UPLOAD_FILE: if document.data_source_info: data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: @@ -1704,7 +1720,7 @@ class DocumentService: file_ids = [ document.data_source_info_dict.get("upload_file_id", "") for document in documents - if document.data_source_type == "upload_file" and document.data_source_info_dict + if document.data_source_type == DataSourceType.UPLOAD_FILE and document.data_source_info_dict ] # Delete documents first, then dispatch cleanup task after commit @@ -1753,7 +1769,13 @@ class DocumentService: @staticmethod def pause_document(document): - if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: + if document.indexing_status not in { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + }: raise DocumentIndexingError() # update document to be paused assert current_user is not None @@ -1793,7 +1815,7 @@ class DocumentService: if cache_result is not None: raise ValueError("Document is being retried, please try again later") # retry document indexing - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) db.session.commit() @@ -1812,7 +1834,7 @@ class DocumentService: if cache_result is not None: raise ValueError("Document is being synced, please try again later") # sync document indexing - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING data_source_info = document.data_source_info_dict if data_source_info: data_source_info["mode"] = "scrape" @@ -1840,7 +1862,7 @@ class DocumentService: knowledge_config: KnowledgeConfig, account: Account | Any, dataset_process_rule: DatasetProcessRule | None = None, - created_from: str = "web", + created_from: str = DocumentCreatedFrom.WEB, ) -> tuple[list[Document], str]: # check doc_form DatasetService.check_doc_form(dataset, knowledge_config.doc_form) @@ -1932,7 +1954,7 @@ class DocumentService: if not dataset_process_rule: process_rule = knowledge_config.process_rule if process_rule: - if process_rule.mode in ("custom", "hierarchical"): + if process_rule.mode in (ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL): if process_rule.rules: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, @@ -1944,7 +1966,7 @@ class DocumentService: dataset_process_rule = dataset.latest_process_rule if not dataset_process_rule: raise ValueError("No process rule found.") - elif process_rule.mode == "automatic": + elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, @@ -1967,7 +1989,7 @@ class DocumentService: if not dataset_process_rule: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -2001,7 +2023,7 @@ class DocumentService: .where( Document.dataset_id == dataset.id, Document.tenant_id == current_user.current_tenant_id, - Document.data_source_type == "upload_file", + Document.data_source_type == DataSourceType.UPLOAD_FILE, Document.enabled == True, Document.name.in_(file_names), ) @@ -2021,7 +2043,7 @@ class DocumentService: document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) documents.append(document) duplicate_document_ids.append(document.id) @@ -2056,7 +2078,7 @@ class DocumentService: .filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, enabled=True, ) .all() @@ -2507,7 +2529,7 @@ class DocumentService: document_data: KnowledgeConfig, account: Account, dataset_process_rule: DatasetProcessRule | None = None, - created_from: str = "web", + created_from: str = DocumentCreatedFrom.WEB, ): assert isinstance(current_user, Account) @@ -2520,14 +2542,14 @@ class DocumentService: # save process rule if document_data.process_rule: process_rule = document_data.process_rule - if process_rule.mode in {"custom", "hierarchical"}: + if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) - elif process_rule.mode == "automatic": + elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, @@ -2609,7 +2631,7 @@ class DocumentService: if document_data.name: document.name = document_data.name # update document to be waiting - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING document.completed_at = None document.processing_started_at = None document.parsing_completed_at = None @@ -2623,7 +2645,7 @@ class DocumentService: # update document segment db.session.query(DocumentSegment).filter_by(document_id=document.id).update( - {DocumentSegment.status: "re_segment"} + {DocumentSegment.status: SegmentStatus.RE_SEGMENT} ) db.session.commit() # trigger async task @@ -2754,7 +2776,7 @@ class DocumentService: if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if knowledge_config.process_rule.mode == "automatic": + if knowledge_config.process_rule.mode == ProcessRuleMode.AUTOMATIC: knowledge_config.process_rule.rules = None else: if not knowledge_config.process_rule.rules: @@ -2785,7 +2807,7 @@ class DocumentService: raise ValueError("Process rule segmentation separator is invalid") if not ( - knowledge_config.process_rule.mode == "hierarchical" + knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL and knowledge_config.process_rule.rules.parent_mode == "full-doc" ): if not knowledge_config.process_rule.rules.segmentation.max_tokens: @@ -2814,7 +2836,7 @@ class DocumentService: if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args["process_rule"]["mode"] == "automatic": + if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC: args["process_rule"]["rules"] = {} else: if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: @@ -3021,7 +3043,7 @@ class DocumentService: @staticmethod def _prepare_disable_update(document, user, now): """Prepare updates for disabling a document.""" - if not document.completed_at or document.indexing_status != "completed": + if not document.completed_at or document.indexing_status != IndexingStatus.COMPLETED: raise DocumentIndexingError(f"Document: {document.name} is not completed.") if not document.enabled: @@ -3130,7 +3152,7 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status="completed", + status=SegmentStatus.COMPLETED, indexing_at=naive_utc_now(), completed_at=naive_utc_now(), created_by=current_user.id, @@ -3167,7 +3189,7 @@ class SegmentService: logger.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" + segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() @@ -3227,7 +3249,7 @@ class SegmentService: word_count=len(content), tokens=tokens, keywords=segment_item.get("keywords", []), - status="completed", + status=SegmentStatus.COMPLETED, indexing_at=naive_utc_now(), completed_at=naive_utc_now(), created_by=current_user.id, @@ -3259,7 +3281,7 @@ class SegmentService: for segment_document in segment_data_list: segment_document.enabled = False segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" + segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) db.session.commit() return segment_data_list @@ -3405,7 +3427,7 @@ class SegmentService: segment.index_node_hash = segment_hash segment.word_count = len(content) segment.tokens = tokens - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.indexing_at = naive_utc_now() segment.completed_at = naive_utc_now() segment.updated_by = current_user.id @@ -3530,7 +3552,7 @@ class SegmentService: logger.exception("update segment index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) db.session.commit() new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index d85b290534..9993d24c70 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -13,7 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DatasetQuerySource logger = logging.getLogger(__name__) @@ -97,7 +97,7 @@ class HitTestingService: dataset_query = DatasetQuery( dataset_id=dataset.id, content=json.dumps(dataset_queries), - source="hit_testing", + source=DatasetQuerySource.HIT_TESTING, source_app_id=None, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, @@ -137,7 +137,7 @@ class HitTestingService: dataset_query = DatasetQuery( dataset_id=dataset.id, content=query, - source="hit_testing", + source=DatasetQuerySource.HIT_TESTING, source_app_id=None, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 859fc1902b..2f47a647a8 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -7,6 +7,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding +from models.enums import DatasetMetadataType from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ( MetadataArgs, @@ -130,11 +131,11 @@ class MetadataService: @staticmethod def get_built_in_fields(): return [ - {"name": BuiltInField.document_name, "type": "string"}, - {"name": BuiltInField.uploader, "type": "string"}, - {"name": BuiltInField.upload_date, "type": "time"}, - {"name": BuiltInField.last_update_date, "type": "time"}, - {"name": BuiltInField.source, "type": "string"}, + {"name": BuiltInField.document_name, "type": DatasetMetadataType.STRING}, + {"name": BuiltInField.uploader, "type": DatasetMetadataType.STRING}, + {"name": BuiltInField.upload_date, "type": DatasetMetadataType.TIME}, + {"name": BuiltInField.last_update_date, "type": DatasetMetadataType.TIME}, + {"name": BuiltInField.source, "type": DatasetMetadataType.STRING}, ] @staticmethod diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 2133dc5b3a..bf3b6db3ed 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now +from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential logger = logging.getLogger(__name__) @@ -103,9 +104,9 @@ class ModelLoadBalancingService: is_load_balancing_enabled = True if config_from == "predefined-model": - credential_source_type = "provider" + credential_source_type = CredentialSourceType.PROVIDER else: - credential_source_type = "custom_model" + credential_source_type = CredentialSourceType.CUSTOM_MODEL # Get load balancing configurations load_balancing_configs = ( @@ -421,7 +422,11 @@ class ModelLoadBalancingService: raise ValueError("Invalid load balancing config name") if credential_id: - credential_source = "provider" if config_from == "predefined-model" else "custom_model" + credential_source = ( + CredentialSourceType.PROVIDER + if config_from == "predefined-model" + else CredentialSourceType.CUSTOM_MODEL + ) assert credential_record is not None load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index f397b28283..07e1b8f20e 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -6,6 +6,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from models.dataset import Document, Pipeline +from models.enums import IndexingStatus from models.model import Account, App, EndUser from models.workflow import Workflow from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -111,6 +112,6 @@ class PipelineGenerateService: """ document = db.session.query(Document).where(Document.id == document_id).first() if document: - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) db.session.commit() diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 571ca6c7a6..f996db11dc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -15,7 +15,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval recommended app from dify official """ - def get_pipeline_template_detail(self, template_id: str): + def get_pipeline_template_detail(self, template_id: str) -> dict | None: + result: dict | None try: result = self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: @@ -35,17 +36,23 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return PipelineTemplateType.REMOTE @classmethod - def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict: """ Fetch pipeline template detail from dify official. - :param template_id: Pipeline ID - :return: + + :param template_id: Pipeline template ID + :return: Template detail dict + :raises ValueError: When upstream returns a non-200 status code """ domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/pipeline-templates/{template_id}" response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: - return None + raise ValueError( + "fetch pipeline template detail failed," + + f" status_code: {response.status_code}," + + f" response: {response.text[:1000]}" + ) data: dict = response.json() return data diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 2118043a98..f3aedafac9 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -64,7 +64,7 @@ from models.dataset import ( # type: ignore PipelineCustomizedTemplate, PipelineRecommendedPlugin, ) -from models.enums import WorkflowRunTriggeredFrom +from models.enums import IndexingStatus, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( Workflow, @@ -117,13 +117,21 @@ class RagPipelineService: def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: """ Get pipeline template detail. + :param template_id: template id - :return: + :param type: template type, "built-in" or "customized" + :return: template detail dict, or None if not found """ if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + if built_in_result is None: + logger.warning( + "pipeline template retrieval returned empty result, template_id: %s, mode: %s", + template_id, + mode, + ) return built_in_result else: mode = "customized" @@ -906,7 +914,7 @@ class RagPipelineService: if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = error db.session.add(document) db.session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c7da1afe1b..deb59da8d3 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -35,6 +35,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline +from models.enums import CollectionBindingType, DatasetRuntimeMode from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import ( IconInfo, @@ -313,7 +314,7 @@ class RagPipelineDslService: indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) if knowledge_configuration.indexing_technique == "high_quality": @@ -323,7 +324,7 @@ class RagPipelineDslService: DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, - DatasetCollectionBinding.type == "dataset", + DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) .first() @@ -334,7 +335,7 @@ class RagPipelineDslService: provider_name=knowledge_configuration.embedding_model_provider, model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type="dataset", + type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) self._session.commit() @@ -445,13 +446,13 @@ class RagPipelineDslService: indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: dataset.indexing_technique = knowledge_configuration.indexing_technique dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( @@ -460,7 +461,7 @@ class RagPipelineDslService: DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, - DatasetCollectionBinding.type == "dataset", + DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) .first() @@ -471,7 +472,7 @@ class RagPipelineDslService: provider_name=knowledge_configuration.embedding_model_provider, model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type="dataset", + type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) self._session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index cee18387b3..1d0aafd5fd 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline +from models.enums import DatasetRuntimeMode, DataSourceType from models.model import UploadFile from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting @@ -27,7 +28,7 @@ class RagPipelineTransformService: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") - if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline": + if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE: return { "pipeline_id": dataset.pipeline_id, "dataset_id": dataset_id, @@ -85,7 +86,7 @@ class RagPipelineTransformService: else: raise ValueError("Unsupported doc form") - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.pipeline_id = pipeline.id # deal document data @@ -102,7 +103,7 @@ class RagPipelineTransformService: pipeline_yaml = {} if doc_form == "text_model": match datasource_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: if indexing_technique == "high_quality": # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: @@ -111,7 +112,7 @@ class RagPipelineTransformService: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "notion_import": + case DataSourceType.NOTION_IMPORT: if indexing_technique == "high_quality": # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: @@ -120,7 +121,7 @@ class RagPipelineTransformService: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: if indexing_technique == "high_quality": # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: @@ -133,15 +134,15 @@ class RagPipelineTransformService: raise ValueError("Unsupported datasource type") elif doc_form == "hierarchical_model": match datasource_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "notion_import": + case DataSourceType.NOTION_IMPORT: # get graph from transform.notion-parentchild.yml with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: # get graph from transform.website-crawl-parentchild.yml with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) @@ -287,7 +288,7 @@ class RagPipelineTransformService: db.session.flush() dataset.pipeline_id = pipeline.id - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.updated_by = current_user.id dataset.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.add(dataset) @@ -310,8 +311,8 @@ class RagPipelineTransformService: data_source_info_dict = document.data_source_info_dict if not data_source_info_dict: continue - if document.data_source_type == "upload_file": - document.data_source_type = "local_file" + if document.data_source_type == DataSourceType.UPLOAD_FILE: + document.data_source_type = DataSourceType.LOCAL_FILE file_id = data_source_info_dict.get("upload_file_id") if file_id: file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() @@ -331,7 +332,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="local_file", + datasource_type=DataSourceType.LOCAL_FILE, datasource_info=data_source_info, input_data={}, created_by=document.created_by, @@ -340,8 +341,8 @@ class RagPipelineTransformService: document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log) - elif document.data_source_type == "notion_import": - document.data_source_type = "online_document" + elif document.data_source_type == DataSourceType.NOTION_IMPORT: + document.data_source_type = DataSourceType.ONLINE_DOCUMENT data_source_info = json.dumps( { "workspace_id": data_source_info_dict.get("notion_workspace_id"), @@ -359,7 +360,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="online_document", + datasource_type=DataSourceType.ONLINE_DOCUMENT, datasource_info=data_source_info, input_data={}, created_by=document.created_by, @@ -368,8 +369,7 @@ class RagPipelineTransformService: document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log) - elif document.data_source_type == "website_crawl": - document.data_source_type = "website_crawl" + elif document.data_source_type == DataSourceType.WEBSITE_CRAWL: data_source_info = json.dumps( { "source_url": data_source_info_dict.get("url"), @@ -388,7 +388,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="website_crawl", + datasource_type=DataSourceType.WEBSITE_CRAWL, datasource_info=data_source_info, input_data={}, created_by=document.created_by, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index eb78be8f88..943dfc972b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -12,12 +12,14 @@ from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument +from models.enums import SummaryStatus logger = logging.getLogger(__name__) @@ -29,7 +31,7 @@ class SummaryIndexService: def generate_summary_for_segment( segment: DocumentSegment, dataset: Dataset, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, ) -> tuple[str, LLMUsage]: """ Generate summary for a single segment. @@ -73,7 +75,7 @@ class SummaryIndexService: segment: DocumentSegment, dataset: Dataset, summary_content: str, - status: str = "generating", + status: SummaryStatus = SummaryStatus.GENERATING, ) -> DocumentSegmentSummary: """ Create or update a DocumentSegmentSummary record. @@ -83,7 +85,7 @@ class SummaryIndexService: segment: DocumentSegment to create summary for dataset: Dataset containing the segment summary_content: Generated summary content - status: Summary status (default: "generating") + status: Summary status (default: SummaryStatus.GENERATING) Returns: Created or updated DocumentSegmentSummary instance @@ -326,7 +328,7 @@ class SummaryIndexService: summary_index_node_id=summary_index_node_id, summary_index_node_hash=summary_hash, tokens=embedding_tokens, - status="completed", + status=SummaryStatus.COMPLETED, enabled=True, ) session.add(summary_record_in_session) @@ -362,7 +364,7 @@ class SummaryIndexService: summary_record_in_session.summary_index_node_id = summary_index_node_id summary_record_in_session.summary_index_node_hash = summary_hash summary_record_in_session.tokens = embedding_tokens # Save embedding tokens - summary_record_in_session.status = "completed" + summary_record_in_session.status = SummaryStatus.COMPLETED # Ensure summary_content is preserved (use the latest from summary_record parameter) # This is critical: use the parameter value, not the database value summary_record_in_session.summary_content = summary_content @@ -400,7 +402,7 @@ class SummaryIndexService: summary_record.summary_index_node_id = summary_index_node_id summary_record.summary_index_node_hash = summary_hash summary_record.tokens = embedding_tokens - summary_record.status = "completed" + summary_record.status = SummaryStatus.COMPLETED summary_record.summary_content = summary_content if summary_record_in_session.updated_at: summary_record.updated_at = summary_record_in_session.updated_at @@ -487,7 +489,7 @@ class SummaryIndexService: ) if summary_record_in_session: - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = f"Vectorization failed: {str(e)}" summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) error_session.add(summary_record_in_session) @@ -498,7 +500,7 @@ class SummaryIndexService: summary_record_in_session.id, ) # Update the original object for consistency - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = summary_record_in_session.error summary_record.updated_at = summary_record_in_session.updated_at else: @@ -514,7 +516,7 @@ class SummaryIndexService: def batch_create_summary_records( segments: list[DocumentSegment], dataset: Dataset, - status: str = "not_started", + status: SummaryStatus = SummaryStatus.NOT_STARTED, ) -> None: """ Batch create summary records for segments with specified status. @@ -523,7 +525,7 @@ class SummaryIndexService: Args: segments: List of DocumentSegment instances dataset: Dataset containing the segments - status: Initial status for the records (default: "not_started") + status: Initial status for the records (default: SummaryStatus.NOT_STARTED) """ segment_ids = [segment.id for segment in segments] if not segment_ids: @@ -588,7 +590,7 @@ class SummaryIndexService: ) if summary_record: - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = error session.add(summary_record) session.commit() @@ -599,7 +601,7 @@ class SummaryIndexService: def generate_and_vectorize_summary( segment: DocumentSegment, dataset: Dataset, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, ) -> DocumentSegmentSummary: """ Generate summary for a segment and vectorize it. @@ -631,14 +633,14 @@ class SummaryIndexService: document_id=segment.document_id, chunk_id=segment.id, summary_content="", - status="generating", + status=SummaryStatus.GENERATING, enabled=True, ) session.add(summary_record_in_session) session.flush() # Update status to "generating" - summary_record_in_session.status = "generating" + summary_record_in_session.status = SummaryStatus.GENERATING summary_record_in_session.error = None # type: ignore[assignment] session.add(summary_record_in_session) # Don't flush here - wait until after vectorization succeeds @@ -681,7 +683,7 @@ class SummaryIndexService: except Exception as vectorize_error: # If vectorization fails, update status to error in current session logger.exception("Failed to vectorize summary for segment %s", segment.id) - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}" session.add(summary_record_in_session) session.commit() @@ -694,7 +696,7 @@ class SummaryIndexService: session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) if summary_record_in_session: - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = str(e) session.add(summary_record_in_session) session.commit() @@ -704,7 +706,7 @@ class SummaryIndexService: def generate_summaries_for_document( dataset: Dataset, document: DatasetDocument, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, segment_ids: list[str] | None = None, only_parent_chunks: bool = False, ) -> list[DocumentSegmentSummary]: @@ -770,7 +772,7 @@ class SummaryIndexService: SummaryIndexService.batch_create_summary_records( segments=segments, dataset=dataset, - status="not_started", + status=SummaryStatus.NOT_STARTED, ) summary_records = [] @@ -1067,7 +1069,7 @@ class SummaryIndexService: # Update summary content summary_record.summary_content = summary_content - summary_record.status = "generating" + summary_record.status = SummaryStatus.GENERATING summary_record.error = None # type: ignore[assignment] # Clear any previous errors session.add(summary_record) # Flush to ensure summary_content is saved before vectorize_summary queries it @@ -1102,7 +1104,7 @@ class SummaryIndexService: # If vectorization fails, update status to error in current session # Don't raise the exception - just log it and return the record with error status # This allows the segment update to complete even if vectorization fails - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = f"Vectorization failed: {str(e)}" session.commit() logger.exception("Failed to vectorize summary for segment %s", segment.id) @@ -1112,7 +1114,7 @@ class SummaryIndexService: else: # Create new summary record if doesn't exist summary_record = SummaryIndexService.create_summary_record( - segment, dataset, summary_content, status="generating" + segment, dataset, summary_content, status=SummaryStatus.GENERATING ) # Re-vectorize summary (this will update status to "completed" and tokens in its own session) # Note: summary_record was created in a different session, @@ -1132,7 +1134,7 @@ class SummaryIndexService: # If vectorization fails, update status to error in current session # Merge the record into current session first error_record = session.merge(summary_record) - error_record.status = "error" + error_record.status = SummaryStatus.ERROR error_record.error = f"Vectorization failed: {str(e)}" session.commit() logger.exception("Failed to vectorize summary for segment %s", segment.id) @@ -1146,7 +1148,7 @@ class SummaryIndexService: session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) if summary_record: - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = str(e) session.add(summary_record) session.commit() @@ -1266,7 +1268,7 @@ class SummaryIndexService: # Check if there are any "not_started" or "generating" status summaries has_pending_summaries = any( summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summary_status_map[segment_id] in ("not_started", "generating") + and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING) for segment_id in segment_ids ) @@ -1330,7 +1332,7 @@ class SummaryIndexService: # it means the summary is disabled (enabled=False) or not created yet, ignore it has_pending_summaries = any( summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summary_status_map[segment_id] in ("not_started", "generating") + and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING) for segment_id in segment_ids ) @@ -1393,17 +1395,17 @@ class SummaryIndexService: # Count statuses status_counts = { - "completed": 0, - "generating": 0, - "error": 0, - "not_started": 0, + SummaryStatus.COMPLETED: 0, + SummaryStatus.GENERATING: 0, + SummaryStatus.ERROR: 0, + SummaryStatus.NOT_STARTED: 0, } summary_list = [] for segment in segments: summary = summary_map.get(segment.id) if summary: - status = summary.status + status = SummaryStatus(summary.status) status_counts[status] = status_counts.get(status, 0) + 1 summary_list.append( { @@ -1421,12 +1423,12 @@ class SummaryIndexService: } ) else: - status_counts["not_started"] += 1 + status_counts[SummaryStatus.NOT_STARTED] += 1 summary_list.append( { "segment_id": segment.id, "segment_position": segment.position, - "status": "not_started", + "status": SummaryStatus.NOT_STARTED, "summary_preview": None, "error": None, "created_at": None, diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 2d3d00cd50..ae55c9ee03 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -13,6 +13,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ def add_document_to_index_task(dataset_document_id: str): logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) return - if dataset_document.indexing_status != "completed": + if dataset_document.indexing_status != IndexingStatus.COMPLETED: return indexing_cache_key = f"document_{dataset_document.id}_indexing" @@ -48,7 +49,7 @@ def add_document_to_index_task(dataset_document_id: str): session.query(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", + DocumentSegment.status == SegmentStatus.COMPLETED, ) .order_by(DocumentSegment.position.asc()) .all() @@ -139,7 +140,7 @@ def add_document_to_index_task(dataset_document_id: str): logger.exception("add document to index failed") dataset_document.enabled = False dataset_document.disabled_at = naive_utc_now() - dataset_document.indexing_status = "error" + dataset_document.indexing_status = IndexingStatus.ERROR dataset_document.error = str(e) session.commit() finally: diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 4f8e2fec7a..1fe43c3d62 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -11,6 +11,7 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset +from models.enums import CollectionBindingType from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService @@ -47,7 +48,7 @@ def enable_annotation_reply_task( try: documents = [] dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, "annotation" + embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION ) annotation_setting = ( session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() @@ -56,7 +57,7 @@ def enable_annotation_reply_task( if dataset_collection_binding.id != annotation_setting.collection_binding_id: old_dataset_collection_binding = ( DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - annotation_setting.collection_binding_id, "annotation" + annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION ) ) if old_dataset_collection_binding and annotations: diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index b5e472d71e..b3cbc73d6e 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -10,6 +10,7 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - if segment.status != "waiting": + if segment.status != SegmentStatus.WAITING: return indexing_cache_key = f"segment_{segment.id}_indexing" @@ -40,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N # update segment status to indexing session.query(DocumentSegment).filter_by(id=segment.id).update( { - DocumentSegment.status: "indexing", + DocumentSegment.status: SegmentStatus.INDEXING, DocumentSegment.indexing_at: naive_utc_now(), } ) @@ -70,7 +71,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N if ( not dataset_document.enabled or dataset_document.archived - or dataset_document.indexing_status != "completed" + or dataset_document.indexing_status != IndexingStatus.COMPLETED ): logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return @@ -82,7 +83,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N # update segment to completed session.query(DocumentSegment).filter_by(id=segment.id).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.completed_at: naive_utc_now(), } ) @@ -94,7 +95,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N logger.exception("create segment to index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) session.commit() finally: diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index fddd9199d1..f99e90062f 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -12,6 +12,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -37,7 +38,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document not found: {document_id}", fg="red")) return - if document.indexing_status == "parsing": + if document.indexing_status == IndexingStatus.PARSING: logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow")) return @@ -88,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): with session_factory.create_session() as session, session.begin(): document = session.query(Document).filter_by(id=document_id).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = "Datasource credential not found. Please reconnect your Notion workspace." document.stopped_at = naive_utc_now() return @@ -128,7 +129,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): data_source_info["last_edited_time"] = last_edited_time document.data_source_info = json.dumps(data_source_info) - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) @@ -151,6 +152,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): with session_factory.create_session() as session, session.begin(): document = session.query(Document).filter_by(id=document_id).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index b3f36d8f44..e05d63426c 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -14,6 +14,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document +from models.enums import IndexingStatus from services.feature_service import FeatureService from tasks.generate_summary_index_task import generate_summary_index_task @@ -81,7 +82,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -96,7 +97,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): for document in documents: if document: - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) # Transaction committed and closed @@ -148,7 +149,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): document.need_summary, ) if ( - document.indexing_status == "completed" + document.indexing_status == IndexingStatus.COMPLETED and document.doc_form != "qa_model" and document.need_summary is True ): diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index c7508c6d05..62bce24de4 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -10,6 +10,7 @@ from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document not found: {document_id}", fg="red")) return - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 00a963255b..13c651753f 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -15,6 +15,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService logger = logging.getLogger(__name__) @@ -112,7 +113,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st ) for document in documents: if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -146,7 +147,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 41ebb0b076..5ad17d75d4 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -12,6 +12,7 @@ from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def enable_segment_to_index_task(segment_id: str): logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - if segment.status != "completed": + if segment.status != SegmentStatus.COMPLETED: logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red")) return @@ -65,7 +66,7 @@ def enable_segment_to_index_task(segment_id: str): if ( not dataset_document.enabled or dataset_document.archived - or dataset_document.indexing_status != "completed" + or dataset_document.indexing_status != IndexingStatus.COMPLETED ): logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return @@ -123,7 +124,7 @@ def enable_segment_to_index_task(segment_id: str): logger.exception("enable segment to index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) session.commit() finally: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index f20b15ac83..4fcb0cf804 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -12,6 +12,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models import Account, Tenant from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -63,7 +64,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ .first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -95,7 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() @@ -108,7 +109,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(ex) document.stopped_at = naive_utc_now() session.add(document) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index f1c8c56995..aa6bce958b 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -11,6 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -76,7 +77,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() @@ -85,7 +86,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): indexing_runner.run([document]) redis_client.delete(sync_indexing_cache_key) except Exception as ex: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(ex) document.stopped_at = naive_utc_now() session.add(document) diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 75471afef8..781e297fa4 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -7,6 +7,7 @@ from faker import Faker from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -35,7 +36,7 @@ class TestGetAvailableDatasetsIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, indexing_technique="high_quality", ) @@ -49,14 +50,14 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field name=f"Document {i}", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -94,7 +95,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -106,13 +107,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived ) @@ -147,7 +148,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -159,13 +160,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, ) @@ -200,21 +201,21 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) # Create documents with non-completed status - for i, status in enumerate(["indexing", "parsing", "splitting"]): + for i, status in enumerate([IndexingStatus.INDEXING, IndexingStatus.PARSING, IndexingStatus.SPLITTING]): document = Document( id=str(uuid.uuid4()), tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, doc_form="text_model", @@ -263,7 +264,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="external", # External provider - data_source_type="external", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -307,7 +308,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant1.id, name="Tenant 1 Dataset", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account1.id, ) db_session_with_containers.add(dataset1) @@ -318,7 +319,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant2.id, name="Tenant 2 Dataset", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account2.id, ) db_session_with_containers.add(dataset2) @@ -330,13 +331,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -398,7 +399,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=f"Dataset {i}", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -410,13 +411,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -456,7 +457,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, indexing_technique="high_quality", ) @@ -467,12 +468,12 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=fake.sentence(), created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="text_model", @@ -525,7 +526,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -572,7 +573,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/models/test_dataset_models.py b/api/tests/test_containers_integration_tests/models/test_dataset_models.py index 6c541a8ad2..a3bbf19657 100644 --- a/api/tests/test_containers_integration_tests/models/test_dataset_models.py +++ b/api/tests/test_containers_integration_tests/models/test_dataset_models.py @@ -12,6 +12,7 @@ import pytest from sqlalchemy.orm import Session from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus class TestDatasetDocumentProperties: @@ -29,7 +30,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -39,10 +40,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=i + 1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name=f"doc_{i}.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -56,7 +57,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -65,12 +66,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="available.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -78,12 +79,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="pending.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, archived=False, ) @@ -91,12 +92,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="disabled.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, ) @@ -111,7 +112,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -121,10 +122,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=i + 1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name=f"doc_{i}.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=wc, ) @@ -139,7 +140,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -148,10 +149,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -166,7 +167,7 @@ class TestDatasetDocumentProperties: content=f"segment {i}", word_count=100, tokens=50, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, created_by=created_by, ) @@ -180,7 +181,7 @@ class TestDatasetDocumentProperties: content="waiting segment", word_count=100, tokens=50, - status="waiting", + status=SegmentStatus.WAITING, enabled=True, created_by=created_by, ) @@ -195,7 +196,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -204,10 +205,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -235,7 +236,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -244,10 +245,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -288,7 +289,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -298,10 +299,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -335,7 +336,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -345,10 +346,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -382,7 +383,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -392,10 +393,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -439,7 +440,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -449,10 +450,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py index 191c161613..638a61c815 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py +++ b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py @@ -12,6 +12,7 @@ import pytest from sqlalchemy.orm import Session from models.dataset import DatasetCollectionBinding +from models.enums import CollectionBindingType from services.dataset_service import DatasetCollectionBindingService @@ -32,7 +33,7 @@ class DatasetCollectionBindingTestDataFactory: provider_name: str = "openai", model_name: str = "text-embedding-ada-002", collection_name: str = "collection-abc", - collection_type: str = "dataset", + collection_type: str = CollectionBindingType.DATASET, ) -> DatasetCollectionBinding: """ Create a DatasetCollectionBinding with specified attributes. @@ -41,7 +42,7 @@ class DatasetCollectionBindingTestDataFactory: provider_name: Name of the embedding model provider (e.g., "openai", "cohere") model_name: Name of the embedding model (e.g., "text-embedding-ada-002") collection_name: Name of the vector database collection - collection_type: Type of collection (default: "dataset") + collection_type: Type of collection (default: CollectionBindingType.DATASET) Returns: DatasetCollectionBinding instance @@ -76,7 +77,7 @@ class TestDatasetCollectionBindingServiceGetBinding: # Arrange provider_name = "openai" model_name = "text-embedding-ada-002" - collection_type = "dataset" + collection_type = CollectionBindingType.DATASET existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( db_session_with_containers, provider_name=provider_name, @@ -104,7 +105,7 @@ class TestDatasetCollectionBindingServiceGetBinding: # Arrange provider_name = f"provider-{uuid4()}" model_name = f"model-{uuid4()}" - collection_type = "dataset" + collection_type = CollectionBindingType.DATASET # Act result = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -145,7 +146,7 @@ class TestDatasetCollectionBindingServiceGetBinding: result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name) # Assert - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET assert result.provider_name == provider_name assert result.model_name == model_name @@ -186,18 +187,20 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act - result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "dataset") + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + binding.id, CollectionBindingType.DATASET + ) # Assert assert result.id == binding.id assert result.provider_name == "openai" assert result.model_name == "text-embedding-ada-002" assert result.collection_name == "test-collection" - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session): """Test error handling when collection binding is not found by ID and type.""" @@ -206,7 +209,9 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: # Act & Assert with pytest.raises(ValueError, match="Dataset collection binding not found"): - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset") + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + non_existent_id, CollectionBindingType.DATASET + ) def test_get_dataset_collection_binding_by_id_and_type_different_collection_type( self, db_session_with_containers: Session @@ -240,7 +245,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act @@ -248,7 +253,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: # Assert assert result.id == binding.id - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session): """Test error when binding exists but with wrong collection type.""" @@ -258,7 +263,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act & Assert diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 4b98bddd26..6b35f867d7 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -15,6 +15,7 @@ from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum +from models.enums import DataSourceType from models.model import App from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -72,7 +73,7 @@ class DatasetUpdateDeleteTestDataFactory: tenant_id=tenant_id, name=name, description="Test description", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=permission, diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index c08ea2a93b..251f17dd03 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -15,7 +15,7 @@ import pytest from models import Account from models.dataset import Dataset, Document -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus from models.model import UploadFile from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -88,7 +88,7 @@ class DocumentStatusTestDataFactory: data_source_info=json.dumps(data_source_info or {}), batch=f"batch-{uuid4()}", name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, doc_form="text_model", ) @@ -100,7 +100,7 @@ class DocumentStatusTestDataFactory: document.paused_by = paused_by document.paused_at = paused_at document.doc_metadata = doc_metadata or {} - if indexing_status == "completed" and "completed_at" not in kwargs: + if indexing_status == IndexingStatus.COMPLETED and "completed_at" not in kwargs: document.completed_at = FIXED_TIME for key, value in kwargs.items(): @@ -139,7 +139,7 @@ class DocumentStatusTestDataFactory: dataset = Dataset( tenant_id=tenant_id, name=name, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) dataset.id = dataset_id @@ -291,7 +291,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, is_paused=False, ) @@ -326,7 +326,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=False, ) @@ -354,7 +354,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="parsing", + indexing_status=IndexingStatus.PARSING, is_paused=False, ) @@ -383,7 +383,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, is_paused=False, ) @@ -412,7 +412,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="error", + indexing_status=IndexingStatus.ERROR, is_paused=False, ) @@ -487,7 +487,7 @@ class TestDocumentServiceRecoverDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=True, paused_by=str(uuid4()), paused_at=paused_time, @@ -526,7 +526,7 @@ class TestDocumentServiceRecoverDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=False, ) @@ -609,7 +609,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = None @@ -619,7 +619,7 @@ class TestDocumentServiceRetryDocument: # Assert db_session_with_containers.refresh(document) - assert document.indexing_status == "waiting" + assert document.indexing_status == IndexingStatus.WAITING expected_cache_key = f"document_{document.id}_is_retried" mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1) @@ -646,14 +646,14 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) document2 = DocumentStatusTestDataFactory.create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, position=2, ) @@ -665,8 +665,8 @@ class TestDocumentServiceRetryDocument: # Assert db_session_with_containers.refresh(document1) db_session_with_containers.refresh(document2) - assert document1.indexing_status == "waiting" - assert document2.indexing_status == "waiting" + assert document1.indexing_status == IndexingStatus.WAITING + assert document2.indexing_status == IndexingStatus.WAITING mock_document_service_dependencies["retry_task"].delay.assert_called_once_with( dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"] @@ -693,7 +693,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = "1" @@ -703,7 +703,7 @@ class TestDocumentServiceRetryDocument: DocumentService.retry_document(dataset.id, [document]) db_session_with_containers.refresh(document) - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR def test_retry_document_missing_current_user_error( self, db_session_with_containers, mock_document_service_dependencies @@ -726,7 +726,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = None @@ -816,7 +816,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document2 = DocumentStatusTestDataFactory.create_document( db_session_with_containers, @@ -824,7 +824,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) document_ids = [document1.id, document2.id] @@ -866,7 +866,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, completed_at=FIXED_TIME, ) document_ids = [document.id] @@ -909,7 +909,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: document_id=str(uuid4()), archived=False, enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -951,7 +951,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: document_id=str(uuid4()), archived=True, enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -1015,7 +1015,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -1098,7 +1098,7 @@ class TestDocumentServiceRenameDocument: document_id=document_id, dataset_id=dataset.id, tenant_id=tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1139,7 +1139,7 @@ class TestDocumentServiceRenameDocument: dataset_id=dataset.id, tenant_id=tenant_id, doc_metadata={"existing_key": "existing_value"}, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1187,7 +1187,7 @@ class TestDocumentServiceRenameDocument: dataset_id=dataset.id, tenant_id=tenant_id, data_source_info={"upload_file_id": upload_file.id}, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1277,7 +1277,7 @@ class TestDocumentServiceRenameDocument: document_id=document_id, dataset_id=dataset.id, tenant_id=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act & Assert diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 44525e0036..975af3d428 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -16,6 +16,7 @@ from models.dataset import ( DatasetPermission, DatasetPermissionEnum, ) +from models.enums import DataSourceType from services.dataset_service import DatasetPermissionService, DatasetService from services.errors.account import NoPermissionError @@ -67,7 +68,7 @@ class DatasetPermissionTestDataFactory: tenant_id=tenant_id, name=name, description="desc", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=permission, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 102c1a1eb5..ac3d9f9604 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -15,6 +15,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline +from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity @@ -74,7 +75,7 @@ class DatasetServiceIntegrationDataFactory: tenant_id=tenant_id, name=name, description=description, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique=indexing_technique, created_by=created_by, provider=provider, @@ -98,13 +99,13 @@ class DatasetServiceIntegrationDataFactory: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info='{"upload_file_id": "upload-file-id"}', batch=str(uuid4()), name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) db_session_with_containers.add(document) @@ -437,7 +438,7 @@ class TestDatasetServiceCreateRagPipelineDataset: created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) assert created_dataset is not None assert created_dataset.name == entity.name - assert created_dataset.runtime_mode == "rag_pipeline" + assert created_dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE assert created_dataset.created_by == account.id assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME assert created_pipeline is not None diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index 322b67d373..7983b1cd93 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -14,6 +14,7 @@ import pytest from sqlalchemy.orm import Session from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -42,7 +43,7 @@ class DocumentBatchUpdateIntegrationDataFactory: dataset = Dataset( tenant_id=tenant_id or str(uuid4()), name=name, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by or str(uuid4()), ) if dataset_id: @@ -72,11 +73,11 @@ class DocumentBatchUpdateIntegrationDataFactory: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=position, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info=json.dumps({"upload_file_id": str(uuid4())}), batch=f"batch-{uuid4()}", name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by or str(uuid4()), doc_form="text_model", ) @@ -85,7 +86,9 @@ class DocumentBatchUpdateIntegrationDataFactory: document.archived = archived document.indexing_status = indexing_status document.completed_at = ( - completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None) + completed_at + if completed_at is not None + else (FIXED_TIME if indexing_status == IndexingStatus.COMPLETED else None) ) for key, value in kwargs.items(): @@ -243,7 +246,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: dataset=dataset, document_ids=document_ids, enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -277,7 +280,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: db_session_with_containers, dataset=dataset, enabled=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, completed_at=FIXED_TIME, ) @@ -306,7 +309,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: db_session_with_containers, dataset=dataset, enabled=True, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, completed_at=None, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index c47e35791d..ed070527c9 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -5,6 +5,7 @@ from uuid import uuid4 from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom from services.dataset_service import DatasetService @@ -58,7 +59,7 @@ class DatasetDeleteIntegrationDataFactory: dataset = Dataset( tenant_id=tenant_id, name=f"dataset-{uuid4()}", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique=indexing_technique, index_struct=index_struct, created_by=created_by, @@ -84,10 +85,10 @@ class DatasetDeleteIntegrationDataFactory: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=f"batch-{uuid4()}", name="Document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, doc_form=doc_form, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index e78894fcae..c4b3a57bb2 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom from services.dataset_service import SegmentService @@ -62,7 +63,7 @@ class SegmentServiceTestDataFactory: tenant_id=tenant_id, name=f"Test Dataset {uuid4()}", description="Test description", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, @@ -82,10 +83,10 @@ class SegmentServiceTestDataFactory: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=f"batch-{uuid4()}", name=f"test-doc-{uuid4()}.txt", - created_from="api", + created_from=DocumentCreatedFrom.API, created_by=created_by, ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 8bd994937a..3021d8984d 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -24,6 +24,7 @@ from models.dataset import ( DatasetProcessRule, DatasetQuery, ) +from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode from models.model import Tag, TagBinding from services.dataset_service import DatasetService, DocumentService @@ -100,7 +101,7 @@ class DatasetRetrievalTestDataFactory: tenant_id=tenant_id, name=name, description="desc", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=permission, @@ -149,7 +150,7 @@ class DatasetRetrievalTestDataFactory: dataset_query = DatasetQuery( dataset_id=dataset_id, content=content, - source="web", + source=DatasetQuerySource.APP, source_app_id=None, created_by_role="account", created_by=created_by, @@ -601,7 +602,7 @@ class TestDatasetServiceGetProcessRules: db_session_with_containers, dataset_id=dataset.id, created_by=account.id, - mode="custom", + mode=ProcessRuleMode.CUSTOM, rules=rules_data, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index ebaa3b4637..fd81948247 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings +from models.enums import DataSourceType from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -64,7 +65,7 @@ class DatasetUpdateTestDataFactory: tenant_id=tenant_id, name=name, description=description, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique=indexing_technique, created_by=created_by, provider=provider, diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index 124056e10f..c6aa89c733 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -4,6 +4,7 @@ from uuid import uuid4 from sqlalchemy import select from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -11,7 +12,7 @@ def _create_dataset(db_session_with_containers) -> Dataset: dataset = Dataset( tenant_id=str(uuid4()), name=f"dataset-{uuid4()}", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) dataset.id = str(uuid4()) @@ -35,11 +36,11 @@ def _create_document( tenant_id=tenant_id, dataset_id=dataset_id, position=position, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch=f"batch-{uuid4()}", name=f"doc-{uuid4()}", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), doc_form="text_model", ) @@ -48,7 +49,7 @@ def _create_document( document.enabled = enabled document.archived = archived document.is_paused = is_paused - if indexing_status == "completed": + if indexing_status == IndexingStatus.COMPLETED: document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db_session_with_containers.add(document) @@ -62,7 +63,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, position=1, @@ -71,7 +72,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, position=2, @@ -80,7 +81,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, position=3, @@ -101,14 +102,14 @@ def test_apply_display_status_filter_applies_when_status_present(db_session_with db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, position=1, ) _create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) @@ -125,14 +126,14 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, position=1, ) doc2 = _create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index f641da6576..b159af0090 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -9,7 +9,7 @@ import pytest from models import Account from models.dataset import Dataset, Document -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom from models.model import UploadFile from services.dataset_service import DocumentService @@ -33,7 +33,7 @@ def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, bu dataset = Dataset( tenant_id=tenant_id, name=f"dataset-{uuid4()}", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) dataset.id = dataset_id @@ -62,11 +62,11 @@ def make_document( tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info=json.dumps(data_source_info or {}), batch=f"batch-{uuid4()}", name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), doc_form="text_model", ) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 6fe40c0744..ef1f31d36b 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import DataSourceType from models.model import ( App, AppAnnotationHitHistory, @@ -287,7 +288,7 @@ class TestMessagesCleanServiceIntegration: dataset_name="Test dataset", document_id=str(uuid.uuid4()), document_name="Test document", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, segment_id=str(uuid.uuid4()), score=0.9, content="Test content", diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 694dc1c1b9..e847329c5b 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document +from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService @@ -101,7 +102,7 @@ class TestMetadataService: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, built_in_field_enabled=False, ) @@ -132,11 +133,11 @@ class TestMetadataService: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch="test-batch", name=fake.file_name(), - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text", doc_language="en", @@ -163,7 +164,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id mock_external_service_dependencies["current_user"].id = account.id - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") # Act: Execute the method under test result = MetadataService.create_metadata(dataset.id, metadata_args) @@ -201,7 +202,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id long_name = "a" * 256 # 256 characters, exceeding 255 limit - metadata_args = MetadataArgs(type="string", name=long_name) + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name) # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): @@ -226,11 +227,11 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create first metadata - first_metadata_args = MetadataArgs(type="string", name="duplicate_name") + first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name") MetadataService.create_metadata(dataset.id, first_metadata_args) # Try to create second metadata with same name - second_metadata_args = MetadataArgs(type="number", name="duplicate_name") + second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name") # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists."): @@ -256,7 +257,7 @@ class TestMetadataService: # Try to create metadata with built-in field name built_in_field_name = BuiltInField.document_name - metadata_args = MetadataArgs(type="string", name=built_in_field_name) + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name) # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): @@ -281,7 +282,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test @@ -318,7 +319,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with too long name @@ -347,10 +348,10 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create two metadata entries - first_metadata_args = MetadataArgs(type="string", name="first_metadata") + first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata") first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args) - second_metadata_args = MetadataArgs(type="number", name="second_metadata") + second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata") second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args) # Try to update first metadata with second metadata's name @@ -376,7 +377,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with built-in field name @@ -432,7 +433,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="to_be_deleted") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test @@ -496,7 +497,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create metadata binding @@ -798,7 +799,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Mock DocumentService.get_document @@ -866,7 +867,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Mock DocumentService.get_document @@ -917,7 +918,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create metadata operation data @@ -1038,7 +1039,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create document and metadata binding @@ -1101,7 +1102,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 597ba6b75b..fa6e651529 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset +from models.enums import DataSourceType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -100,7 +101,7 @@ class TestTagService: description=fake.text(max_nb_chars=100), provider="vendor", permission="only_me", - data_source_type="upload", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 8c007877fd..c3fe6a2950 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -510,7 +510,7 @@ class TestWorkflowConverter: retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=10, score_threshold=0.8, - reranking_model={"provider": "cohere", "model": "rerank-v2"}, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, reranking_enabled=True, ), ) @@ -543,8 +543,8 @@ class TestWorkflowConverter: multiple_config = node["data"]["multiple_retrieval_config"] assert multiple_config["top_k"] == 10 assert multiple_config["score_threshold"] == 0.8 - assert multiple_config["reranking_model"]["provider"] == "cohere" - assert multiple_config["reranking_model"]["model"] == "rerank-v2" + assert multiple_config["reranking_model"]["reranking_provider_name"] == "cohere" + assert multiple_config["reranking_model"]["reranking_model_name"] == "rerank-v2" # Verify single retrieval config is None for multiple strategy assert node["data"]["single_retrieval_config"] is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index efeb29cf20..94173c34bf 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -8,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.add_document_to_index_task import add_document_to_index_task @@ -79,7 +80,7 @@ class TestAddDocumentToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -92,12 +93,12 @@ class TestAddDocumentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) @@ -137,7 +138,7 @@ class TestAddDocumentToIndexTask: index_node_id=f"node_{i}", index_node_hash=f"hash_{i}", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment) @@ -297,7 +298,7 @@ class TestAddDocumentToIndexTask: ) # Set invalid indexing status - document.indexing_status = "processing" + document.indexing_status = IndexingStatus.INDEXING db_session_with_containers.commit() # Act: Execute the task @@ -339,7 +340,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify error handling db_session_with_containers.refresh(document) assert document.enabled is False - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR assert document.error is not None assert "doesn't exist" in document.error assert document.disabled_at is not None @@ -434,7 +435,7 @@ class TestAddDocumentToIndexTask: Test document indexing when segments are already enabled. This test verifies: - - Segments with status="completed" are processed regardless of enabled status + - Segments with status=SegmentStatus.COMPLETED are processed regardless of enabled status - Index processing occurs with all completed segments - Auto disable log deletion still occurs - Redis cache is cleared @@ -460,7 +461,7 @@ class TestAddDocumentToIndexTask: index_node_id=f"node_{i}", index_node_hash=f"hash_{i}", enabled=True, # Already enabled - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment) @@ -482,7 +483,7 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with all completed segments - # (implementation doesn't filter by enabled status, only by status="completed") + # (implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED) call_args = mock_external_service_dependencies["index_processor"].load.call_args assert call_args is not None documents = call_args[0][1] # Second argument should be documents list @@ -594,7 +595,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify error handling db_session_with_containers.refresh(document) assert document.enabled is False - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR assert document.error is not None assert "Index processing failed" in document.error assert document.disabled_at is not None @@ -614,7 +615,7 @@ class TestAddDocumentToIndexTask: Test segment filtering with various edge cases. This test verifies: - - Only segments with status="completed" are processed (regardless of enabled status) + - Only segments with status=SegmentStatus.COMPLETED are processed (regardless of enabled status) - Segments with status!="completed" are NOT processed - Segments are ordered by position correctly - Mixed segment states are handled properly @@ -630,7 +631,7 @@ class TestAddDocumentToIndexTask: fake = Faker() segments = [] - # Segment 1: Should be processed (enabled=False, status="completed") + # Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED) segment1 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -643,14 +644,14 @@ class TestAddDocumentToIndexTask: index_node_id="node_0", index_node_hash="hash_0", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment1) segments.append(segment1) - # Segment 2: Should be processed (enabled=True, status="completed") - # Note: Implementation doesn't filter by enabled status, only by status="completed" + # Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED) + # Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED segment2 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -663,7 +664,7 @@ class TestAddDocumentToIndexTask: index_node_id="node_1", index_node_hash="hash_1", enabled=True, # Already enabled, but will still be processed - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment2) @@ -682,13 +683,13 @@ class TestAddDocumentToIndexTask: index_node_id="node_2", index_node_hash="hash_2", enabled=False, - status="processing", # Not completed + status=SegmentStatus.INDEXING, # Not completed created_by=document.created_by, ) db_session_with_containers.add(segment3) segments.append(segment3) - # Segment 4: Should be processed (enabled=False, status="completed") + # Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED) segment4 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -701,7 +702,7 @@ class TestAddDocumentToIndexTask: index_node_id="node_3", index_node_hash="hash_3", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) db_session_with_containers.add(segment4) @@ -726,7 +727,7 @@ class TestAddDocumentToIndexTask: call_args = mock_external_service_dependencies["index_processor"].load.call_args assert call_args is not None documents = call_args[0][1] # Second argument should be documents list - assert len(documents) == 3 # 3 segments with status="completed" should be processed + assert len(documents) == 3 # 3 segments with status=SegmentStatus.COMPLETED should be processed # Verify correct segments were processed (by position order) # Segments 1, 2, 4 should be processed (positions 0, 1, 3) @@ -799,7 +800,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify consistent error handling db_session_with_containers.refresh(document) assert document.enabled is False, f"Document should be disabled for {error_name}" - assert document.indexing_status == "error", f"Document status should be error for {error_name}" + assert document.indexing_status == IndexingStatus.ERROR, f"Document status should be error for {error_name}" assert document.error is not None, f"Error should be recorded for {error_name}" assert str(exception) in document.error, f"Error message should contain exception for {error_name}" assert document.disabled_at is not None, f"Disabled timestamp should be set for {error_name}" diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index ec789418a8..6adefd59be 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from models.model import UploadFile from tasks.batch_clean_document_task import batch_clean_document_task @@ -113,7 +114,7 @@ class TestBatchCleanDocumentTask: tenant_id=account.current_tenant.id, name=fake.word(), description=fake.sentence(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", @@ -144,12 +145,12 @@ class TestBatchCleanDocumentTask: dataset_id=dataset.id, position=0, name=fake.word(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info=json.dumps({"upload_file_id": str(uuid.uuid4())}), batch="test_batch", - created_from="test", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) @@ -183,7 +184,7 @@ class TestBatchCleanDocumentTask: tokens=50, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) @@ -297,7 +298,7 @@ class TestBatchCleanDocumentTask: tokens=50, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) @@ -671,7 +672,7 @@ class TestBatchCleanDocumentTask: tokens=25 + i * 5, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) segments.append(segment) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index a2324979db..ebe5ff1d96 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -21,7 +21,7 @@ from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from models.model import UploadFile from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task @@ -139,7 +139,7 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", @@ -170,12 +170,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="text_model", @@ -301,7 +301,7 @@ class TestBatchCreateSegmentToIndexTask: assert segment.dataset_id == dataset.id assert segment.document_id == document.id assert segment.position == i + 1 - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.answer is None # text_model doesn't have answers @@ -442,12 +442,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="disabled_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, doc_form="text_model", @@ -458,12 +458,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="archived_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived doc_form="text_model", @@ -474,12 +474,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="incomplete_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="indexing", # Not completed + indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, doc_form="text_model", @@ -643,7 +643,7 @@ class TestBatchCreateSegmentToIndexTask: word_count=len(f"Existing segment {i + 1}"), tokens=10, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash=f"hash_{i}", ) @@ -694,7 +694,7 @@ class TestBatchCreateSegmentToIndexTask: for i, segment in enumerate(new_segments): expected_position = 4 + i # Should start at position 4 assert segment.position == expected_position - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 41d9fc8a29..638752cf8b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -29,7 +29,14 @@ from models.dataset import ( Document, DocumentSegment, ) -from models.enums import CreatorUserRole +from models.enums import ( + CreatorUserRole, + DatasetMetadataType, + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + SegmentStatus, +) from models.model import UploadFile from tasks.clean_dataset_task import clean_dataset_task @@ -176,12 +183,12 @@ class TestCleanDatasetTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="test_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="paragraph_index", @@ -219,7 +226,7 @@ class TestCleanDatasetTask: word_count=20, tokens=30, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash", created_at=datetime.now(), @@ -373,7 +380,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name="test_metadata", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) metadata.id = str(uuid.uuid4()) @@ -587,7 +594,7 @@ class TestCleanDatasetTask: word_count=len(segment_content), tokens=50, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash", created_at=datetime.now(), @@ -686,7 +693,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name=f"test_metadata_{i}", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) metadata.id = str(uuid.uuid4()) @@ -880,11 +887,11 @@ class TestCleanDatasetTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch="test_batch", name=f"test_doc_{special_content}", - created_from="test", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, created_at=datetime.now(), updated_at=datetime.now(), @@ -905,7 +912,7 @@ class TestCleanDatasetTask: word_count=len(segment_content.split()), tokens=len(segment_content) // 4, # Rough token estimation created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash_" + "x" * 50, # Long hash within limits created_at=datetime.now(), @@ -946,7 +953,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name=f"metadata_{special_content}", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) special_metadata.id = str(uuid.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 3ce199c602..a2a190fd69 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService from tasks.clean_notion_document_task import clean_notion_document_task from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -88,7 +89,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -105,17 +106,17 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", # Set doc_form to ensure dataset.doc_form works doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -134,7 +135,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) segments.append(segment) @@ -220,7 +221,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -269,7 +270,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=f"{fake.company()}_{index_type}", description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -281,17 +282,17 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form=index_type, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -308,7 +309,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id="test_node", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) db_session_with_containers.commit() @@ -357,7 +358,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -369,16 +370,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -397,7 +398,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=None, # No index node ID created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) segments.append(segment) @@ -443,7 +444,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -460,16 +461,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -488,7 +489,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -558,7 +559,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -570,22 +571,22 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() # Create segments with different statuses - segment_statuses = ["waiting", "processing", "completed", "error"] + segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR] segments = [] index_node_ids = [] @@ -654,7 +655,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -666,16 +667,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -692,7 +693,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id="test_node", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) db_session_with_containers.commit() @@ -736,7 +737,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -754,16 +755,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -783,7 +784,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -848,7 +849,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=f"{fake.company()}_{i}", description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -866,16 +867,16 @@ class TestCleanNotionDocumentTask: tenant_id=account.current_tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -894,7 +895,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -963,14 +964,22 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() # Create documents with different indexing statuses - document_statuses = ["waiting", "parsing", "cleaning", "splitting", "indexing", "completed", "error"] + document_statuses = [ + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + ] documents = [] all_segments = [] all_index_node_ids = [] @@ -981,13 +990,13 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", indexing_status=status, @@ -1009,7 +1018,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -1066,7 +1075,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, built_in_field_enabled=True, ) @@ -1079,7 +1088,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( { "notion_workspace_id": "workspace_test", @@ -1091,10 +1100,10 @@ class TestCleanNotionDocumentTask: ), batch="test_batch", name="Test Notion Page with Metadata", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_metadata={ "document_name": "Test Notion Page with Metadata", "uploader": account.name, @@ -1122,7 +1131,7 @@ class TestCleanNotionDocumentTask: tokens=75, index_node_id=f"node_{i}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, keywords={"key1": ["value1", "value2"], "key2": ["value3"]}, ) db_session_with_containers.add(segment) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 4fa52ff2a9..132f43c320 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -15,6 +15,7 @@ from faker import Faker from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.create_segment_to_index_task import create_segment_to_index_task @@ -118,7 +119,7 @@ class TestCreateSegmentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), tenant_id=tenant_id, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", embedding_model_provider="openai", embedding_model="text-embedding-ada-002", @@ -133,13 +134,13 @@ class TestCreateSegmentToIndexTask: dataset_id=dataset.id, tenant_id=tenant_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account_id, enabled=True, archived=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="qa_model", ) db_session_with_containers.add(document) @@ -148,7 +149,7 @@ class TestCreateSegmentToIndexTask: return dataset, document def _create_test_segment( - self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting" + self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING ): """ Helper method to create a test document segment for testing. @@ -200,7 +201,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -208,7 +209,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify segment status changes db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.error is None @@ -257,7 +258,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.COMPLETED ) # Act: Execute the task @@ -268,7 +269,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status unchanged db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is None # Verify no index processor calls were made @@ -293,20 +294,25 @@ class TestCreateSegmentToIndexTask: dataset_id=invalid_dataset_id, tenant_id=tenant.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, enabled=True, archived=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) db_session_with_containers.add(document) db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, + invalid_dataset_id, + document.id, + tenant.id, + account.id, + status=SegmentStatus.WAITING, ) # Act: Execute the task @@ -317,7 +323,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -337,7 +343,12 @@ class TestCreateSegmentToIndexTask: invalid_document_id = str(uuid4()) segment = self._create_test_segment( - db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting" + db_session_with_containers, + dataset.id, + invalid_document_id, + tenant.id, + account.id, + status=SegmentStatus.WAITING, ) # Act: Execute the task @@ -348,7 +359,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -373,7 +384,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -384,7 +395,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -409,7 +420,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -420,7 +431,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -445,7 +456,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -456,7 +467,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -477,7 +488,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Mock processor to raise exception @@ -488,7 +499,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify error handling db_session_with_containers.refresh(segment) - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.enabled is False assert segment.disabled_at is not None assert segment.error == "Processor failed" @@ -512,7 +523,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) custom_keywords = ["custom", "keywords", "test"] @@ -521,7 +532,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -555,7 +566,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -563,7 +574,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED # Verify correct doc_form was passed to factory mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) @@ -583,7 +594,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task and measure time @@ -597,7 +608,7 @@ class TestCreateSegmentToIndexTask: # Verify successful completion db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED def test_create_segment_to_index_concurrent_execution( self, db_session_with_containers, mock_external_service_dependencies @@ -617,7 +628,7 @@ class TestCreateSegmentToIndexTask: segments = [] for i in range(3): segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) segments.append(segment) @@ -629,7 +640,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify all segments processed for segment in segments: db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -665,7 +676,7 @@ class TestCreateSegmentToIndexTask: keywords=["large", "content", "test"], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -681,7 +692,7 @@ class TestCreateSegmentToIndexTask: assert execution_time < 10.0 # Should complete within 10 seconds db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -700,7 +711,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Set up Redis cache key to simulate indexing in progress @@ -718,7 +729,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify indexing still completed successfully despite Redis failure db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -740,7 +751,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Simulate an error during indexing to trigger rollback path @@ -752,7 +763,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify error handling and rollback db_session_with_containers.refresh(segment) - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.enabled is False assert segment.disabled_at is not None assert segment.error is not None @@ -772,7 +783,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -780,7 +791,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED # Verify index processor was called with correct metadata mock_processor = mock_external_service_dependencies["index_processor"] @@ -814,11 +825,11 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Verify initial state - assert segment.status == "waiting" + assert segment.status == SegmentStatus.WAITING assert segment.indexing_at is None assert segment.completed_at is None @@ -827,7 +838,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify final state db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -861,7 +872,7 @@ class TestCreateSegmentToIndexTask: keywords=[], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -872,7 +883,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -907,7 +918,7 @@ class TestCreateSegmentToIndexTask: keywords=["special", "unicode", "test"], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -918,7 +929,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -937,7 +948,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Create long keyword list @@ -948,7 +959,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -979,10 +990,10 @@ class TestCreateSegmentToIndexTask: ) segment1 = self._create_test_segment( - db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting" + db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status=SegmentStatus.WAITING ) segment2 = self._create_test_segment( - db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting" + db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status=SegmentStatus.WAITING ) # Act: Execute tasks for both tenants @@ -993,8 +1004,8 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.refresh(segment1) db_session_with_containers.refresh(segment2) - assert segment1.status == "completed" - assert segment2.status == "completed" + assert segment1.status == SegmentStatus.COMPLETED + assert segment2.status == SegmentStatus.COMPLETED assert segment1.tenant_id == tenant1.id assert segment2.tenant_id == tenant2.id assert segment1.tenant_id != segment2.tenant_id @@ -1014,7 +1025,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task with None keywords @@ -1022,7 +1033,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -1050,7 +1061,7 @@ class TestCreateSegmentToIndexTask: segments = [] for i in range(5): segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) segments.append(segment) @@ -1067,7 +1078,7 @@ class TestCreateSegmentToIndexTask: # Verify all segments processed successfully for segment in segments: db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.error is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 4a62383590..67f9dc7011 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -11,6 +11,7 @@ from core.indexing_runner import DocumentIsPausedError from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( _document_indexing, _document_indexing_with_tenant_queue, @@ -139,7 +140,7 @@ class TestDatasetIndexingTaskIntegration: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -155,12 +156,12 @@ class TestDatasetIndexingTaskIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=position, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=f"doc-{position}.txt", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -181,7 +182,7 @@ class TestDatasetIndexingTaskIntegration: for document_id in document_ids: updated = self._query_document(db_session_with_containers, document_id) assert updated is not None - assert updated.indexing_status == "parsing" + assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None def _assert_documents_error_contains( @@ -195,7 +196,7 @@ class TestDatasetIndexingTaskIntegration: for document_id in document_ids: updated = self._query_document(db_session_with_containers, document_id) assert updated is not None - assert updated.indexing_status == "error" + assert updated.indexing_status == IndexingStatus.ERROR assert updated.error is not None assert expected_error_substring in updated.error assert updated.stopped_at is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index 10c719fb6d..e80b37ac1b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -90,7 +91,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -102,13 +103,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -150,7 +151,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -162,13 +163,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -182,13 +183,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -209,7 +210,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -220,7 +221,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load method was called mock_factory = mock_index_processor_factory.return_value @@ -251,7 +252,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -263,13 +264,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="parent_child_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -283,13 +284,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="parent_child_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -310,7 +311,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -321,7 +322,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor clean and load methods were called mock_factory = mock_index_processor_factory.return_value @@ -367,7 +368,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -399,7 +400,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -411,13 +412,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -430,7 +431,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify that no index processor load was called since no segments exist mock_factory = mock_index_processor_factory.return_value @@ -455,7 +456,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -488,7 +489,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -500,13 +501,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -520,13 +521,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -547,7 +548,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -563,7 +564,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to error updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Test exception during indexing" in updated_document.error def test_deal_dataset_vector_index_task_with_custom_index_type( @@ -584,7 +585,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -596,13 +597,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="qa_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -623,7 +624,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -634,7 +635,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type mock_index_processor_factory.assert_called_once_with("qa_index") @@ -660,7 +661,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -672,13 +673,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -699,7 +700,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -710,7 +711,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type mock_index_processor_factory.assert_called_once_with("text_model") @@ -736,7 +737,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -748,13 +749,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -770,13 +771,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name=f"Test Document {i}", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -801,7 +802,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{i}_{j}", index_node_hash=f"hash_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -814,7 +815,7 @@ class TestDealDatasetVectorIndexTask: # Verify all documents were processed for document in documents: updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load was called multiple times mock_factory = mock_index_processor_factory.return_value @@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -851,13 +852,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -871,13 +872,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -898,7 +899,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -916,7 +917,7 @@ class TestDealDatasetVectorIndexTask: # Verify final document status updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( self, db_session_with_containers, mock_index_processor_factory, account_and_tenant @@ -936,7 +937,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -948,13 +949,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -968,13 +969,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Enabled Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -987,13 +988,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Disabled Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped archived=False, batch="test_batch", @@ -1015,7 +1016,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1026,13 +1027,13 @@ class TestDealDatasetVectorIndexTask: # Verify only enabled document was processed updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() - assert updated_enabled_document.indexing_status == "completed" + assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED # Verify disabled document status remains unchanged updated_disabled_document = ( db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() ) - assert updated_disabled_document.indexing_status == "completed" # Should not change + assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change # Verify index processor load was called only once (for enabled document) mock_factory = mock_index_processor_factory.return_value @@ -1057,7 +1058,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -1069,13 +1070,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1089,13 +1090,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Active Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1108,13 +1109,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Archived Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # This document should be skipped batch="test_batch", @@ -1136,7 +1137,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1147,13 +1148,13 @@ class TestDealDatasetVectorIndexTask: # Verify only active document was processed updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() - assert updated_active_document.indexing_status == "completed" + assert updated_active_document.indexing_status == IndexingStatus.COMPLETED # Verify archived document status remains unchanged updated_archived_document = ( db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() ) - assert updated_archived_document.indexing_status == "completed" # Should not change + assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change # Verify index processor load was called only once (for active document) mock_factory = mock_index_processor_factory.return_value @@ -1178,7 +1179,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -1190,13 +1191,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1210,13 +1211,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Completed Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1229,13 +1230,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Incomplete Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="indexing", # This document should be skipped + indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, archived=False, batch="test_batch", @@ -1257,7 +1258,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1270,13 +1271,13 @@ class TestDealDatasetVectorIndexTask: updated_completed_document = ( db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() ) - assert updated_completed_document.indexing_status == "completed" + assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED # Verify incomplete document status remains unchanged updated_incomplete_document = ( db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() ) - assert updated_incomplete_document.indexing_status == "indexing" # Should not change + assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change # Verify index processor load was called only once (for completed document) mock_factory = mock_index_processor_factory.return_value diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 69ed5b632d..6fc2a53f9c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -14,6 +14,7 @@ from faker import Faker from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, Document, DocumentSegment, Tenant +from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task logger = logging.getLogger(__name__) @@ -106,7 +107,7 @@ class TestDeleteSegmentFromIndexTask: dataset.description = fake.text(max_nb_chars=200) dataset.provider = "vendor" dataset.permission = "only_me" - dataset.data_source_type = "upload_file" + dataset.data_source_type = DataSourceType.UPLOAD_FILE dataset.indexing_technique = "high_quality" dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id @@ -145,7 +146,7 @@ class TestDeleteSegmentFromIndexTask: document.data_source_info = kwargs.get("data_source_info", "{}") document.batch = kwargs.get("batch", fake.uuid4()) document.name = kwargs.get("name", f"Test Document {fake.word()}") - document.created_from = kwargs.get("created_from", "api") + document.created_from = kwargs.get("created_from", DocumentCreatedFrom.API) document.created_by = account.id document.created_at = fake.date_time_this_year() document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year()) @@ -162,7 +163,7 @@ class TestDeleteSegmentFromIndexTask: document.enabled = kwargs.get("enabled", True) document.archived = kwargs.get("archived", False) document.updated_at = fake.date_time_this_year() - document.doc_type = kwargs.get("doc_type", "text") + document.doc_type = kwargs.get("doc_type", DocumentDocType.PERSONAL_DOCUMENT) document.doc_metadata = kwargs.get("doc_metadata", {}) document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX) document.doc_language = kwargs.get("doc_language", "en") @@ -204,7 +205,7 @@ class TestDeleteSegmentFromIndexTask: segment.index_node_hash = fake.sha256() segment.hit_count = 0 segment.enabled = True - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.created_by = account.id segment.created_at = fake.date_time_this_year() segment.updated_by = account.id @@ -386,7 +387,7 @@ class TestDeleteSegmentFromIndexTask: account = self._create_test_account(db_session_with_containers, tenant, fake) dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) document = self._create_test_document( - db_session_with_containers, dataset, account, fake, indexing_status="indexing" + db_session_with_containers, dataset, account, fake, indexing_status=IndexingStatus.INDEXING ) segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index ab9e5b639a..da42fc7167 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.disable_segment_from_index_task import disable_segment_from_index_task logger = logging.getLogger(__name__) @@ -97,7 +98,7 @@ class TestDisableSegmentFromIndexTask: tenant_id=tenant.id, name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -132,12 +133,12 @@ class TestDisableSegmentFromIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=fake.uuid4(), name=fake.file_name(), - created_from="api", + created_from=DocumentCreatedFrom.API, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form=doc_form, @@ -189,7 +190,7 @@ class TestDisableSegmentFromIndexTask: status=status, enabled=enabled, created_by=account.id, - completed_at=datetime.now(UTC) if status == "completed" else None, + completed_at=datetime.now(UTC) if status == SegmentStatus.COMPLETED else None, ) db_session_with_containers.add(segment) db_session_with_containers.commit() @@ -271,7 +272,7 @@ class TestDisableSegmentFromIndexTask: dataset = self._create_test_dataset(db_session_with_containers, tenant, account) document = self._create_test_document(db_session_with_containers, dataset, tenant, account) segment = self._create_test_segment( - db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True + db_session_with_containers, document, dataset, tenant, account, status=SegmentStatus.INDEXING, enabled=True ) # Act: Execute the task diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 6f7d2c28b5..4bc9bb4749 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import Session from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule +from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus from tasks.disable_segments_from_index_task import disable_segments_from_index_task @@ -100,7 +101,7 @@ class TestDisableSegmentsFromIndexTask: description=fake.text(max_nb_chars=200), provider="vendor", permission="only_me", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, updated_by=account.id, @@ -134,11 +135,11 @@ class TestDisableSegmentsFromIndexTask: document.tenant_id = dataset.tenant_id document.dataset_id = dataset.id document.position = 1 - document.data_source_type = "upload_file" + document.data_source_type = DataSourceType.UPLOAD_FILE document.data_source_info = '{"upload_file_id": "test_file_id"}' document.batch = fake.uuid4() document.name = f"Test Document {fake.word()}.txt" - document.created_from = "upload_file" + document.created_from = DocumentCreatedFrom.WEB document.created_by = account.id document.created_api_request_id = fake.uuid4() document.processing_started_at = fake.date_time_this_year() @@ -197,7 +198,7 @@ class TestDisableSegmentsFromIndexTask: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.created_by = account.id segment.updated_by = account.id segment.indexing_at = fake.date_time_this_year() @@ -230,7 +231,7 @@ class TestDisableSegmentsFromIndexTask: process_rule.id = fake.uuid4() process_rule.tenant_id = dataset.tenant_id process_rule.dataset_id = dataset.id - process_rule.mode = "automatic" + process_rule.mode = ProcessRuleMode.AUTOMATIC process_rule.rules = ( "{" '"mode": "automatic", ' diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index df5c5dc54b..6a17a19a54 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -16,6 +16,7 @@ import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -54,7 +55,7 @@ class DocumentIndexingSyncTaskTestDataFactory: tenant_id=tenant_id, name=f"dataset-{uuid4()}", description="sync test dataset", - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, indexing_technique="high_quality", created_by=created_by, ) @@ -76,11 +77,11 @@ class DocumentIndexingSyncTaskTestDataFactory: tenant_id=tenant_id, dataset_id=dataset_id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps(data_source_info) if data_source_info is not None else None, batch="test-batch", name=f"doc-{uuid4()}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=indexing_status, enabled=True, @@ -113,7 +114,7 @@ class DocumentIndexingSyncTaskTestDataFactory: word_count=10, tokens=5, index_node_id=f"node-{document_id}-{i}", - status="completed", + status=SegmentStatus.COMPLETED, created_by=created_by, ) db_session_with_containers.add(segment) @@ -181,7 +182,7 @@ class TestDocumentIndexingSyncTask: dataset_id=dataset.id, created_by=account.id, data_source_info=notion_info, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) segments = DocumentIndexingSyncTaskTestDataFactory.create_segments( @@ -276,7 +277,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Datasource credential not found" in updated_document.error assert updated_document.stopped_at is not None mock_external_dependencies["indexing_runner"].run.assert_not_called() @@ -301,7 +302,7 @@ class TestDocumentIndexingSyncTask: .count() ) assert updated_document is not None - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED assert updated_document.processing_started_at is None assert remaining_segments == 3 mock_external_dependencies["index_processor"].clean.assert_not_called() @@ -327,7 +328,7 @@ class TestDocumentIndexingSyncTask: ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None assert updated_document.data_source_info_dict.get("last_edited_time") == "2024-01-02T00:00:00Z" assert remaining_segments == 0 @@ -369,7 +370,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_called_once() @@ -393,7 +394,7 @@ class TestDocumentIndexingSyncTask: .count() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert remaining_segments == 0 mock_external_dependencies["indexing_runner"].run.assert_called_once() @@ -412,7 +413,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.error is None def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies): @@ -430,7 +431,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Indexing error" in updated_document.error assert updated_document.stopped_at is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 5dc1f6bee0..9421b07285 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -8,6 +8,7 @@ from core.entities.document_task import DocumentTask from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( _document_indexing, # Core function _document_indexing_with_tenant_queue, # Tenant queue wrapper function @@ -97,7 +98,7 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -112,12 +113,12 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -179,7 +180,7 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -194,12 +195,12 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -250,7 +251,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -320,7 +321,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in existing_document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with only existing documents @@ -367,7 +368,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing close the session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( @@ -397,12 +398,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="completed", # Already completed + indexing_status=IndexingStatus.COMPLETED, # Already completed enabled=True, ) db_session_with_containers.add(doc1) @@ -414,12 +415,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=False, # Disabled ) db_session_with_containers.add(doc2) @@ -444,7 +445,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with all documents @@ -482,12 +483,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=i + 3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -507,7 +508,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error assert updated_document.stopped_at is not None @@ -548,7 +549,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( @@ -591,7 +592,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== @@ -702,7 +703,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -827,7 +828,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify waiting task was still processed despite core processing error diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 9da9a4132e..2fbea1388c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -5,6 +5,7 @@ from faker import Faker from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_update_task import document_indexing_update_task @@ -61,7 +62,7 @@ class TestDocumentIndexingUpdateTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=64), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -72,12 +73,12 @@ class TestDocumentIndexingUpdateTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -98,7 +99,7 @@ class TestDocumentIndexingUpdateTask: word_count=10, tokens=5, index_node_id=node_id, - status="completed", + status=SegmentStatus.COMPLETED, created_by=account.id, ) db_session_with_containers.add(seg) @@ -122,7 +123,7 @@ class TestDocumentIndexingUpdateTask: # Assert document status updated before reindex updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() - assert updated.indexing_status == "parsing" + assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None # Segments should be deleted diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index c61e37b1e9..f1f5a4b105 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -7,6 +7,7 @@ from core.indexing_runner import DocumentIsPausedError from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.duplicate_document_indexing_task import ( _duplicate_document_indexing_task, # Core function _duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function @@ -107,7 +108,7 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -122,12 +123,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -177,7 +178,7 @@ class TestDuplicateDocumentIndexingTasks: content=fake.text(max_nb_chars=200), word_count=50, tokens=100, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, indexing_at=fake.date_time_this_year(), created_by=dataset.created_by, # Add required field @@ -242,7 +243,7 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -257,12 +258,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -316,7 +317,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -368,7 +369,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify indexing runner was called @@ -437,7 +438,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with only existing documents @@ -484,7 +485,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( @@ -516,12 +517,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=i + 3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -542,7 +543,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() assert updated_document.stopped_at is not None @@ -584,7 +585,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "limit" in updated_document.error.lower() assert updated_document.stopped_at is not None @@ -648,7 +649,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_normal_duplicate_document_indexing_task_with_tenant_queue( @@ -691,7 +692,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_priority_duplicate_document_indexing_task_with_tenant_queue( @@ -735,7 +736,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( @@ -851,7 +852,7 @@ class TestDuplicateDocumentIndexingTasks: for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.is_paused is True - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.display_status == "paused" assert updated_document.processing_started_at is not None mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index bc29395545..54b50016a8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -8,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.enable_segments_to_index_task import enable_segments_to_index_task @@ -79,7 +80,7 @@ class TestEnableSegmentsToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -92,12 +93,12 @@ class TestEnableSegmentsToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) @@ -110,7 +111,13 @@ class TestEnableSegmentsToIndexTask: return dataset, document def _create_test_segments( - self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed" + self, + db_session_with_containers: Session, + document, + dataset, + count=3, + enabled=False, + status=SegmentStatus.COMPLETED, ): """ Helper method to create test document segments. @@ -278,7 +285,7 @@ class TestEnableSegmentsToIndexTask: invalid_statuses = [ ("disabled", {"enabled": False}), ("archived", {"archived": True}), - ("not_completed", {"indexing_status": "processing"}), + ("not_completed", {"indexing_status": IndexingStatus.INDEXING}), ] for _, status_attrs in invalid_statuses: @@ -447,7 +454,7 @@ class TestEnableSegmentsToIndexTask: for segment in segments: db_session_with_containers.refresh(segment) assert segment.enabled is False - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.error is not None assert "Index processing failed" in segment.error assert segment.disabled_at is not None diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index 3b8679f4ec..ebbb34e069 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -59,6 +59,44 @@ class TestPipelineTemplateDetailApi: assert status == 200 assert response == template + def test_get_returns_404_when_template_not_found(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_pipeline_template_detail.return_value = None + + with ( + app.test_request_context("/?type=built-in"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "non-existent-id") + + assert status == 404 + assert "error" in response + + def test_get_returns_404_for_customized_type_not_found(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_pipeline_template_detail.return_value = None + + with ( + app.test_request_context("/?type=customized"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "non-existent-id") + + assert status == 404 + assert "error" in response + class TestCustomizedPipelineTemplateApi: def test_patch_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index dbe54ccb99..f23dd5b44a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -30,6 +30,7 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) +from models.enums import DataSourceType, IndexingStatus def unwrap(func): @@ -62,8 +63,8 @@ def document(): return MagicMock( id="doc-1", tenant_id="tenant-1", - indexing_status="indexing", - data_source_type="upload_file", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, doc_form="text", archived=False, @@ -407,7 +408,7 @@ class TestDocumentProcessingApi: api = DocumentProcessingApi() method = unwrap(api.patch) - doc = MagicMock(indexing_status="error", is_paused=True) + doc = MagicMock(indexing_status=IndexingStatus.ERROR, is_paused=True) with ( app.test_request_context("/"), @@ -425,7 +426,7 @@ class TestDocumentProcessingApi: api = DocumentProcessingApi() method = unwrap(api.patch) - document = MagicMock(indexing_status="paused", is_paused=True) + document = MagicMock(indexing_status=IndexingStatus.PAUSED, is_paused=True) with ( app.test_request_context("/"), @@ -461,7 +462,7 @@ class TestDocumentProcessingApi: api = DocumentProcessingApi() method = unwrap(api.patch) - document = MagicMock(indexing_status="completed") + document = MagicMock(indexing_status=IndexingStatus.COMPLETED) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(InvalidActionError): @@ -630,7 +631,7 @@ class TestDocumentRetryApi: payload = {"document_ids": ["doc-1"]} - document = MagicMock(indexing_status="indexing", archived=False) + document = MagicMock(indexing_status=IndexingStatus.INDEXING, archived=False) with ( app.test_request_context("/", json=payload), @@ -659,7 +660,7 @@ class TestDocumentRetryApi: payload = {"document_ids": ["doc-1"]} - document = MagicMock(indexing_status="completed", archived=False) + document = MagicMock(indexing_status=IndexingStatus.COMPLETED, archived=False) with ( app.test_request_context("/", json=payload), @@ -817,8 +818,8 @@ class TestDocumentIndexingEstimateApi: method = unwrap(api.get) document = MagicMock( - indexing_status="indexing", - data_source_type="upload_file", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", doc_form="text", @@ -844,8 +845,8 @@ class TestDocumentIndexingEstimateApi: method = unwrap(api.get) document = MagicMock( - indexing_status="indexing", - data_source_type="upload_file", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", doc_form="text", @@ -882,7 +883,7 @@ class TestDocumentIndexingEstimateApi: api = DocumentIndexingEstimateApi() method = unwrap(api.get) - document = MagicMock(indexing_status="completed") + document = MagicMock(indexing_status=IndexingStatus.COMPLETED) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(DocumentAlreadyFinishedError): @@ -963,8 +964,8 @@ class TestDocumentBatchIndexingEstimateApi: method = unwrap(api.get) doc = MagicMock( - indexing_status="indexing", - data_source_type="website_crawl", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.WEBSITE_CRAWL, data_source_info_dict={ "provider": "firecrawl", "job_id": "j1", @@ -992,8 +993,8 @@ class TestDocumentBatchIndexingEstimateApi: method = unwrap(api.get) doc = MagicMock( - indexing_status="indexing", - data_source_type="notion_import", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info_dict={ "credential_id": "c1", "notion_workspace_id": "w1", @@ -1020,7 +1021,7 @@ class TestDocumentBatchIndexingEstimateApi: method = unwrap(api.get) document = MagicMock( - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, data_source_type="unknown", data_source_info_dict={}, doc_form="text", @@ -1130,7 +1131,7 @@ class TestDocumentProcessingApiResume: api = DocumentProcessingApi() method = unwrap(api.patch) - document = MagicMock(indexing_status="completed", is_paused=False) + document = MagicMock(indexing_status=IndexingStatus.COMPLETED, is_paused=False) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(InvalidActionError): @@ -1348,8 +1349,8 @@ class TestDocumentIndexingEdgeCases: method = unwrap(api.get) document = MagicMock( - indexing_status="indexing", - data_source_type="upload_file", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", doc_form="text", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index dc651a1627..5c48ef1804 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -32,6 +32,7 @@ from controllers.service_api.dataset.segment import ( SegmentListQuery, ) from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -657,12 +658,27 @@ class TestSegmentIndexingRequirements: dataset.indexing_technique = technique assert dataset.indexing_technique in ["high_quality", "economy"] - @pytest.mark.parametrize("status", ["waiting", "parsing", "indexing", "completed", "error"]) + @pytest.mark.parametrize( + "status", + [ + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + ], + ) def test_valid_indexing_statuses(self, status): """Test valid document indexing statuses.""" document = Mock(spec=Document) document.indexing_status = status - assert document.indexing_status in ["waiting", "parsing", "indexing", "completed", "error"] + assert document.indexing_status in { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + } def test_completed_status_required_for_segments(self): """Test that completed status is required for segment operations.""" diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index f98109af79..e6e841be19 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -244,23 +245,26 @@ class TestDocumentService: class TestDocumentIndexingStatus: """Test document indexing status values.""" + _VALID_STATUSES = { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + IndexingStatus.PAUSED, + } + def test_completed_status(self): """Test completed status.""" - status = "completed" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.COMPLETED in self._VALID_STATUSES def test_indexing_status(self): """Test indexing status.""" - status = "indexing" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.INDEXING in self._VALID_STATUSES def test_error_status(self): """Test error status.""" - status = "error" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.ERROR in self._VALID_STATUSES class TestDocumentDocForm: diff --git a/api/tests/unit_tests/controllers/trigger/test_webhook.py b/api/tests/unit_tests/controllers/trigger/test_webhook.py index d633365f2b..91c793d292 100644 --- a/api/tests/unit_tests/controllers/trigger/test_webhook.py +++ b/api/tests/unit_tests/controllers/trigger/test_webhook.py @@ -23,6 +23,7 @@ def mock_jsonify(): class DummyWebhookTrigger: webhook_id = "wh-1" + webhook_url = "http://localhost:5001/triggers/webhook/wh-1" tenant_id = "tenant-1" app_id = "app-1" node_id = "node-1" @@ -104,7 +105,32 @@ class TestHandleWebhookDebug: @patch.object(module.WebhookService, "get_webhook_trigger_and_workflow") @patch.object(module.WebhookService, "extract_and_validate_webhook_data") @patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1}) - @patch.object(module.TriggerDebugEventBus, "dispatch") + @patch.object(module.TriggerDebugEventBus, "dispatch", return_value=0) + def test_debug_requires_active_listener( + self, + mock_dispatch, + mock_build_inputs, + mock_extract, + mock_get, + ): + mock_get.return_value = (DummyWebhookTrigger(), None, "node_config") + mock_extract.return_value = {"method": "POST"} + + response, status = module.handle_webhook_debug("wh-1") + + assert status == 409 + assert response["error"] == "No active debug listener" + assert response["message"] == ( + "The webhook debug URL only works while the Variable Inspector is listening. " + "Use the published webhook URL to execute the workflow in Celery." + ) + assert response["execution_url"] == DummyWebhookTrigger.webhook_url + mock_dispatch.assert_called_once() + + @patch.object(module.WebhookService, "get_webhook_trigger_and_workflow") + @patch.object(module.WebhookService, "extract_and_validate_webhook_data") + @patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1}) + @patch.object(module.TriggerDebugEventBus, "dispatch", return_value=1) @patch.object(module.WebhookService, "generate_webhook_response") def test_debug_success( self, diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 5ebefcd8d2..75473fc89a 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -35,6 +35,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) +from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -514,7 +515,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva id="lb-base", name="LB Base", credentials={}, - credential_source_type="provider", + credential_source_type=CredentialSourceType.PROVIDER, ) ], ), @@ -528,7 +529,7 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva id="lb-custom", name="LB Custom", credentials={}, - credential_source_type="custom_model", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) ], ), @@ -826,7 +827,7 @@ def test_update_load_balancing_configs_updates_all_matching_configs() -> None: configuration._update_load_balancing_configs_with_credential( credential_id="cred-1", credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) @@ -844,7 +845,7 @@ def test_update_load_balancing_configs_returns_when_no_matching_configs() -> Non configuration._update_load_balancing_configs_with_credential( credential_id="cred-1", credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"), - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index 2451db70b6..e6cc582398 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -236,7 +236,8 @@ class TestParagraphIndexProcessor: "core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve" ) as mock_retrieve: mock_retrieve.return_value = [accepted, rejected] - docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].metadata["score"] == 0.9 diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index abe40f05d1..5c78cae7c1 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -307,7 +307,8 @@ class TestParentChildIndexProcessor: "core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve" ) as mock_retrieve: mock_retrieve.return_value = [ok_result, low_result] - docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].page_content == "keep" diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 8596647ef3..99323eeec9 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -262,7 +262,8 @@ class TestQAIndexProcessor: with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve: mock_retrieve.return_value = [result_ok, result_low] - docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {}) + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) assert len(docs) == 1 assert docs[0].page_content == "accepted" diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index d61f01c616..665e98bd9c 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -25,6 +25,7 @@ from core.app.app_config.entities import ModelConfig as WorkflowModelConfig from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus +from core.rag.data_post_processor.data_post_processor import WeightsDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -4686,7 +4687,10 @@ class TestSingleAndMultipleRetrieveCoverage: extra={"dataset_name": "Ext", "title": "Ext"}, ) app = Flask(__name__) - weights = {"vector_setting": {}} + weights: WeightsDict = { + "vector_setting": {"vector_weight": 0.5, "embedding_provider_name": "", "embedding_model_name": ""}, + "keyword_setting": {"keyword_weight": 0.5}, + } def fake_multiple_thread(**kwargs): if kwargs["query"]: diff --git a/api/tests/unit_tests/models/test_account_models.py b/api/tests/unit_tests/models/test_account_models.py index 1726fc2e8b..f48db77bb5 100644 --- a/api/tests/unit_tests/models/test_account_models.py +++ b/api/tests/unit_tests/models/test_account_models.py @@ -622,28 +622,10 @@ class TestAccountGetByOpenId: mock_account = Account(name="Test User", email="test@example.com") mock_account.id = account_id - # Mock the query chain - mock_query = MagicMock() - mock_where = MagicMock() - mock_where.one_or_none.return_value = mock_account_integrate - mock_query.where.return_value = mock_where - mock_db.session.query.return_value = mock_query - - # Mock the second query for account - mock_account_query = MagicMock() - mock_account_where = MagicMock() - mock_account_where.one_or_none.return_value = mock_account - mock_account_query.where.return_value = mock_account_where - - # Setup query to return different results based on model - def query_side_effect(model): - if model.__name__ == "AccountIntegrate": - return mock_query - elif model.__name__ == "Account": - return mock_account_query - return MagicMock() - - mock_db.session.query.side_effect = query_side_effect + # Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup + mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate + # Mock db.session.scalar() for Account lookup + mock_db.session.scalar.return_value = mock_account # Act result = Account.get_by_openid(provider, open_id) @@ -658,12 +640,8 @@ class TestAccountGetByOpenId: provider = "github" open_id = "github_user_456" - # Mock the query chain to return None - mock_query = MagicMock() - mock_where = MagicMock() - mock_where.one_or_none.return_value = None - mock_query.where.return_value = mock_where - mock_db.session.query.return_value = mock_query + # Mock db.session.execute().scalar_one_or_none() to return None + mock_db.session.execute.return_value.scalar_one_or_none.return_value = None # Act result = Account.get_by_openid(provider, open_id) diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 6c619dcf98..329fe554ea 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -300,10 +300,8 @@ class TestAppModelConfig: created_by=str(uuid4()), ) - # Mock database query to return None - with patch("models.model.db.session.query", autospec=True) as mock_query: - mock_query.return_value.where.return_value.first.return_value = None - + # Mock database scalar to return None (no annotation setting found) + with patch("models.model.db.session.scalar", return_value=None): # Act result = config.annotation_reply_dict @@ -951,10 +949,8 @@ class TestSiteModel: def test_site_generate_code(self): """Test Site.generate_code static method.""" - # Mock database query to return 0 (no existing codes) - with patch("models.model.db.session.query", autospec=True) as mock_query: - mock_query.return_value.where.return_value.count.return_value = 0 - + # Mock database scalar to return 0 (no existing codes) + with patch("models.model.db.session.scalar", return_value=0): # Act code = Site.generate_code(8) diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 9bb7c05a91..98dd07907a 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -25,6 +25,13 @@ from models.dataset import ( DocumentSegment, Embedding, ) +from models.enums import ( + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, +) class TestDatasetModelValidation: @@ -40,14 +47,14 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) # Assert assert dataset.name == "Test Dataset" assert dataset.tenant_id == tenant_id - assert dataset.data_source_type == "upload_file" + assert dataset.data_source_type == DataSourceType.UPLOAD_FILE assert dataset.created_by == created_by # Note: Default values are set by database, not by model instantiation @@ -57,7 +64,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), description="Test description", indexing_technique="high_quality", @@ -77,14 +84,14 @@ class TestDatasetModelValidation: dataset_high_quality = Dataset( tenant_id=str(uuid4()), name="High Quality Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), indexing_technique="high_quality", ) dataset_economy = Dataset( tenant_id=str(uuid4()), name="Economy Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), indexing_technique="economy", ) @@ -101,14 +108,14 @@ class TestDatasetModelValidation: dataset_vendor = Dataset( tenant_id=str(uuid4()), name="Vendor Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), provider="vendor", ) dataset_external = Dataset( tenant_id=str(uuid4()), name="External Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), provider="external", ) @@ -126,7 +133,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), index_struct=json.dumps(index_struct_data), ) @@ -145,7 +152,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -161,7 +168,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -178,7 +185,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -218,10 +225,10 @@ class TestDocumentModelRelationships: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test_document.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) @@ -229,10 +236,10 @@ class TestDocumentModelRelationships: assert document.tenant_id == tenant_id assert document.dataset_id == dataset_id assert document.position == 1 - assert document.data_source_type == "upload_file" + assert document.data_source_type == DataSourceType.UPLOAD_FILE assert document.batch == "batch_001" assert document.name == "test_document.pdf" - assert document.created_from == "web" + assert document.created_from == DocumentCreatedFrom.WEB assert document.created_by == created_by # Note: Default values are set by database, not by model instantiation @@ -250,12 +257,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, ) # Act @@ -271,12 +278,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="parsing", + indexing_status=IndexingStatus.PARSING, is_paused=True, ) @@ -289,15 +296,20 @@ class TestDocumentModelRelationships: def test_document_display_status_indexing(self): """Test document display_status property for indexing state.""" # Arrange - for indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + for indexing_status in [ + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + ]: document = Document( tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), indexing_status=indexing_status, ) @@ -315,12 +327,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) # Act @@ -336,12 +348,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -359,12 +371,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, ) @@ -382,12 +394,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, archived=True, ) @@ -405,10 +417,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), data_source_info=json.dumps(data_source_info), ) @@ -428,10 +440,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), ) @@ -448,10 +460,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), word_count=1000, ) @@ -471,10 +483,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), word_count=0, ) @@ -582,7 +594,7 @@ class TestDocumentSegmentIndexing: word_count=1, tokens=2, created_by=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, ) segment_completed = DocumentSegment( tenant_id=str(uuid4()), @@ -593,12 +605,12 @@ class TestDocumentSegmentIndexing: word_count=1, tokens=2, created_by=str(uuid4()), - status="completed", + status=SegmentStatus.COMPLETED, ) # Assert - assert segment_waiting.status == "waiting" - assert segment_completed.status == "completed" + assert segment_waiting.status == SegmentStatus.WAITING + assert segment_completed.status == SegmentStatus.COMPLETED def test_document_segment_enabled_disabled_tracking(self): """Test document segment enabled/disabled state tracking.""" @@ -769,13 +781,13 @@ class TestDatasetProcessRule: # Act process_rule = DatasetProcessRule( dataset_id=dataset_id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, created_by=created_by, ) # Assert assert process_rule.dataset_id == dataset_id - assert process_rule.mode == "automatic" + assert process_rule.mode == ProcessRuleMode.AUTOMATIC assert process_rule.created_by == created_by def test_dataset_process_rule_modes(self): @@ -797,7 +809,7 @@ class TestDatasetProcessRule: } process_rule = DatasetProcessRule( dataset_id=str(uuid4()), - mode="custom", + mode=ProcessRuleMode.CUSTOM, created_by=str(uuid4()), rules=json.dumps(rules_data), ) @@ -817,7 +829,7 @@ class TestDatasetProcessRule: rules_data = {"test": "data"} process_rule = DatasetProcessRule( dataset_id=dataset_id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, created_by=str(uuid4()), rules=json.dumps(rules_data), ) @@ -827,7 +839,7 @@ class TestDatasetProcessRule: # Assert assert result["dataset_id"] == dataset_id - assert result["mode"] == "automatic" + assert result["mode"] == ProcessRuleMode.AUTOMATIC assert result["rules"] == rules_data def test_dataset_process_rule_automatic_rules(self): @@ -969,7 +981,7 @@ class TestModelIntegration: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, indexing_technique="high_quality", ) @@ -980,10 +992,10 @@ class TestModelIntegration: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=100, ) @@ -999,7 +1011,7 @@ class TestModelIntegration: word_count=3, tokens=5, created_by=created_by, - status="completed", + status=SegmentStatus.COMPLETED, ) # Assert @@ -1009,7 +1021,7 @@ class TestModelIntegration: assert segment.document_id == document_id assert dataset.indexing_technique == "high_quality" assert document.word_count == 100 - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED def test_document_to_dict_serialization(self): """Test document to_dict method for serialization.""" @@ -1022,13 +1034,13 @@ class TestModelIntegration: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=100, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Mock segment_count and hit_count @@ -1044,6 +1056,6 @@ class TestModelIntegration: assert result["dataset_id"] == dataset_id assert result["name"] == "test.pdf" assert result["word_count"] == 100 - assert result["indexing_status"] == "completed" + assert result["indexing_status"] == IndexingStatus.COMPLETED assert result["segment_count"] == 5 assert result["hit_count"] == 10 diff --git a/api/tests/unit_tests/models/test_enums_creator_user_role.py b/api/tests/unit_tests/models/test_enums_creator_user_role.py new file mode 100644 index 0000000000..6317166fdc --- /dev/null +++ b/api/tests/unit_tests/models/test_enums_creator_user_role.py @@ -0,0 +1,19 @@ +import pytest + +from models.enums import CreatorUserRole + + +def test_creator_user_role_missing_maps_hyphen_to_enum(): + # given an alias with hyphen + value = "end-user" + + # when converting to enum (invokes StrEnum._missing_ override) + role = CreatorUserRole(value) + + # then it should map to END_USER + assert role is CreatorUserRole.END_USER + + +def test_creator_user_role_missing_raises_for_unknown(): + with pytest.raises(ValueError): + CreatorUserRole("unknown") diff --git a/api/tests/unit_tests/models/test_provider_models.py b/api/tests/unit_tests/models/test_provider_models.py index ec84a61c8e..f628e54a4d 100644 --- a/api/tests/unit_tests/models/test_provider_models.py +++ b/api/tests/unit_tests/models/test_provider_models.py @@ -19,6 +19,7 @@ from uuid import uuid4 import pytest +from models.enums import CredentialSourceType, PaymentStatus from models.provider import ( LoadBalancingModelConfig, Provider, @@ -158,7 +159,7 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == provider_name - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_used == 0 @@ -172,10 +173,10 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="anthropic", - provider_type="system", + provider_type=ProviderType.SYSTEM, is_valid=True, credential_id=credential_id, - quota_type="paid", + quota_type=ProviderQuotaType.PAID, quota_limit=10000, quota_used=500, ) @@ -183,10 +184,10 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == "anthropic" - assert provider.provider_type == "system" + assert provider.provider_type == ProviderType.SYSTEM assert provider.is_valid is True assert provider.credential_id == credential_id - assert provider.quota_type == "paid" + assert provider.quota_type == ProviderQuotaType.PAID assert provider.quota_limit == 10000 assert provider.quota_used == 500 @@ -199,7 +200,7 @@ class TestProviderModel: ) # Assert - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_type == "" assert provider.quota_limit is None @@ -213,7 +214,7 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="openai", - provider_type="custom", + provider_type=ProviderType.CUSTOM, ) # Act @@ -253,7 +254,7 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, is_valid=True, ) @@ -266,13 +267,13 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - quota_type="trial", + quota_type=ProviderQuotaType.TRIAL, quota_limit=1000, quota_used=250, ) # Assert - assert provider.quota_type == "trial" + assert provider.quota_type == ProviderQuotaType.TRIAL assert provider.quota_limit == 1000 assert provider.quota_used == 250 remaining = provider.quota_limit - provider.quota_used @@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=tenant_id, provider_name="openai", - preferred_provider_type="custom", + preferred_provider_type=ProviderType.CUSTOM, ) # Assert assert preferred.tenant_id == tenant_id assert preferred.provider_name == "openai" - assert preferred.preferred_provider_type == "custom" + assert preferred.preferred_provider_type == ProviderType.CUSTOM def test_tenant_preferred_provider_system_type(self): """Test tenant preferred provider with system type.""" @@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=str(uuid4()), provider_name="anthropic", - preferred_provider_type="system", + preferred_provider_type=ProviderType.SYSTEM, ) # Assert - assert preferred.preferred_provider_type == "system" + assert preferred.preferred_provider_type == ProviderType.SYSTEM class TestProviderOrder: @@ -470,7 +471,7 @@ class TestProviderOrder: quantity=1, currency=None, total_amount=None, - payment_status="wait_pay", + payment_status=PaymentStatus.WAIT_PAY, paid_at=None, pay_failed_at=None, refunded_at=None, @@ -481,7 +482,7 @@ class TestProviderOrder: assert order.provider_name == "openai" assert order.account_id == account_id assert order.payment_product_id == "prod_123" - assert order.payment_status == "wait_pay" + assert order.payment_status == PaymentStatus.WAIT_PAY assert order.quantity == 1 def test_provider_order_with_payment_details(self): @@ -502,7 +503,7 @@ class TestProviderOrder: quantity=5, currency="USD", total_amount=9999, - payment_status="paid", + payment_status=PaymentStatus.PAID, paid_at=paid_time, pay_failed_at=None, refunded_at=None, @@ -514,7 +515,7 @@ class TestProviderOrder: assert order.quantity == 5 assert order.currency == "USD" assert order.total_amount == 9999 - assert order.payment_status == "paid" + assert order.payment_status == PaymentStatus.PAID assert order.paid_at == paid_time def test_provider_order_payment_statuses(self): @@ -536,23 +537,23 @@ class TestProviderOrder: } # Act & Assert - Wait pay status - wait_order = ProviderOrder(**base_params, payment_status="wait_pay") - assert wait_order.payment_status == "wait_pay" + wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY) + assert wait_order.payment_status == PaymentStatus.WAIT_PAY # Act & Assert - Paid status - paid_order = ProviderOrder(**base_params, payment_status="paid") - assert paid_order.payment_status == "paid" + paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID) + assert paid_order.payment_status == PaymentStatus.PAID # Act & Assert - Failed status failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)} - failed_order = ProviderOrder(**failed_params, payment_status="failed") - assert failed_order.payment_status == "failed" + failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED) + assert failed_order.payment_status == PaymentStatus.FAILED assert failed_order.pay_failed_at is not None # Act & Assert - Refunded status refunded_params = {**base_params, "refunded_at": datetime.now(UTC)} - refunded_order = ProviderOrder(**refunded_params, payment_status="refunded") - assert refunded_order.payment_status == "refunded" + refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED) + assert refunded_order.payment_status == PaymentStatus.REFUNDED assert refunded_order.refunded_at is not None @@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig: name="Secondary API Key", encrypted_config='{"api_key": "encrypted_value"}', credential_id=credential_id, - credential_source_type="custom", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) # Assert assert config.encrypted_config == '{"api_key": "encrypted_value"}' assert config.credential_id == credential_id - assert config.credential_source_type == "custom" + assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL def test_load_balancing_config_disabled(self): """Test disabled load balancing config.""" diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index c7e1fed21f..be64e431ba 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock import pytest import services.summary_index_service as summary_module +from models.enums import SegmentStatus, SummaryStatus from services.summary_index_service import SummaryIndexService @@ -42,7 +43,7 @@ def _segment(*, has_document: bool = True) -> MagicMock: segment.dataset_id = "dataset-1" segment.content = "hello world" segment.enabled = True - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.position = 1 if has_document: doc = MagicMock(name="document") @@ -64,7 +65,7 @@ def _summary_record(*, summary_content: str = "summary", node_id: str | None = N record.summary_index_node_id = node_id record.summary_index_node_hash = None record.tokens = None - record.status = "generating" + record.status = SummaryStatus.GENERATING record.error = None record.enabled = True record.created_at = datetime(2024, 1, 1, tzinfo=UTC) @@ -133,10 +134,10 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes segment = _segment() dataset = _dataset() - result = SummaryIndexService.create_summary_record(segment, dataset, "new", status="generating") + result = SummaryIndexService.create_summary_record(segment, dataset, "new", status=SummaryStatus.GENERATING) assert result is existing assert existing.summary_content == "new" - assert existing.status == "generating" + assert existing.status == SummaryStatus.GENERATING assert existing.enabled is True assert existing.disabled_at is None assert existing.disabled_by is None @@ -155,7 +156,7 @@ def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> N create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) - record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status="generating") + record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status=SummaryStatus.GENERATING) assert record.dataset_id == "dataset-1" assert record.chunk_id == "seg-1" assert record.summary_content == "new" @@ -204,7 +205,7 @@ def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: assert vector_instance.add_texts.call_count == 2 summary_module.time.sleep.assert_called_once() # type: ignore[attr-defined] session.flush.assert_called_once() - assert summary.status == "completed" + assert summary.status == SummaryStatus.COMPLETED assert summary.summary_index_node_id == "uuid-1" assert summary.summary_index_node_hash == "hash-1" assert summary.tokens == 5 @@ -245,7 +246,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat create_session_mock.assert_called() session.add.assert_called() session.commit.assert_called_once() - assert summary.status == "completed" + assert summary.status == SummaryStatus.COMPLETED assert summary.summary_index_node_id == "old-node" # reused @@ -275,7 +276,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes with pytest.raises(RuntimeError, match="boom"): SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) - assert summary.status == "error" + assert summary.status == SummaryStatus.ERROR assert "Vectorization failed" in (summary.error or "") error_session.commit.assert_called_once() @@ -310,7 +311,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), ) - SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status="not_started") + SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status=SummaryStatus.NOT_STARTED) session.commit.assert_called_once() assert existing.enabled is True @@ -332,7 +333,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon ) SummaryIndexService.update_summary_record_error(segment, dataset, "err") - assert record.status == "error" + assert record.status == SummaryStatus.ERROR assert record.error == "err" session.commit.assert_called_once() @@ -387,7 +388,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch with pytest.raises(RuntimeError, match="boom"): SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) - assert record.status == "error" + assert record.status == SummaryStatus.ERROR # Outer exception handler overwrites the error with the raw exception message. assert record.error == "boom" @@ -614,7 +615,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo monkeypatch.setattr(summary_module, "logger", logger_mock) result = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) - assert result.status in {"generating", "completed"} + assert result.status in {SummaryStatus.GENERATING, SummaryStatus.COMPLETED} logger_mock.info.assert_called() @@ -787,7 +788,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt segment = _segment() segment.id = summary.chunk_id segment.enabled = True - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED session = MagicMock() summary_query = MagicMock() @@ -850,11 +851,11 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect bad_segment = _segment() bad_segment.enabled = False - bad_segment.status = "completed" + bad_segment.status = SegmentStatus.COMPLETED good_segment = _segment() good_segment.enabled = True - good_segment.status = "completed" + good_segment.status = SegmentStatus.COMPLETED session = MagicMock() summary_query = MagicMock() @@ -1084,7 +1085,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") assert out is record - assert out.status == "error" + assert out.status == SummaryStatus.ERROR assert "Vectorization failed" in (out.error or "") @@ -1133,7 +1134,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk with pytest.raises(RuntimeError, match="flush boom"): SummaryIndexService.update_summary_for_segment(segment, dataset, "new") - assert record.status == "error" + assert record.status == SummaryStatus.ERROR assert record.error == "flush boom" session.commit.assert_called() @@ -1222,7 +1223,7 @@ def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: py monkeypatch.setattr( SummaryIndexService, "get_segments_summaries", - MagicMock(return_value={"seg-1": SimpleNamespace(status="completed")}), + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.COMPLETED)}), ) result = SummaryIndexService.get_documents_summary_index_status(["doc-1"], "dataset-1", "tenant-1") assert result["doc-1"] is None @@ -1254,7 +1255,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") - assert out.status == "error" + assert out.status == SummaryStatus.ERROR assert "Vectorization failed" in (out.error or "") @@ -1276,7 +1277,7 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt monkeypatch.setattr( SummaryIndexService, "get_segments_summaries", - MagicMock(return_value={"seg-1": SimpleNamespace(status="generating")}), + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.GENERATING)}), ) assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING" @@ -1294,7 +1295,7 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt monkeypatch.setattr( SummaryIndexService, "get_segments_summaries", - MagicMock(return_value={"seg-1": SimpleNamespace(status="not_started")}), + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.NOT_STARTED)}), ) result = SummaryIndexService.get_documents_summary_index_status(["doc-1", "doc-2"], "dataset-1", "tenant-1") assert result["doc-1"] == "SUMMARIZING" @@ -1311,7 +1312,7 @@ def test_get_document_summary_status_detail_counts_and_previews(monkeypatch: pyt summary1 = _summary_record(summary_content="x" * 150, node_id="n1") summary1.chunk_id = "seg-1" - summary1.status = "completed" + summary1.status = SummaryStatus.COMPLETED summary1.error = None summary1.created_at = datetime(2024, 1, 1, tzinfo=UTC) summary1.updated_at = datetime(2024, 1, 2, tzinfo=UTC) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index df33f20c9b..74ba7f9c34 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest +from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task # ============================================================================ @@ -116,7 +117,7 @@ def mock_document(): doc.id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) - doc.data_source_type = "upload_file" + doc.data_source_type = DataSourceType.UPLOAD_FILE doc.data_source_info = '{"upload_file_id": "test-file-id"}' doc.data_source_info_dict = {"upload_file_id": "test-file-id"} return doc diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 67e0a8efaf..8a721124d6 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -19,6 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.dataset import Dataset, Document +from models.enums import IndexingStatus from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy from tasks.document_indexing_task import ( _document_indexing, @@ -424,7 +425,7 @@ class TestBatchProcessing: # Assert - All documents should be set to 'parsing' status for doc in mock_documents: - assert doc.indexing_status == "parsing" + assert doc.indexing_status == IndexingStatus.PARSING assert doc.processing_started_at is not None # IndexingRunner should be called with all documents @@ -573,7 +574,7 @@ class TestProgressTracking: # Assert - Status should be 'parsing' for doc in mock_documents: - assert doc.indexing_status == "parsing" + assert doc.indexing_status == IndexingStatus.PARSING assert doc.processing_started_at is not None # Verify commit was called to persist status @@ -1158,7 +1159,7 @@ class TestAdvancedScenarios: # Assert # All documents should be set to parsing (no limit errors) for doc in mock_documents: - assert doc.indexing_status == "parsing" + assert doc.indexing_status == IndexingStatus.PARSING # IndexingRunner should be called with all documents mock_indexing_runner.run.assert_called_once() @@ -1377,7 +1378,7 @@ class TestPerformanceScenarios: # Assert for doc in mock_documents: - assert doc.indexing_status == "parsing" + assert doc.indexing_status == IndexingStatus.PARSING mock_indexing_runner.run.assert_called_once() call_args = mock_indexing_runner.run.call_args[0][0] diff --git a/api/uv.lock b/api/uv.lock index 877c6141ab..ddb70f6b54 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -5154,11 +5154,11 @@ wheels = [ [[package]] name = "pyasn1" -version = "0.6.2" +version = "0.6.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/b6/6e630dff89739fcd427e3f72b3d905ce0acb85a45d4ec3e2678718a3487f/pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b", size = 146586, upload-time = "2026-01-16T18:04:18.534Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5f/6583902b6f79b399c9c40674ac384fd9cd77805f9e6205075f828ef11fb2/pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf", size = 148685, upload-time = "2026-03-17T01:06:53.382Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/b5/a96872e5184f354da9c84ae119971a0a4c221fe9b27a4d94bd43f2596727/pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf", size = 83371, upload-time = "2026-01-16T18:04:17.174Z" }, + { url = "https://files.pythonhosted.org/packages/5d/a0/7d793dce3fa811fe047d6ae2431c672364b462850c6235ae306c0efd025f/pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde", size = 83997, upload-time = "2026-03-17T01:06:52.036Z" }, ] [[package]] diff --git a/web/__tests__/apps/app-card-operations-flow.test.tsx b/web/__tests__/apps/app-card-operations-flow.test.tsx index c3e8410955..c5766878a1 100644 --- a/web/__tests__/apps/app-card-operations-flow.test.tsx +++ b/web/__tests__/apps/app-card-operations-flow.test.tsx @@ -29,7 +29,7 @@ const mockOnPlanInfoChanged = vi.fn() const mockDeleteAppMutation = vi.fn().mockResolvedValue(undefined) let mockDeleteMutationPending = false -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, }), @@ -57,7 +57,7 @@ vi.mock('@headlessui/react', async () => { } }) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { let Component: React.ComponentType> | null = null loader().then((mod) => { diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 079f667dbc..1be7e56086 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -38,7 +38,7 @@ let mockShowTagManagementModal = false const mockRouterPush = vi.fn() const mockRouterReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, replace: mockRouterReplace, @@ -46,7 +46,7 @@ vi.mock('next/navigation', () => ({ useSearchParams: () => new URLSearchParams(), })) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (_loader: () => Promise<{ default: React.ComponentType }>) => { const LazyComponent = (props: Record) => { return
diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index 4ac9824ddd..bc1f7a3a06 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -35,7 +35,7 @@ const mockRouterPush = vi.fn() const mockRouterReplace = vi.fn() const mockOnPlanInfoChanged = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, replace: mockRouterReplace, @@ -117,7 +117,7 @@ vi.mock('ahooks', async () => { }) // Mock dynamically loaded modals with test stubs -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { let Component: React.ComponentType> | null = null loader().then((mod) => { diff --git a/web/__tests__/billing/billing-integration.test.tsx b/web/__tests__/billing/billing-integration.test.tsx index 4891760df4..64d358cbe6 100644 --- a/web/__tests__/billing/billing-integration.test.tsx +++ b/web/__tests__/billing/billing-integration.test.tsx @@ -64,7 +64,7 @@ vi.mock('@/service/use-education', () => ({ // ─── Navigation mocks ─────────────────────────────────────────────────────── const mockRouterPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx index e01d9250fd..84653cd68c 100644 --- a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx +++ b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx @@ -11,6 +11,7 @@ import type { BasicPlan } from '@/app/components/billing/type' import { cleanup, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { ALL_PLANS } from '@/app/components/billing/config' import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher' import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item' @@ -21,7 +22,6 @@ let mockAppCtx: Record = {} const mockFetchSubscriptionUrls = vi.fn() const mockInvoices = vi.fn() const mockOpenAsyncWindow = vi.fn() -const mockToastNotify = vi.fn() // ─── Context mocks ─────────────────────────────────────────────────────────── vi.mock('@/context/app-context', () => ({ @@ -49,12 +49,8 @@ vi.mock('@/hooks/use-async-window-open', () => ({ useAsyncWindowOpen: () => mockOpenAsyncWindow, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (args: unknown) => mockToastNotify(args) }, -})) - // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), @@ -82,12 +78,15 @@ const renderCloudPlanItem = ({ canPay = true, }: RenderCloudPlanItemOptions = {}) => { return render( - , + <> + + + , ) } @@ -96,6 +95,7 @@ describe('Cloud Plan Payment Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() + toast.close() setupAppContext() mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' }) mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' }) @@ -283,11 +283,7 @@ describe('Cloud Plan Payment Flow', () => { await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) // Should not proceed with payment expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled() diff --git a/web/__tests__/billing/education-verification-flow.test.tsx b/web/__tests__/billing/education-verification-flow.test.tsx index 8c35cd9a8c..707f1d690a 100644 --- a/web/__tests__/billing/education-verification-flow.test.tsx +++ b/web/__tests__/billing/education-verification-flow.test.tsx @@ -63,7 +63,7 @@ vi.mock('@/service/use-billing', () => ({ })) // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/partner-stack-flow.test.tsx b/web/__tests__/billing/partner-stack-flow.test.tsx index 4f265478cd..fe642ac70b 100644 --- a/web/__tests__/billing/partner-stack-flow.test.tsx +++ b/web/__tests__/billing/partner-stack-flow.test.tsx @@ -18,7 +18,7 @@ let mockSearchParams = new URLSearchParams() const mockMutateAsync = vi.fn() // ─── Module mocks ──────────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => mockSearchParams, useRouter: () => ({ push: vi.fn() }), usePathname: () => '/', diff --git a/web/__tests__/billing/pricing-modal-flow.test.tsx b/web/__tests__/billing/pricing-modal-flow.test.tsx index 6b8fb57f83..2ec7298618 100644 --- a/web/__tests__/billing/pricing-modal-flow.test.tsx +++ b/web/__tests__/billing/pricing-modal-flow.test.tsx @@ -51,7 +51,7 @@ vi.mock('@/hooks/use-async-window-open', () => ({ })) // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), @@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => { }) }) - // ─── 6. Close Handling ─────────────────────────────────────────────────── - describe('Close handling', () => { - it('should call onCancel when pressing ESC key', () => { - render() - - // ahooks useKeyPress listens on document for keydown events - document.dispatchEvent(new KeyboardEvent('keydown', { - key: 'Escape', - code: 'Escape', - keyCode: 27, - bubbles: true, - })) - - expect(onCancel).toHaveBeenCalledTimes(1) - }) - }) - - // ─── 7. Pricing URL ───────────────────────────────────────────────────── + // ─── 6. Pricing URL ───────────────────────────────────────────────────── describe('Pricing page URL', () => { it('should render pricing link with correct URL', () => { render() diff --git a/web/__tests__/billing/self-hosted-plan-flow.test.tsx b/web/__tests__/billing/self-hosted-plan-flow.test.tsx index 810d36da8a..0802b760e1 100644 --- a/web/__tests__/billing/self-hosted-plan-flow.test.tsx +++ b/web/__tests__/billing/self-hosted-plan-flow.test.tsx @@ -10,12 +10,12 @@ import { cleanup, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '@/app/components/billing/config' import SelfHostedPlanItem from '@/app/components/billing/pricing/plans/self-hosted-plan-item' import { SelfHostedPlan } from '@/app/components/billing/type' let mockAppCtx: Record = {} -const mockToastNotify = vi.fn() const originalLocation = window.location let assignedHref = '' @@ -40,10 +40,6 @@ vi.mock('@/app/components/base/icons/src/public/billing', () => ({ AwsMarketplaceDark: () => , })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (args: unknown) => mockToastNotify(args) }, -})) - vi.mock('@/app/components/billing/pricing/plans/self-hosted-plan-item/list', () => ({ default: ({ plan }: { plan: string }) => (
Features
@@ -57,10 +53,20 @@ const setupAppContext = (overrides: Record = {}) => { } } +const renderSelfHostedPlanItem = (plan: SelfHostedPlan) => { + return render( + <> + + + , + ) +} + describe('Self-Hosted Plan Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() + toast.close() setupAppContext() // Mock window.location with minimal getter/setter (Location props are non-enumerable) @@ -85,14 +91,14 @@ describe('Self-Hosted Plan Flow', () => { // ─── 1. Plan Rendering ────────────────────────────────────────────────── describe('Plan rendering', () => { it('should render community plan with name and description', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.getByText(/plans\.community\.name/i)).toBeInTheDocument() expect(screen.getByText(/plans\.community\.description/i)).toBeInTheDocument() }) it('should render premium plan with cloud provider icons', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByText(/plans\.premium\.name/i)).toBeInTheDocument() expect(screen.getByTestId('icon-azure')).toBeInTheDocument() @@ -100,39 +106,39 @@ describe('Self-Hosted Plan Flow', () => { }) it('should render enterprise plan without cloud provider icons', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) expect(screen.getByText(/plans\.enterprise\.name/i)).toBeInTheDocument() expect(screen.queryByTestId('icon-azure')).not.toBeInTheDocument() }) it('should not show price tip for community (free) plan', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.queryByText(/plans\.community\.priceTip/i)).not.toBeInTheDocument() }) it('should show price tip for premium plan', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByText(/plans\.premium\.priceTip/i)).toBeInTheDocument() }) it('should render features list for each plan', () => { - const { unmount: unmount1 } = render() + const { unmount: unmount1 } = renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.getByTestId('self-hosted-list-community')).toBeInTheDocument() unmount1() - const { unmount: unmount2 } = render() + const { unmount: unmount2 } = renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByTestId('self-hosted-list-premium')).toBeInTheDocument() unmount2() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) expect(screen.getByTestId('self-hosted-list-enterprise')).toBeInTheDocument() }) it('should show AWS marketplace icon for premium plan button', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByTestId('icon-aws-light')).toBeInTheDocument() }) @@ -142,7 +148,7 @@ describe('Self-Hosted Plan Flow', () => { describe('Navigation flow', () => { it('should redirect to GitHub when clicking community plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) const button = screen.getByRole('button') await user.click(button) @@ -152,7 +158,7 @@ describe('Self-Hosted Plan Flow', () => { it('should redirect to AWS Marketplace when clicking premium plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) const button = screen.getByRole('button') await user.click(button) @@ -162,7 +168,7 @@ describe('Self-Hosted Plan Flow', () => { it('should redirect to Typeform when clicking enterprise plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) const button = screen.getByRole('button') await user.click(button) @@ -176,15 +182,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks community button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) // Should NOT redirect expect(assignedHref).toBe('') @@ -193,15 +197,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks premium button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) expect(assignedHref).toBe('') }) @@ -209,15 +211,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks enterprise button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) expect(assignedHref).toBe('') }) diff --git a/web/__tests__/datasets/document-management.test.tsx b/web/__tests__/datasets/document-management.test.tsx index 8aedd4fc63..f9d80520ed 100644 --- a/web/__tests__/datasets/document-management.test.tsx +++ b/web/__tests__/datasets/document-management.test.tsx @@ -13,7 +13,7 @@ import { DataSourceType } from '@/models/datasets' import { renderHookWithNuqs } from '@/test/nuqs-testing' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => new URLSearchParams(''), useRouter: () => ({ push: mockPush }), usePathname: () => '/datasets/ds-1/documents', diff --git a/web/__tests__/document-detail-navigation-fix.test.tsx b/web/__tests__/document-detail-navigation-fix.test.tsx index 6b348cd15b..5cb115830e 100644 --- a/web/__tests__/document-detail-navigation-fix.test.tsx +++ b/web/__tests__/document-detail-navigation-fix.test.tsx @@ -7,12 +7,12 @@ import type { Mock } from 'vitest' */ import { fireEvent, render, screen } from '@testing-library/react' -import { useRouter } from 'next/navigation' +import { useRouter } from '@/next/navigation' import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document' // Mock Next.js router const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(() => ({ push: mockPush, })), diff --git a/web/__tests__/embedded-user-id-auth.test.tsx b/web/__tests__/embedded-user-id-auth.test.tsx index 9231ac6199..cacd6331f8 100644 --- a/web/__tests__/embedded-user-id-auth.test.tsx +++ b/web/__tests__/embedded-user-id-auth.test.tsx @@ -8,7 +8,7 @@ const replaceMock = vi.fn() const backMock = vi.fn() const useSearchParamsMock = vi.fn(() => new URLSearchParams()) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(() => '/chatbot/test-app'), useRouter: vi.fn(() => ({ replace: replaceMock, diff --git a/web/__tests__/embedded-user-id-store.test.tsx b/web/__tests__/embedded-user-id-store.test.tsx index 901218e76b..04597ccfeb 100644 --- a/web/__tests__/embedded-user-id-store.test.tsx +++ b/web/__tests__/embedded-user-id-store.test.tsx @@ -4,7 +4,7 @@ import WebAppStoreProvider, { useWebAppStore } from '@/context/web-app-context' import { AccessMode } from '@/models/access-control' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(() => '/chatbot/sample-app'), useSearchParams: vi.fn(() => { const params = new URLSearchParams() diff --git a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx index e2c18bcc4f..77f493ab18 100644 --- a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx +++ b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx @@ -19,7 +19,7 @@ const mockUninstall = vi.fn() const mockUpdatePinStatus = vi.fn() let mockInstalledApps: InstalledApp[] = [] -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegments: () => mockSegments, useRouter: () => ({ push: mockPush, diff --git a/web/__tests__/plugins/plugin-card-rendering.test.tsx b/web/__tests__/plugins/plugin-card-rendering.test.tsx index 7abcb01b49..5bd7f0c8bf 100644 --- a/web/__tests__/plugins/plugin-card-rendering.test.tsx +++ b/web/__tests__/plugins/plugin-card-rendering.test.tsx @@ -8,6 +8,8 @@ import { cleanup, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' +let mockTheme = 'light' + vi.mock('#i18n', () => ({ useTranslation: () => ({ t: (key: string) => key, @@ -19,16 +21,16 @@ vi.mock('@/context/i18n', () => ({ })) vi.mock('@/hooks/use-theme', () => ({ - default: () => ({ theme: 'light' }), + default: () => ({ theme: mockTheme }), })) vi.mock('@/i18n-config', () => ({ renderI18nObject: (obj: Record, locale: string) => obj[locale] || obj.en_US || '', })) -vi.mock('@/types/app', () => ({ - Theme: { dark: 'dark', light: 'light' }, -})) +vi.mock('@/types/app', async () => { + return vi.importActual('@/types/app') +}) vi.mock('@/utils/classnames', () => ({ cn: (...args: unknown[]) => args.filter(a => typeof a === 'string' && a).join(' '), @@ -100,6 +102,7 @@ type CardPayload = Parameters[0]['payload'] describe('Plugin Card Rendering Integration', () => { beforeEach(() => { cleanup() + mockTheme = 'light' }) const makePayload = (overrides = {}) => ({ @@ -194,9 +197,7 @@ describe('Plugin Card Rendering Integration', () => { }) it('uses dark icon when theme is dark and icon_dark is provided', () => { - vi.doMock('@/hooks/use-theme', () => ({ - default: () => ({ theme: 'dark' }), - })) + mockTheme = 'dark' const payload = makePayload({ icon: 'https://example.com/icon-light.png', @@ -204,7 +205,7 @@ describe('Plugin Card Rendering Integration', () => { }) render() - expect(screen.getByTestId('card-icon')).toBeInTheDocument() + expect(screen.getByTestId('card-icon')).toHaveTextContent('https://example.com/icon-dark.png') }) it('shows loading placeholder when isLoading is true', () => { diff --git a/web/__tests__/share/text-generation-index-flow.test.tsx b/web/__tests__/share/text-generation-index-flow.test.tsx index 3292474bec..2fec054a47 100644 --- a/web/__tests__/share/text-generation-index-flow.test.tsx +++ b/web/__tests__/share/text-generation-index-flow.test.tsx @@ -5,7 +5,7 @@ import TextGeneration from '@/app/components/share/text-generation' const useSearchParamsMock = vi.fn(() => new URLSearchParams()) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => useSearchParamsMock(), })) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index fd0bf2c8bd..0c87fd1a4d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -13,8 +13,6 @@ import { RiTerminalWindowLine, } from '@remixicon/react' import { useUnmount } from 'ahooks' -import dynamic from 'next/dynamic' -import { usePathname, useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -26,6 +24,8 @@ import { useStore as useTagStore } from '@/app/components/base/tag-management/st import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import dynamic from '@/next/dynamic' +import { usePathname, useRouter } from '@/next/navigation' import { fetchAppDetailDirect } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 5e7d98d191..4201d11490 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -7,7 +7,6 @@ import { RiEqualizer2Line, } from '@remixicon/react' import { useBoolean } from 'ahooks' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -17,6 +16,7 @@ import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import Indicator from '@/app/components/header/indicator' import { useAppContext } from '@/context/app-context' +import { usePathname } from '@/next/navigation' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' import { cn } from '@/utils/classnames' import ConfigButton from './config-button' diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 4f3f724e62..730b76ee19 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -9,7 +9,6 @@ import { RiFocus2Fill, RiFocus2Line, } from '@remixicon/react' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -23,6 +22,7 @@ import DatasetDetailContext from '@/context/dataset-detail' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import { usePathname } from '@/next/navigation' import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' diff --git a/web/app/(commonLayout)/datasets/layout.spec.tsx b/web/app/(commonLayout)/datasets/layout.spec.tsx index 5873f344d0..9c01cffba8 100644 --- a/web/app/(commonLayout)/datasets/layout.spec.tsx +++ b/web/app/(commonLayout)/datasets/layout.spec.tsx @@ -6,7 +6,7 @@ import DatasetsLayout from './layout' const mockReplace = vi.fn() const mockUseAppContext = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), diff --git a/web/app/(commonLayout)/datasets/layout.tsx b/web/app/(commonLayout)/datasets/layout.tsx index b543c42570..a465f8222b 100644 --- a/web/app/(commonLayout)/datasets/layout.tsx +++ b/web/app/(commonLayout)/datasets/layout.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter } from 'next/navigation' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' import { ExternalApiPanelProvider } from '@/context/external-api-panel-context' import { ExternalKnowledgeApiProvider } from '@/context/external-knowledge-api-context' +import { useRouter } from '@/next/navigation' export default function DatasetsLayout({ children }: { children: React.ReactNode }) { const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, currentWorkspace, isLoadingCurrentWorkspace } = useAppContext() diff --git a/web/app/(commonLayout)/education-apply/page.tsx b/web/app/(commonLayout)/education-apply/page.tsx index fce6fe1d5d..44ba5ee8ad 100644 --- a/web/app/(commonLayout)/education-apply/page.tsx +++ b/web/app/(commonLayout)/education-apply/page.tsx @@ -1,15 +1,15 @@ 'use client' -import { - useRouter, - useSearchParams, -} from 'next/navigation' import { useEffect, useMemo, } from 'react' import EducationApplyPage from '@/app/education-apply/education-apply-page' import { useProviderContext } from '@/context/provider-context' +import { + useRouter, + useSearchParams, +} from '@/next/navigation' export default function EducationApply() { const router = useRouter() diff --git a/web/app/(commonLayout)/role-route-guard.spec.tsx b/web/app/(commonLayout)/role-route-guard.spec.tsx index 87bf9be8af..ca1550f0b8 100644 --- a/web/app/(commonLayout)/role-route-guard.spec.tsx +++ b/web/app/(commonLayout)/role-route-guard.spec.tsx @@ -6,7 +6,7 @@ const mockReplace = vi.fn() const mockUseAppContext = vi.fn() let mockPathname = '/apps' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname, useRouter: () => ({ replace: mockReplace, diff --git a/web/app/(commonLayout)/role-route-guard.tsx b/web/app/(commonLayout)/role-route-guard.tsx index 1c42be9d15..483dfef095 100644 --- a/web/app/(commonLayout)/role-route-guard.tsx +++ b/web/app/(commonLayout)/role-route-guard.tsx @@ -1,10 +1,10 @@ 'use client' import type { ReactNode } from 'react' -import { usePathname, useRouter } from 'next/navigation' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' +import { usePathname, useRouter } from '@/next/navigation' const datasetOperatorRedirectRoutes = ['/apps', '/app', '/explore', '/tools'] as const diff --git a/web/app/(humanInputLayout)/form/[token]/form.tsx b/web/app/(humanInputLayout)/form/[token]/form.tsx index d027ef8b7d..035da6be8a 100644 --- a/web/app/(humanInputLayout)/form/[token]/form.tsx +++ b/web/app/(humanInputLayout)/form/[token]/form.tsx @@ -9,7 +9,6 @@ import { RiInformation2Fill, } from '@remixicon/react' import { produce } from 'immer' -import { useParams } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -21,6 +20,7 @@ import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-inp import Loading from '@/app/components/base/loading' import DifyLogo from '@/app/components/base/logo/dify-logo' import useDocumentTitle from '@/hooks/use-document-title' +import { useParams } from '@/next/navigation' import { useGetHumanInputForm, useSubmitHumanInputForm } from '@/service/use-share' import { cn } from '@/utils/classnames' diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index c874990448..9f956a8501 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -1,12 +1,12 @@ 'use client' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useWebAppStore } from '@/context/web-app-context' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { useGetUserCanAccessApp } from '@/service/access-control' import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share' import { webAppLogout } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index a2b847f74f..402005752d 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport, webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index fbf45259e5..a0aa86e35b 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -1,14 +1,14 @@ 'use client' import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' - import { useLocale } from '@/context/i18n' + +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common' export default function CheckCode() { diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index 9b9a853cdd..3763e0bb2a 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -1,8 +1,6 @@ 'use client' import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import Link from 'next/link' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -10,9 +8,11 @@ import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' - import { useLocale } from '@/context/i18n' import useDocumentTitle from '@/hooks/use-document-title' + +import Link from '@/next/link' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendResetPasswordCode } from '@/service/common' export default function CheckCode() { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 9f59e8f9eb..1a97f6440b 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -1,13 +1,13 @@ 'use client' import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { validPassword } from '@/config' +import { useRouter, useSearchParams } from '@/next/navigation' import { changeWebAppPasswordWithToken } from '@/service/common' import { cn } from '@/utils/classnames' diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index afea9d668b..81b7c1b9a6 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -1,7 +1,6 @@ 'use client' import type { FormEvent } from 'react' import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -10,6 +9,7 @@ import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx index 0776df036d..391479c870 100644 --- a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect } from 'react' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import { useGlobalPublicStore } from '@/context/global-public-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' import { SSOProtocol } from '@/types/feature' diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index 5aa9d9f141..b350549784 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -1,5 +1,4 @@ import { noop } from 'es-toolkit/function' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -8,6 +7,7 @@ import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppEMailLoginCode } from '@/service/common' export default function MailAndCodeAuth() { diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index e49559401d..87419438e3 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -1,7 +1,5 @@ 'use client' import { noop } from 'es-toolkit/function' -import Link from 'next/link' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -10,6 +8,8 @@ import Toast from '@/app/components/base/toast' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' +import Link from '@/next/link' +import { useRouter, useSearchParams } from '@/next/navigation' import { webAppLogin } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx index d8f3854868..79d67dde5c 100644 --- a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' import Toast from '@/app/components/base/toast' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchMembersOAuth2SSOUrl, fetchMembersOIDCSSOUrl, fetchMembersSAMLSSOUrl } from '@/service/share' import { SSOProtocol } from '@/types/feature' diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index b15145346f..7ee08d66ae 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -1,12 +1,12 @@ 'use client' import { RiContractLine, RiDoorLockLine, RiErrorWarningFill } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' +import Link from '@/next/link' import { LicenseStatus } from '@/types/feature' import { cn } from '@/utils/classnames' import MailAndCodeAuth from './components/mail-and-code-auth' diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx index b3ad1d48a6..a5c2528cc7 100644 --- a/web/app/(shareLayout)/webapp-signin/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -1,6 +1,5 @@ 'use client' import type { FC } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' @@ -8,6 +7,7 @@ import AppUnavailable from '@/app/components/base/app-unavailable' import { useGlobalPublicStore } from '@/context/global-public-context' import { useWebAppStore } from '@/context/web-app-context' import { AccessMode } from '@/models/access-control' +import { useRouter, useSearchParams } from '@/next/navigation' import { webAppLogout } from '@/service/webapp-auth' import ExternalMemberSsoAuth from './components/external-member-sso-auth' import NormalForm from './normalForm' diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 9bd32d2576..3fc677d8d8 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -160,7 +160,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { isShow={isShowDeleteConfirm} onClose={() => setIsShowDeleteConfirm(false)} > -
{t('avatar.deleteTitle', { ns: 'common' })}
+
{t('avatar.deleteTitle', { ns: 'common' })}

{t('avatar.deleteDescription', { ns: 'common' })}

diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index 463c27294a..f0dfd4f12f 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -1,7 +1,6 @@ import type { ResponseError } from '@/service/fetch' import { RiCloseLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { Trans, useTranslation } from 'react-i18next' @@ -10,6 +9,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import { ToastContext } from '@/app/components/base/toast/context' +import { useRouter } from '@/next/navigation' import { checkEmailExisted, resetEmail, @@ -209,14 +209,14 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
{step === STEP.start && ( <> -
{t('account.changeEmail.title', { ns: 'common' })}
+
{t('account.changeEmail.title', { ns: 'common' })}
-
{t('account.changeEmail.authTip', { ns: 'common' })}
-
+
{t('account.changeEmail.authTip', { ns: 'common' })}
+
}} + components={{ email: }} values={{ email }} />
@@ -241,19 +241,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { )} {step === STEP.verifyOrigin && ( <> -
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
+
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
-
+
}} + components={{ email: }} values={{ email }} />
-
{t('account.changeEmail.codeLabel', { ns: 'common' })}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
{ {t('operation.cancel', { ns: 'common' })}
-
+
{t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('account.changeEmail.resend', { ns: 'common' })} + {t('account.changeEmail.resend', { ns: 'common' })} )}
)} {step === STEP.newEmail && ( <> -
{t('account.changeEmail.newEmail', { ns: 'common' })}
+
{t('account.changeEmail.newEmail', { ns: 'common' })}
-
{t('account.changeEmail.content3', { ns: 'common' })}
+
{t('account.changeEmail.content3', { ns: 'common' })}
-
{t('account.changeEmail.emailLabel', { ns: 'common' })}
+
{t('account.changeEmail.emailLabel', { ns: 'common' })}
{ destructive={newEmailExited || unAvailableEmail} /> {newEmailExited && ( -
{t('account.changeEmail.existingEmail', { ns: 'common' })}
+
{t('account.changeEmail.existingEmail', { ns: 'common' })}
)} {unAvailableEmail && ( -
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
+
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
)}
@@ -331,19 +331,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { )} {step === STEP.verifyNew && ( <> -
{t('account.changeEmail.verifyNew', { ns: 'common' })}
+
{t('account.changeEmail.verifyNew', { ns: 'common' })}
-
+
}} + components={{ email: }} values={{ email: mail }} />
-
{t('account.changeEmail.codeLabel', { ns: 'common' })}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
{ {t('operation.cancel', { ns: 'common' })}
-
+
{t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('account.changeEmail.resend', { ns: 'common' })} + {t('account.changeEmail.resend', { ns: 'common' })} )}
diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index 58331e3a77..9a104619da 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -145,7 +145,7 @@ export default function AccountPage() { imageUrl={icon_url} />
-
{item.name}
+
{item.name}
) } @@ -153,12 +153,12 @@ export default function AccountPage() { return ( <>
-

{t('account.myAccount', { ns: 'common' })}

+

{t('account.myAccount', { ns: 'common' })}

-

+

{userProfile.name} {isEducationAccount && ( @@ -167,16 +167,16 @@ export default function AccountPage() { )}

-

{userProfile.email}

+

{userProfile.email}

{t('account.name', { ns: 'common' })}
-
+
{userProfile.name}
-
+
{t('operation.edit', { ns: 'common' })}
@@ -184,11 +184,11 @@ export default function AccountPage() {
{t('account.email', { ns: 'common' })}
-
+
{userProfile.email}
{systemFeatures.enable_change_email && ( -
setShowUpdateEmail(true)}> +
setShowUpdateEmail(true)}> {t('operation.change', { ns: 'common' })}
)} @@ -198,8 +198,8 @@ export default function AccountPage() { systemFeatures.enable_email_password_login && (
-
{t('account.password', { ns: 'common' })}
-
{t('account.passwordTip', { ns: 'common' })}
+
{t('account.password', { ns: 'common' })}
+
{t('account.passwordTip', { ns: 'common' })}
@@ -226,7 +226,7 @@ export default function AccountPage() { onClose={() => setEditNameModalVisible(false)} className="!w-[420px] !p-6" > -
{t('account.editName', { ns: 'common' })}
+
{t('account.editName', { ns: 'common' })}
{t('account.name', { ns: 'common' })}
-
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
+
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
{userProfile.is_password_set && ( <>
{t('account.currentPassword', { ns: 'common' })}
@@ -279,7 +279,7 @@ export default function AccountPage() {
)} -
+
{userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
@@ -298,7 +298,7 @@ export default function AccountPage() {
-
{t('account.confirmPassword', { ns: 'common' })}
+
{t('account.confirmPassword', { ns: 'common' })}
{ diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 835a1e702e..30cfdd25d3 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -7,16 +7,16 @@ import { RiMailLine, RiTranslate2, } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' +import { useRouter, useSearchParams } from '@/next/navigation' import { useIsLogin, useUserProfile } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' @@ -91,9 +91,9 @@ export default function OAuthAuthorize() { globalThis.location.href = url.toString() } catch (err: any) { - Toast.notify({ + toast.add({ type: 'error', - message: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`, + title: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`, }) } } @@ -102,10 +102,10 @@ export default function OAuthAuthorize() { const invalidParams = !client_id || !redirect_uri if ((invalidParams || isError) && !hasNotifiedRef.current) { hasNotifiedRef.current = true - Toast.notify({ + toast.add({ type: 'error', - message: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), - duration: 0, + title: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), + timeout: 0, }) } }, [client_id, redirect_uri, isError]) diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index 421b816652..418d3b8bb1 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter, useSearchParams } from 'next/navigation' import { useEffect } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' - import useDocumentTitle from '@/hooks/use-document-title' + +import { useRouter, useSearchParams } from '@/next/navigation' import { useInvitationCheck } from '@/service/use-common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/browser-initializer.spec.ts b/web/app/components/__tests__/browser-initializer.spec.ts similarity index 100% rename from web/app/components/browser-initializer.spec.ts rename to web/app/components/__tests__/browser-initializer.spec.ts diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index bf7aa39580..e08ece6666 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -2,13 +2,13 @@ import type { ReactNode } from 'react' import Cookies from 'js-cookie' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import { parseAsBoolean, useQueryState } from 'nuqs' import { useCallback, useEffect, useState } from 'react' import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { sendGAEvent } from '@/utils/gtag' import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' diff --git a/web/app/components/app-sidebar/__tests__/index.spec.tsx b/web/app/components/app-sidebar/__tests__/index.spec.tsx index 89db80e0f1..b2e1e92bbb 100644 --- a/web/app/components/app-sidebar/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/index.spec.tsx @@ -19,7 +19,7 @@ vi.mock('zustand/react/shallow', () => ({ useShallow: (fn: unknown) => fn, })) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname, })) diff --git a/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx index fb19833dd2..a3868a8330 100644 --- a/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx @@ -7,7 +7,7 @@ import { render } from '@testing-library/react' import * as React from 'react' // Mock Next.js navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx index f8612e8057..2f98089e40 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import { AppModeEnum } from '@/types/app' import AppInfoModals from '../app-info-modals' -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { const LazyComp = React.lazy(loader) return function DynamicWrapper(props: Record) { diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts index 6104e2b641..deea28ce3e 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -23,7 +23,7 @@ let mockAppDetail: Record | undefined = { icon_background: '#FFEAD5', } -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace }), })) diff --git a/web/app/components/app-sidebar/app-info/app-info-modals.tsx b/web/app/components/app-sidebar/app-info/app-info-modals.tsx index 4ca7f6adbc..232afb18c7 100644 --- a/web/app/components/app-sidebar/app-info/app-info-modals.tsx +++ b/web/app/components/app-sidebar/app-info/app-info-modals.tsx @@ -3,9 +3,9 @@ import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-moda import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' import type { App, AppSSO } from '@/types/app' -import dynamic from 'next/dynamic' import * as React from 'react' import { useTranslation } from 'react-i18next' +import dynamic from '@/next/dynamic' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false }) const CreateAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), { ssr: false }) diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts index 800f21de44..55ec13e506 100644 --- a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -1,7 +1,6 @@ import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' -import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -9,6 +8,7 @@ import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast/context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useProviderContext } from '@/context/provider-context' +import { useRouter } from '@/next/navigation' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' import { useInvalidateAppList } from '@/service/use-apps' import { fetchWorkflowDraft } from '@/service/workflow' diff --git a/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx index 512f9490c2..1df6fa79b7 100644 --- a/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx @@ -80,7 +80,7 @@ const createDataset = (overrides: Partial = {}): DataSet => ({ ...overrides, }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace }), })) diff --git a/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx index be27e247d7..a1e275d731 100644 --- a/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx @@ -90,7 +90,7 @@ const createDataset = (overrides: Partial = {}): DataSet => ({ ...overrides, }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index 96127c4210..528bac831f 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -1,11 +1,11 @@ import type { DataSet } from '@/models/datasets' import { RiMoreFill } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import { useRouter } from '@/next/navigation' import { checkIsUsedInApp, deleteDataset } from '@/service/datasets' import { datasetDetailQueryKeyPrefix, useInvalidDatasetList } from '@/service/knowledge/use-dataset' import { useInvalid } from '@/service/use-base' diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index e24b005d01..13fde97f89 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -1,12 +1,12 @@ import type { NavIcon } from './nav-link' import { useHover, useKeyPress } from 'ahooks' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { usePathname } from '@/next/navigation' import { cn } from '@/utils/classnames' import Divider from '../base/divider' import { getKeyboardKeyCodeBySystem } from '../workflow/utils' diff --git a/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx index 04ca7bd0e4..fe46290002 100644 --- a/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx +++ b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx @@ -4,12 +4,12 @@ import * as React from 'react' import NavLink from '..' // Mock Next.js navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) // Mock Next.js Link component -vi.mock('next/link', () => ({ +vi.mock('@/next/link', () => ({ default: function MockLink({ children, href, className, title }: { children: React.ReactNode, href: string, className?: string, title?: string }) { return ( diff --git a/web/app/components/app-sidebar/nav-link/index.tsx b/web/app/components/app-sidebar/nav-link/index.tsx index d69ed8590e..cf986a7407 100644 --- a/web/app/components/app-sidebar/nav-link/index.tsx +++ b/web/app/components/app-sidebar/nav-link/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { RemixiconComponentType } from '@remixicon/react' -import Link from 'next/link' -import { useSelectedLayoutSegment } from 'next/navigation' import * as React from 'react' +import Link from '@/next/link' +import { useSelectedLayoutSegment } from '@/next/navigation' import { cn } from '@/utils/classnames' export type NavIcon = React.ComponentType< diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index 118eaea58e..a969b3d491 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -94,7 +94,7 @@ const CSVUploader: FC = ({ />
{!file && ( -
+
diff --git a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx index be4377bfd9..abcf5795d0 100644 --- a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx +++ b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx @@ -2,25 +2,19 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' import HasNotSetAPI from './has-not-set-api' -describe('HasNotSetAPI WarningMask', () => { - it('should show default title when trial not finished', () => { - render() +describe('HasNotSetAPI', () => { + it('should render the empty state copy', () => { + render() - expect(screen.getByText('appDebug.notSetAPIKey.title')).toBeInTheDocument() - expect(screen.getByText('appDebug.notSetAPIKey.description')).toBeInTheDocument() + expect(screen.getByText('appDebug.noModelProviderConfigured')).toBeInTheDocument() + expect(screen.getByText('appDebug.noModelProviderConfiguredTip')).toBeInTheDocument() }) - it('should show trail finished title when flag is true', () => { - render() - - expect(screen.getByText('appDebug.notSetAPIKey.trailFinished')).toBeInTheDocument() - }) - - it('should call onSetting when primary button clicked', () => { + it('should call onSetting when manage models button is clicked', () => { const onSetting = vi.fn() - render() + render() - fireEvent.click(screen.getByRole('button', { name: 'appDebug.notSetAPIKey.settingBtn' })) + fireEvent.click(screen.getByRole('button', { name: 'appDebug.manageModels' })) expect(onSetting).toHaveBeenCalledTimes(1) }) }) diff --git a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx index 84323e64f5..2c5fc5ff2f 100644 --- a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx +++ b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx @@ -2,38 +2,38 @@ import type { FC } from 'react' import * as React from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import WarningMask from '.' export type IHasNotSetAPIProps = { - isTrailFinished: boolean onSetting: () => void } -const icon = ( - - - - -) - const HasNotSetAPI: FC = ({ - isTrailFinished, onSetting, }) => { const { t } = useTranslation() return ( - - {t('notSetAPIKey.settingBtn', { ns: 'appDebug' })} - {icon} - - )} - /> +
+
+
+
+ +
+
+
+
{t('noModelProviderConfigured', { ns: 'appDebug' })}
+
{t('noModelProviderConfiguredTip', { ns: 'appDebug' })}
+
+ +
+
) } export default React.memo(HasNotSetAPI) diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index c33d55873d..39a1699063 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -178,7 +178,7 @@ const Prompt: FC = ({ {!noTitle && (
-
{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}
+
{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}
{!readonly && ( = ({ )}
-
+
= (
-
{t('codegen.instruction', { ns: 'appDebug' })}
+
{t('codegen.instruction', { ns: 'appDebug' })}
= ( disabled={isLoading} > - {t('codegen.generate', { ns: 'appDebug' })} + {t('codegen.generate', { ns: 'appDebug' })}
diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx index 7f71247d56..8c6e626b45 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import ContextVar from './index' // Mock external dependencies only -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx index aa8dae813f..6704fa0afd 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import VarPicker from './var-picker' // Mock external dependencies only -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 91e5353cc4..8c2fb77c20 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import type { DataSet } from '@/models/datasets' import { useInfiniteScroll } from 'ahooks' -import Link from 'next/link' import * as React from 'react' import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -14,6 +13,7 @@ import Modal from '@/app/components/base/modal' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import FeatureIcon from '@/app/components/header/account-setting/model-provider-page/model-selector/feature-icon' import { useKnowledge } from '@/hooks/use-knowledge' +import Link from '@/next/link' import { useInfiniteDatasets } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index c9c8d080f2..bc534599de 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -210,7 +210,7 @@ const SettingsModal: FC = ({
-
{t('form.name', { ns: 'datasetSettings' })}
+
{t('form.name', { ns: 'datasetSettings' })}
= ({
-
{t('form.desc', { ns: 'datasetSettings' })}
+
{t('form.desc', { ns: 'datasetSettings' })}