Merge remote-tracking branch 'origin/main' into yanli/fix-iter-log

This commit is contained in:
Yanli 盐粒 2026-03-18 17:51:26 +08:00
commit dea90b0ccd
839 changed files with 26780 additions and 8907 deletions

View File

@ -63,7 +63,8 @@ pnpm analyze-component <path> --review
### File Naming
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory
- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`.
- Integration tests: `web/__tests__/` directory
## Test Structure Template

View File

@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = vi.fn()
// vi.mock('next/navigation', () => ({
// vi.mock('@/next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))

View File

@ -29,8 +29,8 @@ jobs:
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
shardIndex: [1, 2, 3, 4, 5, 6]
shardTotal: [6]
defaults:
run:
shell: bash

View File

@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -242,7 +243,7 @@ def migrate_knowledge_vector_database():
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
@ -254,7 +255,7 @@ def migrate_knowledge_vector_database():
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.status == SegmentStatus.COMPLETED,
DocumentSegment.enabled == True,
)
).all()
@ -430,7 +431,7 @@ def old_metadata_migration():
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
name=key,
type="string",
type=DatasetMetadataType.STRING,
created_by=document.created_by,
)
db.session.add(dataset_metadata)

View File

@ -1,4 +1,4 @@
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
from pydantic_settings import BaseSettings
@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
description="Maximum connections in the Redis connection pool (unset for library default)",
default=None,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):
"""Allow empty string in env/.env to mean 'unset' (None)."""
if v is None:
return None
if isinstance(v, str) and v.strip() == "":
return None
return v

View File

@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -741,13 +742,15 @@ class DatasetIndexingStatusApi(Resource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields

View File

@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.enums import IndexingStatus, SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
document.completed_segments = completed_segments
@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = []
for document in documents:
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource):
match action:
case "pause":
if document.indexing_status != "indexing":
if document.indexing_status != IndexingStatus.INDEXING:
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
case "resume":
if document.indexing_status not in {"paused", "error"}:
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource):
raise ArchivedDocumentImmutableError()
# 400 if document is completed
if document.indexing_status == "completed":
if document.indexing_status == IndexingStatus.COMPLETED:
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception:

View File

@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
if pipeline_template is None:
return {"error": "Pipeline template not found from upstream service."}, 404
return pipeline_template, 200

View File

@ -36,6 +36,7 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields

View File

@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def handle_webhook_debug(webhook_id: str):
"""Handle webhook debug calls without triggering production workflow execution."""
"""Handle webhook debug calls without triggering production workflow execution.
The debug webhook endpoint is only for draft inspection flows. It never enqueues
Celery work for the published workflow; instead it dispatches an in-memory debug
event to an active Variable Inspector listener. Returning a clear error when no
listener is registered prevents a misleading 200 response for requests that are
effectively dropped.
"""
try:
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
if error:
@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
"method": webhook_data.get("method"),
},
)
TriggerDebugEventBus.dispatch(
dispatch_count = TriggerDebugEventBus.dispatch(
tenant_id=webhook_trigger.tenant_id,
event=event,
pool_key=pool_key,
)
if dispatch_count == 0:
logger.warning(
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
webhook_trigger.webhook_id,
webhook_trigger.tenant_id,
webhook_trigger.app_id,
webhook_trigger.node_id,
)
return (
jsonify(
{
"error": "No active debug listener",
"message": (
"The webhook debug URL only works while the Variable Inspector is listening. "
"Use the published webhook URL to execute the workflow in Celery."
),
"execution_url": webhook_trigger.webhook_url,
}
),
409,
)
response_data, status_code = WebhookService.generate_webhook_response(node_config)
return jsonify(response_data), status_code

View File

@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
agent_thoughts = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tool_names_raw = agent_thought.tool

View File

@ -1,13 +1,36 @@
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
class SystemParametersDict(TypedDict):
image_file_size_limit: int
video_file_size_limit: int
audio_file_size_limit: int
file_size_limit: int
workflow_file_upload_limit: int
class AppParametersDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: dict[str, Any]
speech_to_text: dict[str, Any]
text_to_speech: dict[str, Any]
retriever_resource: dict[str, Any]
annotation_reply: dict[str, Any]
more_like_this: dict[str, Any]
user_input_form: list[dict[str, Any]]
sensitive_word_avoidance: dict[str, Any]
file_upload: dict[str, Any]
system_parameters: SystemParametersDict
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
) -> AppParametersDict:
"""
Mapping from feature dict to webapp parameters
"""

View File

@ -8,6 +8,7 @@ from core.app.app_config.entities import (
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from models.model import AppMode, AppModelConfigDict
from services.dataset_service import DatasetService
@ -117,8 +118,10 @@ class DatasetConfigManager:
score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None,
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=weights_val if isinstance(weights_val, dict) else None,
reranking_model=cast(RerankingModelDict, reranking_model_val)
if isinstance(reranking_model_val, dict)
else None,
weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None,
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=cast(

View File

@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from dify_graph.file import FileUploadConfig
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel):
top_k: int | None = None
score_threshold: float | None = 0.0
rerank_mode: str | None = "reranking_model"
reranking_model: dict | None = None
weights: dict | None = None
reranking_model: RerankingModelDict | None = None
weights: WeightsDict | None = None
reranking_enabled: bool | None = True
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None

View File

@ -3,7 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from typing import Any, NewType, TypedDict, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -76,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str)
logger = logging.getLogger(__name__)
class AccountCreatedByDict(TypedDict):
id: str
name: str
email: str
class EndUserCreatedByDict(TypedDict):
id: str
user: str
CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict
@dataclass(slots=True)
class _NodeSnapshot:
"""In-memory cache for node metadata between start and completion events."""
@ -249,19 +263,19 @@ class WorkflowResponseConverter:
outputs_mapping = graph_runtime_state.outputs or {}
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
created_by: Mapping[str, object] | None
created_by: CreatedByDict | dict[str, object] = {}
user = self._user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
created_by = AccountCreatedByDict(
id=user.id,
name=user.name,
email=user.email,
)
elif isinstance(user, EndUser):
created_by = EndUserCreatedByDict(
id=user.id,
user=user.session_id,
)
return WorkflowFinishStreamResponse(
task_id=task_id,

View File

@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
@ -43,7 +44,7 @@ class AnnotationReplyFeature:
embedding_model_name = collection_binding_detail.model_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
)
dataset = Dataset(

View File

@ -1,3 +1,5 @@
from typing import TypedDict
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
@ -6,7 +8,20 @@ from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
class MessageFileInfoDict(TypedDict):
related_id: str
extension: str
filename: str
size: int
mime_type: str
transfer_method: str
type: str
url: str
upload_file_id: str
remote_url: str | None
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict:
"""
Prepare file dictionary for message end stream response.

View File

@ -12,7 +12,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
_logger = logging.getLogger(__name__)
@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source=DatasetQuerySource.APP,
source_app_id=self._app_id,
created_by_role=(
CreatorUserRole.ACCOUNT

View File

@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
from models.provider import (
LoadBalancingModelConfig,
Provider,
@ -546,7 +547,7 @@ class ProviderConfiguration(BaseModel):
self._update_load_balancing_configs_with_credential(
credential_id=credential_id,
credential_record=credential_record,
credential_source="provider",
credential_source=CredentialSourceType.PROVIDER,
session=session,
)
except Exception:
@ -623,7 +624,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
)
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
try:
@ -1043,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
self._update_load_balancing_configs_with_credential(
credential_id=credential_id,
credential_record=credential_record,
credential_source="custom_model",
credential_source=CredentialSourceType.CUSTOM_MODEL,
session=session,
)
except Exception:
@ -1073,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
)
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
@ -1711,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
provider_model_lb_configs = [
config
for config in model_setting.load_balancing_configs
if config.credential_source_type != "custom_model"
if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
]
load_balancing_enabled = model_setting.load_balancing_enabled
@ -1769,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
custom_model_lb_configs = [
config
for config in model_setting.load_balancing_configs
if config.credential_source_type != "provider"
if config.credential_source_type != CredentialSourceType.PROVIDER
]
load_balancing_enabled = model_setting.load_balancing_enabled

View File

@ -40,6 +40,7 @@ from libs.datetime_utils import naive_utc_now
from models import Account
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
from models.model import UploadFile
from services.feature_service import FeatureService
@ -56,7 +57,7 @@ class IndexingRunner:
logger.exception("consume document failed")
document = db.session.get(DatasetDocument, document_id)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
error_message = getattr(error, "description", str(error))
document.error = str(error_message)
document.stopped_at = naive_utc_now()
@ -219,7 +220,7 @@ class IndexingRunner:
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
if document_segment.status != SegmentStatus.COMPLETED:
document = Document(
page_content=document_segment.content,
metadata={
@ -382,7 +383,7 @@ class IndexingRunner:
data_source_info = dataset_document.data_source_info_dict
text_docs = []
match dataset_document.data_source_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
@ -395,7 +396,7 @@ class IndexingRunner:
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "notion_import":
case DataSourceType.NOTION_IMPORT:
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
@ -417,7 +418,7 @@ class IndexingRunner:
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
if (
not data_source_info
or "provider" not in data_source_info
@ -445,7 +446,7 @@ class IndexingRunner:
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="splitting",
after_indexing_status=IndexingStatus.SPLITTING,
extra_update_params={
DatasetDocument.parsing_completed_at: naive_utc_now(),
},
@ -545,7 +546,7 @@ class IndexingRunner:
Clean the document text according to the processing rules.
"""
rules: AutomaticRulesConfig | dict[str, Any]
if processing_rule.mode == "automatic":
if processing_rule.mode == ProcessRuleMode.AUTOMATIC:
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
@ -636,7 +637,7 @@ class IndexingRunner:
# update document status to completed
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="completed",
after_indexing_status=IndexingStatus.COMPLETED,
extra_update_params={
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: naive_utc_now(),
@ -659,10 +660,10 @@ class IndexingRunner:
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing",
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
@ -703,10 +704,10 @@ class IndexingRunner:
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing",
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
@ -725,7 +726,7 @@ class IndexingRunner:
@staticmethod
def _update_document_index_status(
document_id: str, after_indexing_status: str, extra_update_params: dict | None = None
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
):
"""
Update the document indexing status.
@ -803,7 +804,7 @@ class IndexingRunner:
cur_time = naive_utc_now()
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
after_indexing_status=IndexingStatus.INDEXING,
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
@ -815,7 +816,7 @@ class IndexingRunner:
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.status: SegmentStatus.INDEXING,
DocumentSegment.indexing_at: naive_utc_now(),
},
)

View File

@ -1,3 +1,5 @@
from typing_extensions import TypedDict
from core.model_manager import ModelInstance, ModelManager
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
class RerankingModelDict(TypedDict):
reranking_provider_name: str
reranking_model_name: str
class VectorSettingDict(TypedDict):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSettingDict(TypedDict):
keyword_weight: float
class WeightsDict(TypedDict):
vector_setting: VectorSettingDict
keyword_setting: KeywordSettingDict
class DataPostProcessor:
"""Interface for data post-processing document."""
@ -17,8 +39,8 @@ class DataPostProcessor:
self,
tenant_id: str,
reranking_mode: str,
reranking_model: dict | None = None,
weights: dict | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
reorder_enabled: bool = False,
):
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
@ -45,8 +67,8 @@ class DataPostProcessor:
self,
reranking_mode: str,
tenant_id: str,
reranking_model: dict | None = None,
weights: dict | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
) -> BaseRerankRunner | None:
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
runner = RerankRunnerFactory.create_rerank_runner(
@ -79,12 +101,14 @@ class DataPostProcessor:
return ReorderRunner()
return None
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
def _get_rerank_model_instance(
self, tenant_id: str, reranking_model: RerankingModelDict | None
) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
reranking_provider_name = reranking_model["reranking_provider_name"]
reranking_model_name = reranking_model["reranking_model_name"]
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(

View File

@ -1,19 +1,20 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, NotRequired
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from typing_extensions import TypedDict
from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
class SegmentAttachmentResult(TypedDict):
attachment_info: AttachmentInfoDict
segment_id: str
class SegmentAttachmentInfoResult(TypedDict):
attachment_id: str
attachment_info: AttachmentInfoDict
segment_id: str
class ChildChunkDetail(TypedDict):
id: str
content: str
position: int
score: float
class SegmentChildMapDetail(TypedDict):
max_score: float
child_chunks: list[ChildChunkDetail]
class SegmentRecord(TypedDict):
segment: DocumentSegment
score: NotRequired[float]
child_chunks: NotRequired[list[ChildChunkDetail]]
files: NotRequired[list[AttachmentInfoDict]]
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod | str
reranking_enable: bool
reranking_model: RerankingModelDict
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -56,9 +96,9 @@ class RetrievalService:
query: str,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
@ -235,7 +275,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: RetrievalMethod,
exceptions: list,
@ -277,8 +317,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
@ -288,8 +328,8 @@ class RetrievalService:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
provider=reranking_model["reranking_provider_name"],
model=reranking_model["reranking_model_name"],
model_type=ModelType.RERANK,
)
if is_support_vision:
@ -329,7 +369,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: str,
exceptions: list,
@ -349,8 +389,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
@ -459,7 +499,7 @@ class RetrievalService:
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[dict[str, Any]]] = {}
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
@ -544,12 +584,12 @@ class RetrievalService:
segment_summary_map[summary.chunk_id] = summary.summary_content
include_segment_ids = set()
segment_child_map: dict[str, dict[str, Any]] = {}
records: list[dict[str, Any]] = []
segment_child_map: dict[str, SegmentChildMapDetail] = {}
records: list[SegmentRecord] = []
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
@ -560,14 +600,14 @@ class RetrievalService:
max_score = summary_score_map.get(segment.id, 0.0)
if child_chunks or attachment_infos:
child_chunk_details = []
child_chunk_details: list[ChildChunkDetail] = []
for child_chunk in child_chunks:
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
if child_document:
child_score = child_document.metadata.get("score", 0.0)
else:
child_score = 0.0
child_chunk_detail = {
child_chunk_detail: ChildChunkDetail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
@ -580,7 +620,7 @@ class RetrievalService:
if file_document:
max_score = max(max_score, file_document.metadata.get("score", 0.0))
map_detail = {
map_detail: SegmentChildMapDetail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
@ -593,7 +633,7 @@ class RetrievalService:
"max_score": summary_score,
"child_chunks": [],
}
record: dict[str, Any] = {
record: SegmentRecord = {
"segment": segment,
}
records.append(record)
@ -617,19 +657,19 @@ class RetrievalService:
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
another_record: SegmentRecord = {
"segment": segment,
"score": max_score,
}
records.append(record)
records.append(another_record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
record["files"] = attachment_map[record["segment"].id]
result: list[RetrievalSegments] = []
for record in records:
@ -693,9 +733,9 @@ class RetrievalService:
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
):
@ -807,7 +847,7 @@ class RetrievalService:
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
) -> SegmentAttachmentResult | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
@ -816,7 +856,7 @@ class RetrievalService:
.first()
)
if attachment_binding:
attachment_info = {
attachment_info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -828,8 +868,10 @@ class RetrievalService:
return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
def get_segment_attachment_infos(
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
@ -843,7 +885,7 @@ class RetrievalService:
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -855,7 +897,7 @@ class RetrievalService:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"attachment_info": info,
"segment_id": attachment_binding.segment_id,
}
)

View File

@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r
document embeddings used in retrieval-augmented generation workflows.
"""
import atexit
import datetime
import json
import logging
@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None
_weaviate_client_lock = threading.Lock()
def _shutdown_weaviate_client() -> None:
"""
Best-effort shutdown hook to close the module-level Weaviate client.
This is registered with atexit so that HTTP/gRPC resources are released
when the Python interpreter exits.
"""
global _weaviate_client
# Ensure thread-safety when accessing the shared client instance
with _weaviate_client_lock:
client = _weaviate_client
_weaviate_client = None
if client is not None:
try:
client.close()
except Exception:
# Best-effort cleanup; log at debug level and ignore errors.
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
# Register the shutdown hook once per process.
atexit.register(_shutdown_weaviate_client)
class WeaviateConfig(BaseModel):
"""
Configuration model for Weaviate connection settings.
@ -85,18 +112,6 @@ class WeaviateVector(BaseVector):
self._client = self._init_client(config)
self._attributes = attributes
def __del__(self):
"""
Destructor to properly close the Weaviate client connection.
Prevents connection leaks and resource warnings.
"""
if hasattr(self, "_client") and self._client is not None:
try:
self._client.close()
except Exception as e:
# Ignore errors during cleanup as object is being destroyed
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
"""
Initializes and returns a connected Weaviate client.

View File

@ -1,8 +1,18 @@
from pydantic import BaseModel
from typing_extensions import TypedDict
from models.dataset import DocumentSegment
class AttachmentInfoDict(TypedDict):
id: str
name: str
extension: str
mime_type: str
source_url: str
size: int
class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""
@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel):
segment: DocumentSegment
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
files: list[AttachmentInfoDict] | None = None
summary: str | None = None # Summary content if retrieved via summary index

View File

@ -9,6 +9,7 @@ from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
@ -51,7 +52,7 @@ class IndexProcessor:
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
summary_index_setting: SummaryIndexSettingDict | None = None,
):
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
@ -131,7 +132,12 @@ class IndexProcessor:
}
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
self,
chunks: Any,
dataset_id: str,
document_id: str,
chunk_structure: str,
summary_index_setting: SummaryIndexSettingDict | None,
) -> Preview:
doc_language = None
with session_factory.create_session() as session:

View File

@ -7,14 +7,16 @@ import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, NotRequired, Optional
from urllib.parse import unquote, urlparse
import httpx
from typing_extensions import TypedDict
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
@ -35,6 +37,13 @@ if TYPE_CHECKING:
from core.model_manager import ModelInstance
class SummaryIndexSettingDict(TypedDict):
enable: bool
model_name: NotRequired[str]
model_provider_name: NotRequired[str]
summary_prompt: NotRequired[str]
class BaseIndexProcessor(ABC):
"""Interface for extract files."""
@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
raise NotImplementedError

View File

@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.provider_manager import ProviderManager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
@ -22,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def generate_summary(
tenant_id: str,
text: str,
summary_index_setting: dict | None = None,
summary_index_setting: SummaryIndexSettingDict | None = None,
segment_id: str | None = None,
document_language: str | None = None,
) -> tuple[str, LLMUsage]:

View File

@ -11,6 +11,7 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""

View File

@ -15,13 +15,14 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
):
# Set search parameters.
results = RetrievalService.retrieve(
@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""

View File

@ -31,7 +31,7 @@ from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
@ -83,7 +83,7 @@ from models.dataset import (
)
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
@ -727,8 +727,8 @@ class DatasetRetrieval:
top_k: int,
score_threshold: float,
reranking_mode: str,
reranking_model: dict | None = None,
weights: dict[str, Any] | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
reranking_enable: bool = True,
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
@ -1008,7 +1008,7 @@ class DatasetRetrieval:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source=DatasetQuerySource.APP,
source_app_id=app_id,
created_by_role=CreatorUserRole(user_from),
created_by=user_id,
@ -1181,8 +1181,8 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"],
reranking_model_name=retrieve_config.reranking_model["reranking_model_name"],
)
tools.append(tool)
@ -1685,8 +1685,8 @@ class DatasetRetrieval:
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
reranking_model: RerankingModelDict | None,
weights: WeightsDict | None,
top_k: int,
score_threshold: float,
query: str | None,

View File

@ -2,6 +2,7 @@ import concurrent.futures
import logging
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
@ -11,7 +12,11 @@ logger = logging.getLogger(__name__)
class SummaryIndex:
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
self,
dataset_id: str,
document_id: str,
is_preview: bool,
summary_index_setting: SummaryIndexSettingDict | None = None,
) -> None:
if is_preview:
with session_factory.create_session() as session:

View File

@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict):
controller: ApiToolProviderController
class EmojiIconDict(TypedDict):
background: str
content: str
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
@ -916,7 +921,7 @@ class ToolManager:
)
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
@ -933,7 +938,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
@ -950,7 +955,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
try:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
@ -970,7 +975,7 @@ class ToolManager:
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> str | Mapping[str, str]:
) -> str | EmojiIconDict | dict[str, str]:
"""
get the tool icon

View File

@ -1,5 +1,4 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},

View File

@ -1,9 +1,10 @@
from typing import Any, cast
from typing import NotRequired, TypedDict, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
@ -16,7 +17,19 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModelDict
reranking_mode: NotRequired[str]
weights: NotRequired[WeightsDict | None]
score_threshold: NotRequired[float]
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
# use keyword table query

View File

@ -1,4 +1,5 @@
import re
from collections.abc import Mapping
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
@ -20,10 +21,18 @@ class InterfaceDict(TypedDict):
operation: dict[str, Any]
class OpenAPISpecDict(TypedDict):
openapi: str
info: dict[str, str]
servers: list[dict[str, Any]]
paths: dict[str, Any]
components: dict[str, Any]
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
) -> dict[str, Any]:
) -> OpenAPISpecDict:
warning = warning or {}
"""
parse swagger to openapi
@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser:
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
converted_openapi: dict[str, Any] = {
converted_openapi: OpenAPISpecDict = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),

View File

@ -2,6 +2,7 @@ from typing import Literal, Union
from pydantic import BaseModel
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from dify_graph.entities.base_node_data import BaseNodeData
@ -161,4 +162,4 @@ class KnowledgeIndexNodeData(BaseNodeData):
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None
summary_index_setting: SummaryIndexSettingDict | None = None

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from dify_graph.entities.graph_config import NodeConfigDict
@ -127,7 +128,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
is_preview: bool,
batch: Any,
chunks: Mapping[str, Any],
summary_index_setting: dict | None = None,
summary_index_setting: SummaryIndexSettingDict | None = None,
):
if not document_id:
raise KnowledgeIndexNodeError("document_id is required.")

View File

@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
@ -201,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
reranking_model = None
weights = None
reranking_model: RerankingModelDict | None = None
weights: WeightsDict | None = None
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:

View File

@ -2,6 +2,7 @@ from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from dify_graph.model_runtime.entities import LLMUsage
from dify_graph.nodes.llm.entities import ModelConfig
@ -75,8 +76,8 @@ class KnowledgeRetrievalRequest(BaseModel):
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration")
weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")

View File

@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
http_request_config=self._http_request_config,
max_retries=0,
ssl_verify=self.node_data.ssl_verify,
http_client=self._http_client,
file_manager=self._file_manager,

View File

@ -10,6 +10,7 @@ from events.document_index_event import document_index_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document
from models.enums import IndexingStatus
logger = logging.getLogger(__name__)
@ -35,7 +36,7 @@ def handle(sender, **kwargs):
if not document:
raise NotFound("Document not found")
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)

View File

@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase):
@classmethod
def get_by_openid(cls, provider: str, open_id: str):
account_integrate = (
db.session.query(AccountIntegrate)
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none()
)
account_integrate = db.session.execute(
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
).scalar_one_or_none()
if account_integrate:
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
return None
# check current_user.current_tenant.current_role in ['admin', 'owner']

View File

@ -8,6 +8,7 @@ import os
import pickle
import re
import time
from collections.abc import Sequence
from datetime import datetime
from json import JSONDecodeError
from typing import Any, TypedDict, cast
@ -30,7 +31,20 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
from .account import Account
from .base import Base, TypeBase
from .engine import db
from .enums import CreatorUserRole
from .enums import (
CollectionBindingType,
CreatorUserRole,
DatasetMetadataType,
DatasetQuerySource,
DatasetRuntimeMode,
DataSourceType,
DocumentCreatedFrom,
DocumentDocType,
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SummaryStatus,
)
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
@ -120,7 +134,7 @@ class Dataset(Base):
server_default=sa.text("'only_me'"),
default=DatasetPermissionEnum.ONLY_ME,
)
data_source_type = mapped_column(String(255))
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
index_struct = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
@ -137,7 +151,9 @@ class Dataset(Base):
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
runtime_mode = mapped_column(
EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'")
)
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@ -145,30 +161,25 @@ class Dataset(Base):
@property
def total_documents(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def total_available_documents(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def dataset_keyword_table(self):
dataset_keyword_table = (
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
)
if dataset_keyword_table:
return dataset_keyword_table
return None
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
@property
def index_struct_dict(self):
@ -195,64 +206,66 @@ class Dataset(Base):
@property
def latest_process_rule(self):
return (
db.session.query(DatasetProcessRule)
return db.session.scalar(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
.limit(1)
)
@property
def app_count(self):
return (
db.session.query(func.count(AppDatasetJoin.id))
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar()
db.session.scalar(
select(func.count(AppDatasetJoin.id)).where(
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
)
)
or 0
)
@property
def document_count(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def available_document_count(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def available_segment_count(self):
return (
db.session.query(func.count(DocumentSegment.id))
.where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
)
.scalar()
or 0
)
@property
def word_count(self):
return (
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
.where(Document.dataset_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
)
@property
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
if document:
return document.doc_form
return None
@ -270,8 +283,8 @@ class Dataset(Base):
@property
def tags(self):
tags = (
db.session.query(Tag)
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@ -279,8 +292,7 @@ class Dataset(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "knowledge",
)
.all()
)
).all()
return tags or []
@ -288,8 +300,8 @@ class Dataset(Base):
def external_knowledge_info(self):
if self.provider != "external":
return None
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
)
if not external_knowledge_binding:
return None
@ -310,7 +322,7 @@ class Dataset(Base):
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
if pipeline:
return pipeline.is_published
return False
@ -382,7 +394,7 @@ class DatasetProcessRule(Base): # bug
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id = mapped_column(StringUUID, nullable=False)
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
rules = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -428,12 +440,12 @@ class Document(Base):
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False)
data_source_info = mapped_column(LongText, nullable=True)
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
batch: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_api_request_id = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -467,7 +479,9 @@ class Document(Base):
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# basic fields
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'"))
indexing_status = mapped_column(
EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'")
)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
@ -478,7 +492,7 @@ class Document(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
doc_type = mapped_column(String(40), nullable=True)
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
@ -521,10 +535,8 @@ class Document(Base):
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none()
file_detail = db.session.scalar(
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
)
if file_detail:
return {
@ -557,24 +569,23 @@ class Document(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def segment_count(self):
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
return (
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
)
@property
def hit_count(self):
return (
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.where(DocumentSegment.document_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
)
@property
def uploader(self):
user = db.session.query(Account).where(Account.id == self.created_by).first()
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
return user.name if user else None
@property
@ -588,14 +599,13 @@ class Document(Base):
@property
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
document_metadatas = db.session.scalars(
select(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
)
).all()
metadata_list: list[DocMetadataDetailItem] = []
for metadata in document_metadatas:
metadata_dict: DocMetadataDetailItem = {
@ -791,7 +801,7 @@ class DocumentSegment(Base):
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'"))
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -826,7 +836,7 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self) -> list[Any]:
def child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -835,16 +845,13 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
def get_child_chunks(self) -> list[Any]:
def get_child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -853,12 +860,9 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
@ -1007,15 +1011,15 @@ class ChildChunk(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def document(self):
return db.session.query(Document).where(Document.id == self.document_id).first()
return db.session.scalar(select(Document).where(Document.id == self.document_id))
@property
def segment(self):
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
class AppDatasetJoin(TypeBase):
@ -1061,7 +1065,7 @@ class DatasetQuery(TypeBase):
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False)
source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False)
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1076,7 +1080,7 @@ class DatasetQuery(TypeBase):
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
if file_info:
query["file_info"] = {
"id": file_info.id,
@ -1141,7 +1145,7 @@ class DatasetKeywordTable(TypeBase):
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
if not dataset:
return None
if self.data_source_type == "database":
@ -1206,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase):
)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
type: Mapped[str] = mapped_column(
EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False
)
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@ -1433,7 +1439,7 @@ class DatasetMetadata(TypeBase):
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
@ -1535,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase):
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
return ""
@ -1570,7 +1576,7 @@ class Pipeline(TypeBase):
)
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
class DocumentPipelineExecutionLog(TypeBase):
@ -1660,7 +1666,9 @@ class DocumentSegmentSummary(Base):
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
status: Mapped[str] = mapped_column(
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
)
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

View File

@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
ACCOUNT = "account"
END_USER = "end_user"
@classmethod
def _missing_(cls, value):
if value == "end-user":
return cls.END_USER
else:
return super()._missing_(value)
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
@ -96,3 +103,216 @@ class ConversationStatus(StrEnum):
"""Conversation Status Enum"""
NORMAL = "normal"
class DataSourceType(StrEnum):
"""Data Source Type for Dataset and Document"""
UPLOAD_FILE = "upload_file"
NOTION_IMPORT = "notion_import"
WEBSITE_CRAWL = "website_crawl"
LOCAL_FILE = "local_file"
ONLINE_DOCUMENT = "online_document"
class ProcessRuleMode(StrEnum):
"""Dataset Process Rule Mode"""
AUTOMATIC = "automatic"
CUSTOM = "custom"
HIERARCHICAL = "hierarchical"
class IndexingStatus(StrEnum):
"""Document Indexing Status"""
WAITING = "waiting"
PARSING = "parsing"
CLEANING = "cleaning"
SPLITTING = "splitting"
INDEXING = "indexing"
PAUSED = "paused"
COMPLETED = "completed"
ERROR = "error"
class DocumentCreatedFrom(StrEnum):
"""Document Created From"""
WEB = "web"
API = "api"
RAG_PIPELINE = "rag-pipeline"
class ConversationFromSource(StrEnum):
"""Conversation / Message from_source"""
API = "api"
CONSOLE = "console"
class FeedbackFromSource(StrEnum):
"""MessageFeedback from_source"""
USER = "user"
ADMIN = "admin"
class InvokeFrom(StrEnum):
"""How a conversation/message was invoked"""
SERVICE_API = "service-api"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"
DEBUGGER = "debugger"
PUBLISHED_PIPELINE = "published"
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
return cls(value)
def to_source(self) -> str:
source_mapping = {
InvokeFrom.WEB_APP: "web_app",
InvokeFrom.DEBUGGER: "dev",
InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api",
}
return source_mapping.get(self, "dev")
class DocumentDocType(StrEnum):
"""Document doc_type classification"""
BOOK = "book"
WEB_PAGE = "web_page"
PAPER = "paper"
SOCIAL_MEDIA_POST = "social_media_post"
WIKIPEDIA_ENTRY = "wikipedia_entry"
PERSONAL_DOCUMENT = "personal_document"
BUSINESS_DOCUMENT = "business_document"
IM_CHAT_LOG = "im_chat_log"
SYNCED_FROM_NOTION = "synced_from_notion"
SYNCED_FROM_GITHUB = "synced_from_github"
OTHERS = "others"
class TagType(StrEnum):
"""Tag type"""
KNOWLEDGE = "knowledge"
APP = "app"
class DatasetMetadataType(StrEnum):
"""Dataset metadata value type"""
STRING = "string"
NUMBER = "number"
TIME = "time"
class SegmentStatus(StrEnum):
"""Document segment status"""
WAITING = "waiting"
INDEXING = "indexing"
COMPLETED = "completed"
ERROR = "error"
PAUSED = "paused"
RE_SEGMENT = "re_segment"
class DatasetRuntimeMode(StrEnum):
"""Dataset runtime mode"""
GENERAL = "general"
RAG_PIPELINE = "rag_pipeline"
class CollectionBindingType(StrEnum):
"""Dataset collection binding type"""
DATASET = "dataset"
ANNOTATION = "annotation"
class DatasetQuerySource(StrEnum):
"""Dataset query source"""
HIT_TESTING = "hit_testing"
APP = "app"
class TidbAuthBindingStatus(StrEnum):
"""TiDB auth binding status"""
CREATING = "CREATING"
ACTIVE = "ACTIVE"
class MessageFileBelongsTo(StrEnum):
"""MessageFile belongs_to"""
USER = "user"
ASSISTANT = "assistant"
class CredentialSourceType(StrEnum):
"""Load balancing credential source type"""
PROVIDER = "provider"
CUSTOM_MODEL = "custom_model"
class PaymentStatus(StrEnum):
"""Provider order payment status"""
WAIT_PAY = "wait_pay"
PAID = "paid"
FAILED = "failed"
REFUNDED = "refunded"
class BannerStatus(StrEnum):
"""ExporleBanner status"""
ENABLED = "enabled"
DISABLED = "disabled"
class SummaryStatus(StrEnum):
"""Document segment summary status"""
NOT_STARTED = "not_started"
GENERATING = "generating"
COMPLETED = "completed"
ERROR = "error"
TIMEOUT = "timeout"
class MessageChainType(StrEnum):
"""Message chain type"""
SYSTEM = "system"
class ProviderQuotaType(StrEnum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value: str) -> "ProviderQuotaType":
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")

View File

@ -380,13 +380,12 @@ class App(Base):
@property
def site(self) -> Site | None:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
return db.session.scalar(select(Site).where(Site.app_id == self.id))
@property
def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
return None
@ -395,7 +394,7 @@ class App(Base):
if self.workflow_id:
from .workflow import Workflow
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
return None
@ -405,8 +404,7 @@ class App(Base):
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def is_agent(self) -> bool:
@ -546,9 +544,9 @@ class App(Base):
return deleted_tools
@property
def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
def tags(self) -> Sequence[Tag]:
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@ -556,15 +554,14 @@ class App(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "app",
)
.all()
)
).all()
return tags or []
@property
def author_name(self) -> str | None:
if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
@ -616,8 +613,7 @@ class AppModelConfig(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def model_dict(self) -> ModelConfig:
@ -652,8 +648,8 @@ class AppModelConfig(TypeBase):
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
@ -845,8 +841,7 @@ class RecommendedApp(Base): # bug
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class InstalledApp(TypeBase):
@ -873,13 +868,11 @@ class InstalledApp(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base):
@ -899,8 +892,7 @@ class TrialApp(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base):
@ -919,13 +911,11 @@ class AccountTrialAppRecord(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class ExporleBanner(TypeBase):
@ -1117,8 +1107,8 @@ class Conversation(Base):
else:
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
app_model_config = db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
)
if app_model_config:
model_config = app_model_config.to_dict()
@ -1141,36 +1131,43 @@ class Conversation(Base):
@property
def annotated(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
return (
db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
)
or 0
) > 0
@property
def annotation(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
@property
def message_count(self):
return db.session.query(Message).where(Message.conversation_id == self.id).count()
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
@property
def user_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@ -1178,23 +1175,25 @@ class Conversation(Base):
@property
def admin_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@ -1256,22 +1255,19 @@ class Conversation(Base):
@property
def first_message(self):
return (
db.session.query(Message)
.where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc())
.first()
return db.session.scalar(
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
)
@property
def app(self) -> App | None:
with Session(db.engine, expire_on_commit=False) as session:
return session.query(App).where(App.id == self.app_id).first()
return session.scalar(select(App).where(App.id == self.app_id))
@property
def from_end_user_session_id(self):
if self.from_end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
if end_user:
return end_user.session_id
@ -1280,7 +1276,7 @@ class Conversation(Base):
@property
def from_account_name(self) -> str | None:
if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
if account:
return account.name
@ -1505,21 +1501,15 @@ class Message(Base):
@property
def user_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
)
return feedback
@property
def admin_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
)
return feedback
@property
def feedbacks(self):
@ -1528,28 +1518,27 @@ class Message(Base):
@property
def annotation(self):
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
return annotation
@property
def annotation_hit_history(self):
annotation_history = (
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
annotation_history = db.session.scalar(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
)
if annotation_history:
annotation = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id == annotation_history.annotation_id)
.first()
return db.session.scalar(
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
)
return annotation
return None
@property
def app_model_config(self):
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
if conversation:
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
return db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
)
return None
@ -1562,13 +1551,12 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
return db.session.scalars(
select(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
).all()
@property
def retriever_resources(self) -> Any:
@ -1579,7 +1567,7 @@ class Message(Base):
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
if not current_app:
raise ValueError(f"App {self.app_id} not found")
@ -1743,8 +1731,7 @@ class MessageFeedback(TypeBase):
@property
def from_account(self) -> Account | None:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
def to_dict(self) -> MessageFeedbackDict:
return {
@ -1817,13 +1804,11 @@ class MessageAnnotation(Base):
@property
def account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationHitHistory(TypeBase):
@ -1852,18 +1837,15 @@ class AppAnnotationHitHistory(TypeBase):
@property
def account(self):
account = (
db.session.query(Account)
return db.session.scalar(
select(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.where(MessageAnnotation.id == self.annotation_id)
.first()
)
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationSetting(TypeBase):
@ -1896,12 +1878,9 @@ class AppAnnotationSetting(TypeBase):
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
collection_binding_detail = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == self.collection_binding_id)
.first()
return db.session.scalar(
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
)
return collection_binding_detail
class OperationLog(TypeBase):
@ -2007,7 +1986,9 @@ class AppMCPServer(TypeBase):
def generate_server_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
while (
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
) > 0:
result = generate_string(n)
return result
@ -2068,7 +2049,7 @@ class Site(Base):
def generate_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0:
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
result = generate_string(n)
return result

View File

@ -6,13 +6,14 @@ from functools import cached_property
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, text
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus
from .types import EnumText, LongText, StringUUID
@ -96,7 +97,7 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
@property
def credential_name(self):
@ -159,10 +160,8 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return (
db.session.query(ProviderModelCredential)
.where(ProviderModelCredential.id == self.credential_id)
.first()
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
@property
@ -239,7 +238,9 @@ class ProviderOrder(TypeBase):
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
currency: Mapped[str | None] = mapped_column(String(40))
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'"))
payment_status: Mapped[PaymentStatus] = mapped_column(
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
)
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
EnumText(CredentialSourceType, length=40), nullable=True, default=None
)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False

View File

@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func
from sqlalchemy import ForeignKey, String, func, select
from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
def user(self) -> Account | None:
if not self.user_id:
return None
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class ToolLabelBinding(TypeBase):
@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
@property
def user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
@property
def app(self) -> App | None:
return db.session.query(App).where(App.id == self.app_id).first()
return db.session.scalar(select(App).where(App.id == self.app_id))
class MCPToolProvider(TypeBase):
@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def credentials(self) -> dict[str, Any]:

View File

@ -2,7 +2,7 @@ from datetime import datetime
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, func
from sqlalchemy import DateTime, func, select
from sqlalchemy.orm import Mapped, mapped_column
from .base import TypeBase
@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
@property
def message(self):
return db.session.query(Message).where(Message.id == self.message_id).first()
return db.session.scalar(select(Message).where(Message.id == self.message_id))
class PinnedConversation(TypeBase):

View File

@ -679,14 +679,14 @@ class WorkflowRun(Base):
def message(self):
from .model import Message
return (
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
return db.session.scalar(
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
)
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def workflow(self):
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
def to_dict(self):
return {

View File

@ -182,4 +182,9 @@ tasks/app_generate/workflow_execute_task.py
tasks/regenerate_summary_index_task.py
tasks/trigger_processing_tasks.py
tasks/workflow_cfs_scheduler/cfs_scheduler.py
tasks/add_document_to_index_task.py
tasks/create_segment_to_index_task.py
tasks/disable_segment_from_index_task.py
tasks/enable_segment_to_index_task.py
tasks/remove_document_from_index_task.py
tasks/workflow_execution_tasks.py

View File

@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
from models.model import App, Conversation, EndUser, Message
class AgentService:
@ -47,7 +47,7 @@ class AgentService:
if not message:
raise ValueError(f"Message not found: {message_id}")
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
agent_thoughts = message.agent_thoughts
if conversation.from_end_user_id:
# only select name field

View File

@ -51,6 +51,14 @@ from models.dataset import (
Pipeline,
SegmentAttachmentBinding,
)
from models.enums import (
DatasetRuntimeMode,
DataSourceType,
DocumentCreatedFrom,
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
@ -319,7 +327,7 @@ class DatasetService:
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
created_by=current_user.id,
pipeline_id=pipeline.id,
@ -614,7 +622,7 @@ class DatasetService:
"""
Update pipeline knowledge base node data.
"""
if dataset.runtime_mode != "rag_pipeline":
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
return
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
@ -1229,10 +1237,15 @@ class DocumentService:
"enabled": "available",
}
_INDEXING_STATUSES: tuple[str, ...] = ("parsing", "cleaning", "splitting", "indexing")
_INDEXING_STATUSES: tuple[IndexingStatus, ...] = (
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
)
DISPLAY_STATUS_FILTERS: dict[str, tuple[Any, ...]] = {
"queuing": (Document.indexing_status == "waiting",),
"queuing": (Document.indexing_status == IndexingStatus.WAITING,),
"indexing": (
Document.indexing_status.in_(_INDEXING_STATUSES),
Document.is_paused.is_not(True),
@ -1241,19 +1254,19 @@ class DocumentService:
Document.indexing_status.in_(_INDEXING_STATUSES),
Document.is_paused.is_(True),
),
"error": (Document.indexing_status == "error",),
"error": (Document.indexing_status == IndexingStatus.ERROR,),
"available": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(False),
Document.enabled.is_(True),
),
"disabled": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(False),
Document.enabled.is_(False),
),
"archived": (
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived.is_(True),
),
}
@ -1536,7 +1549,7 @@ class DocumentService:
"""
Normalize and validate `Document -> UploadFile` linkage for download flows.
"""
if document.data_source_type != "upload_file":
if document.data_source_type != DataSourceType.UPLOAD_FILE:
raise NotFound(invalid_source_message)
data_source_info: dict[str, Any] = document.data_source_info_dict or {}
@ -1617,7 +1630,7 @@ class DocumentService:
select(Document).where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived == False,
)
).all()
@ -1640,7 +1653,7 @@ class DocumentService:
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.indexing_status == IndexingStatus.COMPLETED,
Document.archived == False,
)
).all()
@ -1650,7 +1663,10 @@ class DocumentService:
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
select(Document).where(
Document.dataset_id == dataset_id,
Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]),
)
).all()
return documents
@ -1683,7 +1699,7 @@ class DocumentService:
def delete_document(document):
# trigger document_was_deleted signal
file_id = None
if document.data_source_type == "upload_file":
if document.data_source_type == DataSourceType.UPLOAD_FILE:
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
@ -1704,7 +1720,7 @@ class DocumentService:
file_ids = [
document.data_source_info_dict.get("upload_file_id", "")
for document in documents
if document.data_source_type == "upload_file" and document.data_source_info_dict
if document.data_source_type == DataSourceType.UPLOAD_FILE and document.data_source_info_dict
]
# Delete documents first, then dispatch cleanup task after commit
@ -1753,7 +1769,13 @@ class DocumentService:
@staticmethod
def pause_document(document):
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
if document.indexing_status not in {
IndexingStatus.WAITING,
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
}:
raise DocumentIndexingError()
# update document to be paused
assert current_user is not None
@ -1793,7 +1815,7 @@ class DocumentService:
if cache_result is not None:
raise ValueError("Document is being retried, please try again later")
# retry document indexing
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
db.session.commit()
@ -1812,7 +1834,7 @@ class DocumentService:
if cache_result is not None:
raise ValueError("Document is being synced, please try again later")
# sync document indexing
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
data_source_info = document.data_source_info_dict
if data_source_info:
data_source_info["mode"] = "scrape"
@ -1840,7 +1862,7 @@ class DocumentService:
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
created_from: str = DocumentCreatedFrom.WEB,
) -> tuple[list[Document], str]:
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
@ -1932,7 +1954,7 @@ class DocumentService:
if not dataset_process_rule:
process_rule = knowledge_config.process_rule
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
if process_rule.mode in (ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL):
if process_rule.rules:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
@ -1944,7 +1966,7 @@ class DocumentService:
dataset_process_rule = dataset.latest_process_rule
if not dataset_process_rule:
raise ValueError("No process rule found.")
elif process_rule.mode == "automatic":
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
@ -1967,7 +1989,7 @@ class DocumentService:
if not dataset_process_rule:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode="automatic",
mode=ProcessRuleMode.AUTOMATIC,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
@ -2001,7 +2023,7 @@ class DocumentService:
.where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == "upload_file",
Document.data_source_type == DataSourceType.UPLOAD_FILE,
Document.enabled == True,
Document.name.in_(file_names),
)
@ -2021,7 +2043,7 @@ class DocumentService:
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
@ -2056,7 +2078,7 @@ class DocumentService:
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
enabled=True,
)
.all()
@ -2507,7 +2529,7 @@ class DocumentService:
document_data: KnowledgeConfig,
account: Account,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
created_from: str = DocumentCreatedFrom.WEB,
):
assert isinstance(current_user, Account)
@ -2520,14 +2542,14 @@ class DocumentService:
# save process rule
if document_data.process_rule:
process_rule = document_data.process_rule
if process_rule.mode in {"custom", "hierarchical"}:
if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
@ -2609,7 +2631,7 @@ class DocumentService:
if document_data.name:
document.name = document_data.name
# update document to be waiting
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
document.completed_at = None
document.processing_started_at = None
document.parsing_completed_at = None
@ -2623,7 +2645,7 @@ class DocumentService:
# update document segment
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
{DocumentSegment.status: "re_segment"}
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
)
db.session.commit()
# trigger async task
@ -2754,7 +2776,7 @@ class DocumentService:
if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if knowledge_config.process_rule.mode == "automatic":
if knowledge_config.process_rule.mode == ProcessRuleMode.AUTOMATIC:
knowledge_config.process_rule.rules = None
else:
if not knowledge_config.process_rule.rules:
@ -2785,7 +2807,7 @@ class DocumentService:
raise ValueError("Process rule segmentation separator is invalid")
if not (
knowledge_config.process_rule.mode == "hierarchical"
knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL
and knowledge_config.process_rule.rules.parent_mode == "full-doc"
):
if not knowledge_config.process_rule.rules.segmentation.max_tokens:
@ -2814,7 +2836,7 @@ class DocumentService:
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args["process_rule"]["mode"] == "automatic":
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
args["process_rule"]["rules"] = {}
else:
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
@ -3021,7 +3043,7 @@ class DocumentService:
@staticmethod
def _prepare_disable_update(document, user, now):
"""Prepare updates for disabling a document."""
if not document.completed_at or document.indexing_status != "completed":
if not document.completed_at or document.indexing_status != IndexingStatus.COMPLETED:
raise DocumentIndexingError(f"Document: {document.name} is not completed.")
if not document.enabled:
@ -3130,7 +3152,7 @@ class SegmentService:
content=content,
word_count=len(content),
tokens=tokens,
status="completed",
status=SegmentStatus.COMPLETED,
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
@ -3167,7 +3189,7 @@ class SegmentService:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
@ -3227,7 +3249,7 @@ class SegmentService:
word_count=len(content),
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
status=SegmentStatus.COMPLETED,
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
@ -3259,7 +3281,7 @@ class SegmentService:
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e)
db.session.commit()
return segment_data_list
@ -3405,7 +3427,7 @@ class SegmentService:
segment.index_node_hash = segment_hash
segment.word_count = len(content)
segment.tokens = tokens
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
segment.indexing_at = naive_utc_now()
segment.completed_at = naive_utc_now()
segment.updated_by = current_user.id
@ -3530,7 +3552,7 @@ class SegmentService:
logger.exception("update segment index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()

View File

@ -13,7 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode
from extensions.ext_database import db
from models import Account
from models.dataset import Dataset, DatasetQuery
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
@ -97,7 +97,7 @@ class HitTestingService:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=json.dumps(dataset_queries),
source="hit_testing",
source=DatasetQuerySource.HIT_TESTING,
source_app_id=None,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
@ -137,7 +137,7 @@ class HitTestingService:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source=DatasetQuerySource.HIT_TESTING,
source_app_id=None,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,

View File

@ -7,6 +7,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from models.enums import DatasetMetadataType
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
@ -130,11 +131,11 @@ class MetadataService:
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "time"},
{"name": BuiltInField.last_update_date, "type": "time"},
{"name": BuiltInField.source, "type": "string"},
{"name": BuiltInField.document_name, "type": DatasetMetadataType.STRING},
{"name": BuiltInField.uploader, "type": DatasetMetadataType.STRING},
{"name": BuiltInField.upload_date, "type": DatasetMetadataType.TIME},
{"name": BuiltInField.last_update_date, "type": DatasetMetadataType.TIME},
{"name": BuiltInField.source, "type": DatasetMetadataType.STRING},
]
@staticmethod

View File

@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.enums import CredentialSourceType
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
logger = logging.getLogger(__name__)
@ -103,9 +104,9 @@ class ModelLoadBalancingService:
is_load_balancing_enabled = True
if config_from == "predefined-model":
credential_source_type = "provider"
credential_source_type = CredentialSourceType.PROVIDER
else:
credential_source_type = "custom_model"
credential_source_type = CredentialSourceType.CUSTOM_MODEL
# Get load balancing configurations
load_balancing_configs = (
@ -421,7 +422,11 @@ class ModelLoadBalancingService:
raise ValueError("Invalid load balancing config name")
if credential_id:
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
credential_source = (
CredentialSourceType.PROVIDER
if config_from == "predefined-model"
else CredentialSourceType.CUSTOM_MODEL
)
assert credential_record is not None
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,

View File

@ -6,6 +6,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import IndexingStatus
from models.model import Account, App, EndUser
from models.workflow import Workflow
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -111,6 +112,6 @@ class PipelineGenerateService:
"""
document = db.session.query(Document).where(Document.id == document_id).first()
if document:
document.indexing_status = "waiting"
document.indexing_status = IndexingStatus.WAITING
db.session.add(document)
db.session.commit()

View File

@ -15,7 +15,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
Retrieval recommended app from dify official
"""
def get_pipeline_template_detail(self, template_id: str):
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
result: dict | None
try:
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
@ -35,17 +36,23 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return PipelineTemplateType.REMOTE
@classmethod
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict:
"""
Fetch pipeline template detail from dify official.
:param template_id: Pipeline ID
:return:
:param template_id: Pipeline template ID
:return: Template detail dict
:raises ValueError: When upstream returns a non-200 status code
"""
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/pipeline-templates/{template_id}"
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
if response.status_code != 200:
return None
raise ValueError(
"fetch pipeline template detail failed,"
+ f" status_code: {response.status_code},"
+ f" response: {response.text[:1000]}"
)
data: dict = response.json()
return data

View File

@ -64,7 +64,7 @@ from models.dataset import ( # type: ignore
PipelineCustomizedTemplate,
PipelineRecommendedPlugin,
)
from models.enums import WorkflowRunTriggeredFrom
from models.enums import IndexingStatus, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
Workflow,
@ -117,13 +117,21 @@ class RagPipelineService:
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
"""
Get pipeline template detail.
:param template_id: template id
:return:
:param type: template type, "built-in" or "customized"
:return: template detail dict, or None if not found
"""
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
if built_in_result is None:
logger.warning(
"pipeline template retrieval returned empty result, template_id: %s, mode: %s",
template_id,
mode,
)
return built_in_result
else:
mode = "customized"
@ -906,7 +914,7 @@ class RagPipelineService:
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = error
db.session.add(document)
db.session.commit()

View File

@ -35,6 +35,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.enums import CollectionBindingType, DatasetRuntimeMode
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import (
IconInfo,
@ -313,7 +314,7 @@ class RagPipelineDslService:
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == "high_quality":
@ -323,7 +324,7 @@ class RagPipelineDslService:
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
@ -334,7 +335,7 @@ class RagPipelineDslService:
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
type=CollectionBindingType.DATASET,
)
self._session.add(dataset_collection_binding)
self._session.commit()
@ -445,13 +446,13 @@ class RagPipelineDslService:
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
@ -460,7 +461,7 @@ class RagPipelineDslService:
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
@ -471,7 +472,7 @@ class RagPipelineDslService:
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
type=CollectionBindingType.DATASET,
)
self._session.add(dataset_collection_binding)
self._session.commit()

View File

@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import DatasetRuntimeMode, DataSourceType
from models.model import UploadFile
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
@ -27,7 +28,7 @@ class RagPipelineTransformService:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
return {
"pipeline_id": dataset.pipeline_id,
"dataset_id": dataset_id,
@ -85,7 +86,7 @@ class RagPipelineTransformService:
else:
raise ValueError("Unsupported doc form")
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.pipeline_id = pipeline.id
# deal document data
@ -102,7 +103,7 @@ class RagPipelineTransformService:
pipeline_yaml = {}
if doc_form == "text_model":
match datasource_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
# get graph from transform.file-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
@ -111,7 +112,7 @@ class RagPipelineTransformService:
# get graph from transform.file-general-economy.yml
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
case DataSourceType.NOTION_IMPORT:
if indexing_technique == "high_quality":
# get graph from transform.notion-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
@ -120,7 +121,7 @@ class RagPipelineTransformService:
# get graph from transform.notion-general-economy.yml
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
if indexing_technique == "high_quality":
# get graph from transform.website-crawl-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
@ -133,15 +134,15 @@ class RagPipelineTransformService:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
match datasource_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
# get graph from transform.file-parentchild.yml
with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
case DataSourceType.NOTION_IMPORT:
# get graph from transform.notion-parentchild.yml
with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
# get graph from transform.website-crawl-parentchild.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
@ -287,7 +288,7 @@ class RagPipelineTransformService:
db.session.flush()
dataset.pipeline_id = pipeline.id
dataset.runtime_mode = "rag_pipeline"
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
dataset.updated_by = current_user.id
dataset.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.add(dataset)
@ -310,8 +311,8 @@ class RagPipelineTransformService:
data_source_info_dict = document.data_source_info_dict
if not data_source_info_dict:
continue
if document.data_source_type == "upload_file":
document.data_source_type = "local_file"
if document.data_source_type == DataSourceType.UPLOAD_FILE:
document.data_source_type = DataSourceType.LOCAL_FILE
file_id = data_source_info_dict.get("upload_file_id")
if file_id:
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
@ -331,7 +332,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="local_file",
datasource_type=DataSourceType.LOCAL_FILE,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
@ -340,8 +341,8 @@ class RagPipelineTransformService:
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "notion_import":
document.data_source_type = "online_document"
elif document.data_source_type == DataSourceType.NOTION_IMPORT:
document.data_source_type = DataSourceType.ONLINE_DOCUMENT
data_source_info = json.dumps(
{
"workspace_id": data_source_info_dict.get("notion_workspace_id"),
@ -359,7 +360,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="online_document",
datasource_type=DataSourceType.ONLINE_DOCUMENT,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
@ -368,8 +369,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "website_crawl":
document.data_source_type = "website_crawl"
elif document.data_source_type == DataSourceType.WEBSITE_CRAWL:
data_source_info = json.dumps(
{
"source_url": data_source_info_dict.get("url"),
@ -388,7 +388,7 @@ class RagPipelineTransformService:
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="website_crawl",
datasource_type=DataSourceType.WEBSITE_CRAWL,
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,

View File

@ -12,12 +12,14 @@ from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.models.document import Document
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.entities.model_entities import ModelType
from libs import helper
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from models.enums import SummaryStatus
logger = logging.getLogger(__name__)
@ -29,7 +31,7 @@ class SummaryIndexService:
def generate_summary_for_segment(
segment: DocumentSegment,
dataset: Dataset,
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
) -> tuple[str, LLMUsage]:
"""
Generate summary for a single segment.
@ -73,7 +75,7 @@ class SummaryIndexService:
segment: DocumentSegment,
dataset: Dataset,
summary_content: str,
status: str = "generating",
status: SummaryStatus = SummaryStatus.GENERATING,
) -> DocumentSegmentSummary:
"""
Create or update a DocumentSegmentSummary record.
@ -83,7 +85,7 @@ class SummaryIndexService:
segment: DocumentSegment to create summary for
dataset: Dataset containing the segment
summary_content: Generated summary content
status: Summary status (default: "generating")
status: Summary status (default: SummaryStatus.GENERATING)
Returns:
Created or updated DocumentSegmentSummary instance
@ -326,7 +328,7 @@ class SummaryIndexService:
summary_index_node_id=summary_index_node_id,
summary_index_node_hash=summary_hash,
tokens=embedding_tokens,
status="completed",
status=SummaryStatus.COMPLETED,
enabled=True,
)
session.add(summary_record_in_session)
@ -362,7 +364,7 @@ class SummaryIndexService:
summary_record_in_session.summary_index_node_id = summary_index_node_id
summary_record_in_session.summary_index_node_hash = summary_hash
summary_record_in_session.tokens = embedding_tokens # Save embedding tokens
summary_record_in_session.status = "completed"
summary_record_in_session.status = SummaryStatus.COMPLETED
# Ensure summary_content is preserved (use the latest from summary_record parameter)
# This is critical: use the parameter value, not the database value
summary_record_in_session.summary_content = summary_content
@ -400,7 +402,7 @@ class SummaryIndexService:
summary_record.summary_index_node_id = summary_index_node_id
summary_record.summary_index_node_hash = summary_hash
summary_record.tokens = embedding_tokens
summary_record.status = "completed"
summary_record.status = SummaryStatus.COMPLETED
summary_record.summary_content = summary_content
if summary_record_in_session.updated_at:
summary_record.updated_at = summary_record_in_session.updated_at
@ -487,7 +489,7 @@ class SummaryIndexService:
)
if summary_record_in_session:
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = f"Vectorization failed: {str(e)}"
summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None)
error_session.add(summary_record_in_session)
@ -498,7 +500,7 @@ class SummaryIndexService:
summary_record_in_session.id,
)
# Update the original object for consistency
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = summary_record_in_session.error
summary_record.updated_at = summary_record_in_session.updated_at
else:
@ -514,7 +516,7 @@ class SummaryIndexService:
def batch_create_summary_records(
segments: list[DocumentSegment],
dataset: Dataset,
status: str = "not_started",
status: SummaryStatus = SummaryStatus.NOT_STARTED,
) -> None:
"""
Batch create summary records for segments with specified status.
@ -523,7 +525,7 @@ class SummaryIndexService:
Args:
segments: List of DocumentSegment instances
dataset: Dataset containing the segments
status: Initial status for the records (default: "not_started")
status: Initial status for the records (default: SummaryStatus.NOT_STARTED)
"""
segment_ids = [segment.id for segment in segments]
if not segment_ids:
@ -588,7 +590,7 @@ class SummaryIndexService:
)
if summary_record:
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = error
session.add(summary_record)
session.commit()
@ -599,7 +601,7 @@ class SummaryIndexService:
def generate_and_vectorize_summary(
segment: DocumentSegment,
dataset: Dataset,
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
) -> DocumentSegmentSummary:
"""
Generate summary for a segment and vectorize it.
@ -631,14 +633,14 @@ class SummaryIndexService:
document_id=segment.document_id,
chunk_id=segment.id,
summary_content="",
status="generating",
status=SummaryStatus.GENERATING,
enabled=True,
)
session.add(summary_record_in_session)
session.flush()
# Update status to "generating"
summary_record_in_session.status = "generating"
summary_record_in_session.status = SummaryStatus.GENERATING
summary_record_in_session.error = None # type: ignore[assignment]
session.add(summary_record_in_session)
# Don't flush here - wait until after vectorization succeeds
@ -681,7 +683,7 @@ class SummaryIndexService:
except Exception as vectorize_error:
# If vectorization fails, update status to error in current session
logger.exception("Failed to vectorize summary for segment %s", segment.id)
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}"
session.add(summary_record_in_session)
session.commit()
@ -694,7 +696,7 @@ class SummaryIndexService:
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if summary_record_in_session:
summary_record_in_session.status = "error"
summary_record_in_session.status = SummaryStatus.ERROR
summary_record_in_session.error = str(e)
session.add(summary_record_in_session)
session.commit()
@ -704,7 +706,7 @@ class SummaryIndexService:
def generate_summaries_for_document(
dataset: Dataset,
document: DatasetDocument,
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
segment_ids: list[str] | None = None,
only_parent_chunks: bool = False,
) -> list[DocumentSegmentSummary]:
@ -770,7 +772,7 @@ class SummaryIndexService:
SummaryIndexService.batch_create_summary_records(
segments=segments,
dataset=dataset,
status="not_started",
status=SummaryStatus.NOT_STARTED,
)
summary_records = []
@ -1067,7 +1069,7 @@ class SummaryIndexService:
# Update summary content
summary_record.summary_content = summary_content
summary_record.status = "generating"
summary_record.status = SummaryStatus.GENERATING
summary_record.error = None # type: ignore[assignment] # Clear any previous errors
session.add(summary_record)
# Flush to ensure summary_content is saved before vectorize_summary queries it
@ -1102,7 +1104,7 @@ class SummaryIndexService:
# If vectorization fails, update status to error in current session
# Don't raise the exception - just log it and return the record with error status
# This allows the segment update to complete even if vectorization fails
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = f"Vectorization failed: {str(e)}"
session.commit()
logger.exception("Failed to vectorize summary for segment %s", segment.id)
@ -1112,7 +1114,7 @@ class SummaryIndexService:
else:
# Create new summary record if doesn't exist
summary_record = SummaryIndexService.create_summary_record(
segment, dataset, summary_content, status="generating"
segment, dataset, summary_content, status=SummaryStatus.GENERATING
)
# Re-vectorize summary (this will update status to "completed" and tokens in its own session)
# Note: summary_record was created in a different session,
@ -1132,7 +1134,7 @@ class SummaryIndexService:
# If vectorization fails, update status to error in current session
# Merge the record into current session first
error_record = session.merge(summary_record)
error_record.status = "error"
error_record.status = SummaryStatus.ERROR
error_record.error = f"Vectorization failed: {str(e)}"
session.commit()
logger.exception("Failed to vectorize summary for segment %s", segment.id)
@ -1146,7 +1148,7 @@ class SummaryIndexService:
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if summary_record:
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = str(e)
session.add(summary_record)
session.commit()
@ -1266,7 +1268,7 @@ class SummaryIndexService:
# Check if there are any "not_started" or "generating" status summaries
has_pending_summaries = any(
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
and summary_status_map[segment_id] in ("not_started", "generating")
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
for segment_id in segment_ids
)
@ -1330,7 +1332,7 @@ class SummaryIndexService:
# it means the summary is disabled (enabled=False) or not created yet, ignore it
has_pending_summaries = any(
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
and summary_status_map[segment_id] in ("not_started", "generating")
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
for segment_id in segment_ids
)
@ -1393,17 +1395,17 @@ class SummaryIndexService:
# Count statuses
status_counts = {
"completed": 0,
"generating": 0,
"error": 0,
"not_started": 0,
SummaryStatus.COMPLETED: 0,
SummaryStatus.GENERATING: 0,
SummaryStatus.ERROR: 0,
SummaryStatus.NOT_STARTED: 0,
}
summary_list = []
for segment in segments:
summary = summary_map.get(segment.id)
if summary:
status = summary.status
status = SummaryStatus(summary.status)
status_counts[status] = status_counts.get(status, 0) + 1
summary_list.append(
{
@ -1421,12 +1423,12 @@ class SummaryIndexService:
}
)
else:
status_counts["not_started"] += 1
status_counts[SummaryStatus.NOT_STARTED] += 1
summary_list.append(
{
"segment_id": segment.id,
"segment_position": segment.position,
"status": "not_started",
"status": SummaryStatus.NOT_STARTED,
"summary_preview": None,
"error": None,
"created_at": None,

View File

@ -13,6 +13,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -34,7 +35,7 @@ def add_document_to_index_task(dataset_document_id: str):
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
return
if dataset_document.indexing_status != "completed":
if dataset_document.indexing_status != IndexingStatus.COMPLETED:
return
indexing_cache_key = f"document_{dataset_document.id}_indexing"
@ -48,7 +49,7 @@ def add_document_to_index_task(dataset_document_id: str):
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.status == SegmentStatus.COMPLETED,
)
.order_by(DocumentSegment.position.asc())
.all()
@ -139,7 +140,7 @@ def add_document_to_index_task(dataset_document_id: str):
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.indexing_status = IndexingStatus.ERROR
dataset_document.error = str(e)
session.commit()
finally:

View File

@ -11,6 +11,7 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.model import App, AppAnnotationSetting, MessageAnnotation
from services.dataset_service import DatasetCollectionBindingService
@ -47,7 +48,7 @@ def enable_annotation_reply_task(
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
)
annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
@ -56,7 +57,7 @@ def enable_annotation_reply_task(
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION
)
)
if old_dataset_collection_binding and annotations:

View File

@ -10,6 +10,7 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
if segment.status != "waiting":
if segment.status != SegmentStatus.WAITING:
return
indexing_cache_key = f"segment_{segment.id}_indexing"
@ -40,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
# update segment status to indexing
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.status: SegmentStatus.INDEXING,
DocumentSegment.indexing_at: naive_utc_now(),
}
)
@ -70,7 +71,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
or dataset_document.indexing_status != IndexingStatus.COMPLETED
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
@ -82,7 +83,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
# update segment to completed
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.completed_at: naive_utc_now(),
}
)
@ -94,7 +95,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
session.commit()
finally:

View File

@ -12,6 +12,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@ -37,7 +38,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status == "parsing":
if document.indexing_status == IndexingStatus.PARSING:
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
@ -88,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
return
@ -128,7 +129,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = json.dumps(data_source_info)
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
@ -151,6 +152,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()

View File

@ -14,6 +14,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from models.enums import IndexingStatus
from services.feature_service import FeatureService
from tasks.generate_summary_index_task import generate_summary_index_task
@ -81,7 +82,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -96,7 +97,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
for document in documents:
if document:
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
# Transaction committed and closed
@ -148,7 +149,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
document.need_summary,
)
if (
document.indexing_status == "completed"
document.indexing_status == IndexingStatus.COMPLETED
and document.doc_form != "qa_model"
and document.need_summary is True
):

View File

@ -10,6 +10,7 @@ from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()

View File

@ -15,6 +15,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@ -112,7 +113,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
)
for document in documents:
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -146,7 +147,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()

View File

@ -12,6 +12,7 @@ from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
from models.enums import IndexingStatus, SegmentStatus
logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ def enable_segment_to_index_task(segment_id: str):
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
if segment.status != "completed":
if segment.status != SegmentStatus.COMPLETED:
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
return
@ -65,7 +66,7 @@ def enable_segment_to_index_task(segment_id: str):
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
or dataset_document.indexing_status != IndexingStatus.COMPLETED
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
@ -123,7 +124,7 @@ def enable_segment_to_index_task(segment_id: str):
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.status = SegmentStatus.ERROR
segment.error = str(e)
session.commit()
finally:

View File

@ -12,6 +12,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -63,7 +64,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
.first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -95,7 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
@ -108,7 +109,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(ex)
document.stopped_at = naive_utc_now()
session.add(document)

View File

@ -11,6 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@ -48,7 +49,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
@ -76,7 +77,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
@ -85,7 +86,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
document.error = str(ex)
document.stopped_at = naive_utc_now()
session.add(document)

View File

@ -7,6 +7,7 @@ from faker import Faker
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.account_service import AccountService, TenantService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -35,7 +36,7 @@ class TestGetAvailableDatasetsIntegration:
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
)
@ -49,14 +50,14 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
name=f"Document {i}",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -94,7 +95,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -106,13 +107,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Archived
)
@ -147,7 +148,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -159,13 +160,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Disabled
archived=False,
)
@ -200,21 +201,21 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create documents with non-completed status
for i, status in enumerate(["indexing", "parsing", "splitting"]):
for i, status in enumerate([IndexingStatus.INDEXING, IndexingStatus.PARSING, IndexingStatus.SPLITTING]):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
@ -263,7 +264,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="external", # External provider
data_source_type="external",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -307,7 +308,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant1.id,
name="Tenant 1 Dataset",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account1.id,
)
db_session_with_containers.add(dataset1)
@ -318,7 +319,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant2.id,
name="Tenant 2 Dataset",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account2.id,
)
db_session_with_containers.add(dataset2)
@ -330,13 +331,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -398,7 +399,7 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
name=f"Dataset {i}",
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -410,13 +411,13 @@ class TestGetAvailableDatasetsIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -456,7 +457,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
indexing_technique="high_quality",
)
@ -467,12 +468,12 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=str(uuid.uuid4()), # Required field
created_from="web",
created_from=DocumentCreatedFrom.WEB,
name=fake.sentence(),
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
@ -525,7 +526,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -572,7 +573,7 @@ class TestKnowledgeRetrievalIntegration:
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)

View File

@ -12,6 +12,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
class TestDatasetDocumentProperties:
@ -29,7 +30,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -39,10 +40,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=i + 1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name=f"doc_{i}.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -56,7 +57,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -65,12 +66,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="available.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
)
@ -78,12 +79,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="pending.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
archived=False,
)
@ -91,12 +92,12 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="disabled.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False,
archived=False,
)
@ -111,7 +112,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -121,10 +122,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=i + 1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name=f"doc_{i}.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
word_count=wc,
)
@ -139,7 +140,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -148,10 +149,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -166,7 +167,7 @@ class TestDatasetDocumentProperties:
content=f"segment {i}",
word_count=100,
tokens=50,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
created_by=created_by,
)
@ -180,7 +181,7 @@ class TestDatasetDocumentProperties:
content="waiting segment",
word_count=100,
tokens=50,
status="waiting",
status=SegmentStatus.WAITING,
enabled=True,
created_by=created_by,
)
@ -195,7 +196,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -204,10 +205,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -235,7 +236,7 @@ class TestDatasetDocumentProperties:
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
@ -244,10 +245,10 @@ class TestDatasetDocumentProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="doc.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(doc)
@ -288,7 +289,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -298,10 +299,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -335,7 +336,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -345,10 +346,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -382,7 +383,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -392,10 +393,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)
@ -439,7 +440,7 @@ class TestDocumentSegmentNavigationProperties:
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session_with_containers.add(dataset)
@ -449,10 +450,10 @@ class TestDocumentSegmentNavigationProperties:
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="batch_001",
name="test.pdf",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
)
db_session_with_containers.add(document)

View File

@ -12,6 +12,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import DatasetCollectionBinding
from models.enums import CollectionBindingType
from services.dataset_service import DatasetCollectionBindingService
@ -32,7 +33,7 @@ class DatasetCollectionBindingTestDataFactory:
provider_name: str = "openai",
model_name: str = "text-embedding-ada-002",
collection_name: str = "collection-abc",
collection_type: str = "dataset",
collection_type: str = CollectionBindingType.DATASET,
) -> DatasetCollectionBinding:
"""
Create a DatasetCollectionBinding with specified attributes.
@ -41,7 +42,7 @@ class DatasetCollectionBindingTestDataFactory:
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
collection_name: Name of the vector database collection
collection_type: Type of collection (default: "dataset")
collection_type: Type of collection (default: CollectionBindingType.DATASET)
Returns:
DatasetCollectionBinding instance
@ -76,7 +77,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
# Arrange
provider_name = "openai"
model_name = "text-embedding-ada-002"
collection_type = "dataset"
collection_type = CollectionBindingType.DATASET
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name=provider_name,
@ -104,7 +105,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
# Arrange
provider_name = f"provider-{uuid4()}"
model_name = f"model-{uuid4()}"
collection_type = "dataset"
collection_type = CollectionBindingType.DATASET
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding(
@ -145,7 +146,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name)
# Assert
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
assert result.provider_name == provider_name
assert result.model_name == model_name
@ -186,18 +187,20 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "dataset")
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
binding.id, CollectionBindingType.DATASET
)
# Assert
assert result.id == binding.id
assert result.provider_name == "openai"
assert result.model_name == "text-embedding-ada-002"
assert result.collection_name == "test-collection"
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session):
"""Test error handling when collection binding is not found by ID and type."""
@ -206,7 +209,9 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
# Act & Assert
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset")
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
non_existent_id, CollectionBindingType.DATASET
)
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(
self, db_session_with_containers: Session
@ -240,7 +245,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act
@ -248,7 +253,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
# Assert
assert result.id == binding.id
assert result.type == "dataset"
assert result.type == CollectionBindingType.DATASET
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session):
"""Test error when binding exists but with wrong collection type."""
@ -258,7 +263,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
collection_type="dataset",
collection_type=CollectionBindingType.DATASET,
)
# Act & Assert

View File

@ -15,6 +15,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
from models.enums import DataSourceType
from models.model import App
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
@ -72,7 +73,7 @@ class DatasetUpdateDeleteTestDataFactory:
tenant_id=tenant_id,
name=name,
description="Test description",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=permission,

View File

@ -15,7 +15,7 @@ import pytest
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
from models.model import UploadFile
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -88,7 +88,7 @@ class DocumentStatusTestDataFactory:
data_source_info=json.dumps(data_source_info or {}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form="text_model",
)
@ -100,7 +100,7 @@ class DocumentStatusTestDataFactory:
document.paused_by = paused_by
document.paused_at = paused_at
document.doc_metadata = doc_metadata or {}
if indexing_status == "completed" and "completed_at" not in kwargs:
if indexing_status == IndexingStatus.COMPLETED and "completed_at" not in kwargs:
document.completed_at = FIXED_TIME
for key, value in kwargs.items():
@ -139,7 +139,7 @@ class DocumentStatusTestDataFactory:
dataset = Dataset(
tenant_id=tenant_id,
name=name,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
dataset.id = dataset_id
@ -291,7 +291,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
is_paused=False,
)
@ -326,7 +326,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=False,
)
@ -354,7 +354,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="parsing",
indexing_status=IndexingStatus.PARSING,
is_paused=False,
)
@ -383,7 +383,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
is_paused=False,
)
@ -412,7 +412,7 @@ class TestDocumentServicePauseDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
is_paused=False,
)
@ -487,7 +487,7 @@ class TestDocumentServiceRecoverDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=True,
paused_by=str(uuid4()),
paused_at=paused_time,
@ -526,7 +526,7 @@ class TestDocumentServiceRecoverDocument:
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
is_paused=False,
)
@ -609,7 +609,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = None
@ -619,7 +619,7 @@ class TestDocumentServiceRetryDocument:
# Assert
db_session_with_containers.refresh(document)
assert document.indexing_status == "waiting"
assert document.indexing_status == IndexingStatus.WAITING
expected_cache_key = f"document_{document.id}_is_retried"
mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1)
@ -646,14 +646,14 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
document2 = DocumentStatusTestDataFactory.create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
position=2,
)
@ -665,8 +665,8 @@ class TestDocumentServiceRetryDocument:
# Assert
db_session_with_containers.refresh(document1)
db_session_with_containers.refresh(document2)
assert document1.indexing_status == "waiting"
assert document2.indexing_status == "waiting"
assert document1.indexing_status == IndexingStatus.WAITING
assert document2.indexing_status == IndexingStatus.WAITING
mock_document_service_dependencies["retry_task"].delay.assert_called_once_with(
dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"]
@ -693,7 +693,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = "1"
@ -703,7 +703,7 @@ class TestDocumentServiceRetryDocument:
DocumentService.retry_document(dataset.id, [document])
db_session_with_containers.refresh(document)
assert document.indexing_status == "error"
assert document.indexing_status == IndexingStatus.ERROR
def test_retry_document_missing_current_user_error(
self, db_session_with_containers, mock_document_service_dependencies
@ -726,7 +726,7 @@ class TestDocumentServiceRetryDocument:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="error",
indexing_status=IndexingStatus.ERROR,
)
mock_document_service_dependencies["redis_client"].get.return_value = None
@ -816,7 +816,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document2 = DocumentStatusTestDataFactory.create_document(
db_session_with_containers,
@ -824,7 +824,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
position=2,
)
document_ids = [document1.id, document2.id]
@ -866,7 +866,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
completed_at=FIXED_TIME,
)
document_ids = [document.id]
@ -909,7 +909,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
document_id=str(uuid4()),
archived=False,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -951,7 +951,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
document_id=str(uuid4()),
archived=True,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -1015,7 +1015,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
document_id=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
document_ids = [document.id]
@ -1098,7 +1098,7 @@ class TestDocumentServiceRenameDocument:
document_id=document_id,
dataset_id=dataset.id,
tenant_id=tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1139,7 +1139,7 @@ class TestDocumentServiceRenameDocument:
dataset_id=dataset.id,
tenant_id=tenant_id,
doc_metadata={"existing_key": "existing_value"},
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1187,7 +1187,7 @@ class TestDocumentServiceRenameDocument:
dataset_id=dataset.id,
tenant_id=tenant_id,
data_source_info={"upload_file_id": upload_file.id},
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -1277,7 +1277,7 @@ class TestDocumentServiceRenameDocument:
document_id=document_id,
dataset_id=dataset.id,
tenant_id=str(uuid4()),
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act & Assert

View File

@ -16,6 +16,7 @@ from models.dataset import (
DatasetPermission,
DatasetPermissionEnum,
)
from models.enums import DataSourceType
from services.dataset_service import DatasetPermissionService, DatasetService
from services.errors.account import NoPermissionError
@ -67,7 +68,7 @@ class DatasetPermissionTestDataFactory:
tenant_id=tenant_id,
name=name,
description="desc",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=permission,

View File

@ -15,6 +15,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline
from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -74,7 +75,7 @@ class DatasetServiceIntegrationDataFactory:
tenant_id=tenant_id,
name=name,
description=description,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
created_by=created_by,
provider=provider,
@ -98,13 +99,13 @@ class DatasetServiceIntegrationDataFactory:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info='{"upload_file_id": "upload-file-id"}',
batch=str(uuid4()),
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
)
db_session_with_containers.add(document)
@ -437,7 +438,7 @@ class TestDatasetServiceCreateRagPipelineDataset:
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
assert created_dataset is not None
assert created_dataset.name == entity.name
assert created_dataset.runtime_mode == "rag_pipeline"
assert created_dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE
assert created_dataset.created_by == account.id
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
assert created_pipeline is not None

View File

@ -14,6 +14,7 @@ import pytest
from sqlalchemy.orm import Session
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -42,7 +43,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
dataset = Dataset(
tenant_id=tenant_id or str(uuid4()),
name=name,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by or str(uuid4()),
)
if dataset_id:
@ -72,11 +73,11 @@ class DocumentBatchUpdateIntegrationDataFactory:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=position,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info=json.dumps({"upload_file_id": str(uuid4())}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by or str(uuid4()),
doc_form="text_model",
)
@ -85,7 +86,9 @@ class DocumentBatchUpdateIntegrationDataFactory:
document.archived = archived
document.indexing_status = indexing_status
document.completed_at = (
completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None)
completed_at
if completed_at is not None
else (FIXED_TIME if indexing_status == IndexingStatus.COMPLETED else None)
)
for key, value in kwargs.items():
@ -243,7 +246,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
dataset=dataset,
document_ids=document_ids,
enabled=True,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
# Act
@ -277,7 +280,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
db_session_with_containers,
dataset=dataset,
enabled=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
completed_at=FIXED_TIME,
)
@ -306,7 +309,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
db_session_with_containers,
dataset=dataset,
enabled=True,
indexing_status="indexing",
indexing_status=IndexingStatus.INDEXING,
completed_at=None,
)

View File

@ -5,6 +5,7 @@ from uuid import uuid4
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom
from services.dataset_service import DatasetService
@ -58,7 +59,7 @@ class DatasetDeleteIntegrationDataFactory:
dataset = Dataset(
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
index_struct=index_struct,
created_by=created_by,
@ -84,10 +85,10 @@ class DatasetDeleteIntegrationDataFactory:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=f"batch-{uuid4()}",
name="Document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form=doc_form,
)

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom
from services.dataset_service import SegmentService
@ -62,7 +63,7 @@ class SegmentServiceTestDataFactory:
tenant_id=tenant_id,
name=f"Test Dataset {uuid4()}",
description="Test description",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=DatasetPermissionEnum.ONLY_ME,
@ -82,10 +83,10 @@ class SegmentServiceTestDataFactory:
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=f"batch-{uuid4()}",
name=f"test-doc-{uuid4()}.txt",
created_from="api",
created_from=DocumentCreatedFrom.API,
created_by=created_by,
)
db_session_with_containers.add(document)

View File

@ -24,6 +24,7 @@ from models.dataset import (
DatasetProcessRule,
DatasetQuery,
)
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode
from models.model import Tag, TagBinding
from services.dataset_service import DatasetService, DocumentService
@ -100,7 +101,7 @@ class DatasetRetrievalTestDataFactory:
tenant_id=tenant_id,
name=name,
description="desc",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=created_by,
permission=permission,
@ -149,7 +150,7 @@ class DatasetRetrievalTestDataFactory:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=content,
source="web",
source=DatasetQuerySource.APP,
source_app_id=None,
created_by_role="account",
created_by=created_by,
@ -601,7 +602,7 @@ class TestDatasetServiceGetProcessRules:
db_session_with_containers,
dataset_id=dataset.id,
created_by=account.id,
mode="custom",
mode=ProcessRuleMode.CUSTOM,
rules=rules_data,
)

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, ExternalKnowledgeBindings
from models.enums import DataSourceType
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
@ -64,7 +65,7 @@ class DatasetUpdateTestDataFactory:
tenant_id=tenant_id,
name=name,
description=description,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
created_by=created_by,
provider=provider,

View File

@ -4,6 +4,7 @@ from uuid import uuid4
from sqlalchemy import select
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
@ -11,7 +12,7 @@ def _create_dataset(db_session_with_containers) -> Dataset:
dataset = Dataset(
tenant_id=str(uuid4()),
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
)
dataset.id = str(uuid4())
@ -35,11 +36,11 @@ def _create_document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info="{}",
batch=f"batch-{uuid4()}",
name=f"doc-{uuid4()}",
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
)
@ -48,7 +49,7 @@ def _create_document(
document.enabled = enabled
document.archived = archived
document.is_paused = is_paused
if indexing_status == "completed":
if indexing_status == IndexingStatus.COMPLETED:
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db_session_with_containers.add(document)
@ -62,7 +63,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
position=1,
@ -71,7 +72,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False,
archived=False,
position=2,
@ -80,7 +81,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True,
position=3,
@ -101,14 +102,14 @@ def test_apply_display_status_filter_applies_when_status_present(db_session_with
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
position=1,
)
_create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
position=2,
)
@ -125,14 +126,14 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
position=1,
)
doc2 = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
position=2,
)

View File

@ -9,7 +9,7 @@ import pytest
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom
from models.model import UploadFile
from services.dataset_service import DocumentService
@ -33,7 +33,7 @@ def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, bu
dataset = Dataset(
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
)
dataset.id = dataset_id
@ -62,11 +62,11 @@ def make_document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info=json.dumps(data_source_info or {}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
)

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import DataSourceType
from models.model import (
App,
AppAnnotationHitHistory,
@ -287,7 +288,7 @@ class TestMessagesCleanServiceIntegration:
dataset_name="Test dataset",
document_id=str(uuid.uuid4()),
document_name="Test document",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
segment_id=str(uuid.uuid4()),
score=0.9,
content="Test content",

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from core.rag.index_processor.constant.built_in_field import BuiltInField
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
@ -101,7 +102,7 @@ class TestMetadataService:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
built_in_field_enabled=False,
)
@ -132,11 +133,11 @@ class TestMetadataService:
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info="{}",
batch="test-batch",
name=fake.file_name(),
created_from="web",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text",
doc_language="en",
@ -163,7 +164,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
# Act: Execute the method under test
result = MetadataService.create_metadata(dataset.id, metadata_args)
@ -201,7 +202,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
long_name = "a" * 256 # 256 characters, exceeding 255 limit
metadata_args = MetadataArgs(type="string", name=long_name)
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
@ -226,11 +227,11 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create first metadata
first_metadata_args = MetadataArgs(type="string", name="duplicate_name")
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name")
MetadataService.create_metadata(dataset.id, first_metadata_args)
# Try to create second metadata with same name
second_metadata_args = MetadataArgs(type="number", name="duplicate_name")
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name")
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name already exists."):
@ -256,7 +257,7 @@ class TestMetadataService:
# Try to create metadata with built-in field name
built_in_field_name = BuiltInField.document_name
metadata_args = MetadataArgs(type="string", name=built_in_field_name)
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
@ -281,7 +282,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type="string", name="old_name")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Act: Execute the method under test
@ -318,7 +319,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type="string", name="old_name")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Try to update with too long name
@ -347,10 +348,10 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create two metadata entries
first_metadata_args = MetadataArgs(type="string", name="first_metadata")
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata")
first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args)
second_metadata_args = MetadataArgs(type="number", name="second_metadata")
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata")
second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args)
# Try to update first metadata with second metadata's name
@ -376,7 +377,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type="string", name="old_name")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Try to update with built-in field name
@ -432,7 +433,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type="string", name="to_be_deleted")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Act: Execute the method under test
@ -496,7 +497,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Create metadata binding
@ -798,7 +799,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Mock DocumentService.get_document
@ -866,7 +867,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Mock DocumentService.get_document
@ -917,7 +918,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Create metadata operation data
@ -1038,7 +1039,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Create document and metadata binding
@ -1101,7 +1102,7 @@ class TestMetadataService:
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
# Act: Execute the method under test

View File

@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset
from models.enums import DataSourceType
from models.model import App, Tag, TagBinding
from services.tag_service import TagService
@ -100,7 +101,7 @@ class TestTagService:
description=fake.text(max_nb_chars=100),
provider="vendor",
permission="only_me",
data_source_type="upload",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
tenant_id=tenant_id,
created_by=mock_external_service_dependencies["current_user"].id,

View File

@ -510,7 +510,7 @@ class TestWorkflowConverter:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=10,
score_threshold=0.8,
reranking_model={"provider": "cohere", "model": "rerank-v2"},
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
reranking_enabled=True,
),
)
@ -543,8 +543,8 @@ class TestWorkflowConverter:
multiple_config = node["data"]["multiple_retrieval_config"]
assert multiple_config["top_k"] == 10
assert multiple_config["score_threshold"] == 0.8
assert multiple_config["reranking_model"]["provider"] == "cohere"
assert multiple_config["reranking_model"]["model"] == "rerank-v2"
assert multiple_config["reranking_model"]["reranking_provider_name"] == "cohere"
assert multiple_config["reranking_model"]["reranking_model_name"] == "rerank-v2"
# Verify single retrieval config is None for multiple strategy
assert node["data"]["single_retrieval_config"] is None

View File

@ -8,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.add_document_to_index_task import add_document_to_index_task
@ -79,7 +80,7 @@ class TestAddDocumentToIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -92,12 +93,12 @@ class TestAddDocumentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
@ -137,7 +138,7 @@ class TestAddDocumentToIndexTask:
index_node_id=f"node_{i}",
index_node_hash=f"hash_{i}",
enabled=False,
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment)
@ -297,7 +298,7 @@ class TestAddDocumentToIndexTask:
)
# Set invalid indexing status
document.indexing_status = "processing"
document.indexing_status = IndexingStatus.INDEXING
db_session_with_containers.commit()
# Act: Execute the task
@ -339,7 +340,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify error handling
db_session_with_containers.refresh(document)
assert document.enabled is False
assert document.indexing_status == "error"
assert document.indexing_status == IndexingStatus.ERROR
assert document.error is not None
assert "doesn't exist" in document.error
assert document.disabled_at is not None
@ -434,7 +435,7 @@ class TestAddDocumentToIndexTask:
Test document indexing when segments are already enabled.
This test verifies:
- Segments with status="completed" are processed regardless of enabled status
- Segments with status=SegmentStatus.COMPLETED are processed regardless of enabled status
- Index processing occurs with all completed segments
- Auto disable log deletion still occurs
- Redis cache is cleared
@ -460,7 +461,7 @@ class TestAddDocumentToIndexTask:
index_node_id=f"node_{i}",
index_node_hash=f"hash_{i}",
enabled=True, # Already enabled
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment)
@ -482,7 +483,7 @@ class TestAddDocumentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with all completed segments
# (implementation doesn't filter by enabled status, only by status="completed")
# (implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED)
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
@ -594,7 +595,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify error handling
db_session_with_containers.refresh(document)
assert document.enabled is False
assert document.indexing_status == "error"
assert document.indexing_status == IndexingStatus.ERROR
assert document.error is not None
assert "Index processing failed" in document.error
assert document.disabled_at is not None
@ -614,7 +615,7 @@ class TestAddDocumentToIndexTask:
Test segment filtering with various edge cases.
This test verifies:
- Only segments with status="completed" are processed (regardless of enabled status)
- Only segments with status=SegmentStatus.COMPLETED are processed (regardless of enabled status)
- Segments with status!="completed" are NOT processed
- Segments are ordered by position correctly
- Mixed segment states are handled properly
@ -630,7 +631,7 @@ class TestAddDocumentToIndexTask:
fake = Faker()
segments = []
# Segment 1: Should be processed (enabled=False, status="completed")
# Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
segment1 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
@ -643,14 +644,14 @@ class TestAddDocumentToIndexTask:
index_node_id="node_0",
index_node_hash="hash_0",
enabled=False,
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment1)
segments.append(segment1)
# Segment 2: Should be processed (enabled=True, status="completed")
# Note: Implementation doesn't filter by enabled status, only by status="completed"
# Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED)
# Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED
segment2 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
@ -663,7 +664,7 @@ class TestAddDocumentToIndexTask:
index_node_id="node_1",
index_node_hash="hash_1",
enabled=True, # Already enabled, but will still be processed
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment2)
@ -682,13 +683,13 @@ class TestAddDocumentToIndexTask:
index_node_id="node_2",
index_node_hash="hash_2",
enabled=False,
status="processing", # Not completed
status=SegmentStatus.INDEXING, # Not completed
created_by=document.created_by,
)
db_session_with_containers.add(segment3)
segments.append(segment3)
# Segment 4: Should be processed (enabled=False, status="completed")
# Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
segment4 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
@ -701,7 +702,7 @@ class TestAddDocumentToIndexTask:
index_node_id="node_3",
index_node_hash="hash_3",
enabled=False,
status="completed",
status=SegmentStatus.COMPLETED,
created_by=document.created_by,
)
db_session_with_containers.add(segment4)
@ -726,7 +727,7 @@ class TestAddDocumentToIndexTask:
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 3 # 3 segments with status="completed" should be processed
assert len(documents) == 3 # 3 segments with status=SegmentStatus.COMPLETED should be processed
# Verify correct segments were processed (by position order)
# Segments 1, 2, 4 should be processed (positions 0, 1, 3)
@ -799,7 +800,7 @@ class TestAddDocumentToIndexTask:
# Assert: Verify consistent error handling
db_session_with_containers.refresh(document)
assert document.enabled is False, f"Document should be disabled for {error_name}"
assert document.indexing_status == "error", f"Document status should be error for {error_name}"
assert document.indexing_status == IndexingStatus.ERROR, f"Document status should be error for {error_name}"
assert document.error is not None, f"Error should be recorded for {error_name}"
assert str(exception) in document.error, f"Error message should contain exception for {error_name}"
assert document.disabled_at is not None, f"Disabled timestamp should be set for {error_name}"

View File

@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from models.model import UploadFile
from tasks.batch_clean_document_task import batch_clean_document_task
@ -113,7 +114,7 @@ class TestBatchCleanDocumentTask:
tenant_id=account.current_tenant.id,
name=fake.word(),
description=fake.sentence(),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
@ -144,12 +145,12 @@ class TestBatchCleanDocumentTask:
dataset_id=dataset.id,
position=0,
name=fake.word(),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info=json.dumps({"upload_file_id": str(uuid.uuid4())}),
batch="test_batch",
created_from="test",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
)
@ -183,7 +184,7 @@ class TestBatchCleanDocumentTask:
tokens=50,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
@ -297,7 +298,7 @@ class TestBatchCleanDocumentTask:
tokens=50,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
@ -671,7 +672,7 @@ class TestBatchCleanDocumentTask:
tokens=25 + i * 5,
index_node_id=str(uuid.uuid4()),
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
segments.append(segment)

View File

@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from models.model import UploadFile
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
@ -139,7 +139,7 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
@ -170,12 +170,12 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
@ -301,7 +301,7 @@ class TestBatchCreateSegmentToIndexTask:
assert segment.dataset_id == dataset.id
assert segment.document_id == document.id
assert segment.position == i + 1
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
assert segment.answer is None # text_model doesn't have answers
@ -442,12 +442,12 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name="disabled_document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Document is disabled
archived=False,
doc_form="text_model",
@ -458,12 +458,12 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name="archived_document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Document is archived
doc_form="text_model",
@ -474,12 +474,12 @@ class TestBatchCreateSegmentToIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name="incomplete_document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="indexing", # Not completed
indexing_status=IndexingStatus.INDEXING, # Not completed
enabled=True,
archived=False,
doc_form="text_model",
@ -643,7 +643,7 @@ class TestBatchCreateSegmentToIndexTask:
word_count=len(f"Existing segment {i + 1}"),
tokens=10,
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash=f"hash_{i}",
)
@ -694,7 +694,7 @@ class TestBatchCreateSegmentToIndexTask:
for i, segment in enumerate(new_segments):
expected_position = 4 + i # Should start at position 4
assert segment.position == expected_position
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None

View File

@ -29,7 +29,14 @@ from models.dataset import (
Document,
DocumentSegment,
)
from models.enums import CreatorUserRole
from models.enums import (
CreatorUserRole,
DatasetMetadataType,
DataSourceType,
DocumentCreatedFrom,
IndexingStatus,
SegmentStatus,
)
from models.model import UploadFile
from tasks.clean_dataset_task import clean_dataset_task
@ -176,12 +183,12 @@ class TestCleanDatasetTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name="test_document",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="paragraph_index",
@ -219,7 +226,7 @@ class TestCleanDatasetTask:
word_count=20,
tokens=30,
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash",
created_at=datetime.now(),
@ -373,7 +380,7 @@ class TestCleanDatasetTask:
dataset_id=dataset.id,
tenant_id=tenant.id,
name="test_metadata",
type="string",
type=DatasetMetadataType.STRING,
created_by=account.id,
)
metadata.id = str(uuid.uuid4())
@ -587,7 +594,7 @@ class TestCleanDatasetTask:
word_count=len(segment_content),
tokens=50,
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash",
created_at=datetime.now(),
@ -686,7 +693,7 @@ class TestCleanDatasetTask:
dataset_id=dataset.id,
tenant_id=tenant.id,
name=f"test_metadata_{i}",
type="string",
type=DatasetMetadataType.STRING,
created_by=account.id,
)
metadata.id = str(uuid.uuid4())
@ -880,11 +887,11 @@ class TestCleanDatasetTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info="{}",
batch="test_batch",
name=f"test_doc_{special_content}",
created_from="test",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
created_at=datetime.now(),
updated_at=datetime.now(),
@ -905,7 +912,7 @@ class TestCleanDatasetTask:
word_count=len(segment_content.split()),
tokens=len(segment_content) // 4, # Rough token estimation
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash_" + "x" * 50, # Long hash within limits
created_at=datetime.now(),
@ -946,7 +953,7 @@ class TestCleanDatasetTask:
dataset_id=dataset.id,
tenant_id=tenant.id,
name=f"metadata_{special_content}",
type="string",
type=DatasetMetadataType.STRING,
created_by=account.id,
)
special_metadata.id = str(uuid.uuid4())

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from services.account_service import AccountService, TenantService
from tasks.clean_notion_document_task import clean_notion_document_task
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -88,7 +89,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -105,17 +106,17 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model", # Set doc_form to ensure dataset.doc_form works
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -134,7 +135,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
segments.append(segment)
@ -220,7 +221,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -269,7 +270,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=f"{fake.company()}_{index_type}",
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -281,17 +282,17 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
),
batch="test_batch",
name="Test Notion Page",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form=index_type,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -308,7 +309,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id="test_node",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
db_session_with_containers.commit()
@ -357,7 +358,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -369,16 +370,16 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
),
batch="test_batch",
name="Test Notion Page",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -397,7 +398,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=None, # No index node ID
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
segments.append(segment)
@ -443,7 +444,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -460,16 +461,16 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -488,7 +489,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
all_segments.append(segment)
@ -558,7 +559,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -570,22 +571,22 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
),
batch="test_batch",
name="Test Notion Page",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
# Create segments with different statuses
segment_statuses = ["waiting", "processing", "completed", "error"]
segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR]
segments = []
index_node_ids = []
@ -654,7 +655,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -666,16 +667,16 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
),
batch="test_batch",
name="Test Notion Page",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -692,7 +693,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id="test_node",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
db_session_with_containers.commit()
@ -736,7 +737,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -754,16 +755,16 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -783,7 +784,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
all_segments.append(segment)
@ -848,7 +849,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=f"{fake.company()}_{i}",
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -866,16 +867,16 @@ class TestCleanNotionDocumentTask:
tenant_id=account.current_tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
@ -894,7 +895,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
all_segments.append(segment)
@ -963,14 +964,22 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
# Create documents with different indexing statuses
document_statuses = ["waiting", "parsing", "cleaning", "splitting", "indexing", "completed", "error"]
document_statuses = [
IndexingStatus.WAITING,
IndexingStatus.PARSING,
IndexingStatus.CLEANING,
IndexingStatus.SPLITTING,
IndexingStatus.INDEXING,
IndexingStatus.COMPLETED,
IndexingStatus.ERROR,
]
documents = []
all_segments = []
all_index_node_ids = []
@ -981,13 +990,13 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
),
batch="test_batch",
name=f"Notion Page {i}",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status=status,
@ -1009,7 +1018,7 @@ class TestCleanNotionDocumentTask:
tokens=50,
index_node_id=f"node_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
)
db_session_with_containers.add(segment)
all_segments.append(segment)
@ -1066,7 +1075,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
created_by=account.id,
built_in_field_enabled=True,
)
@ -1079,7 +1088,7 @@ class TestCleanNotionDocumentTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="notion_import",
data_source_type=DataSourceType.NOTION_IMPORT,
data_source_info=json.dumps(
{
"notion_workspace_id": "workspace_test",
@ -1091,10 +1100,10 @@ class TestCleanNotionDocumentTask:
),
batch="test_batch",
name="Test Notion Page with Metadata",
created_from="notion_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_metadata={
"document_name": "Test Notion Page with Metadata",
"uploader": account.name,
@ -1122,7 +1131,7 @@ class TestCleanNotionDocumentTask:
tokens=75,
index_node_id=f"node_{i}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
keywords={"key1": ["value1", "value2"], "key2": ["value3"]},
)
db_session_with_containers.add(segment)

View File

@ -15,6 +15,7 @@ from faker import Faker
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.create_segment_to_index_task import create_segment_to_index_task
@ -118,7 +119,7 @@ class TestCreateSegmentToIndexTask:
name=fake.company(),
description=fake.text(max_nb_chars=100),
tenant_id=tenant_id,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
@ -133,13 +134,13 @@ class TestCreateSegmentToIndexTask:
dataset_id=dataset.id,
tenant_id=tenant_id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account_id,
enabled=True,
archived=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="qa_model",
)
db_session_with_containers.add(document)
@ -148,7 +149,7 @@ class TestCreateSegmentToIndexTask:
return dataset, document
def _create_test_segment(
self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting"
self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING
):
"""
Helper method to create a test document segment for testing.
@ -200,7 +201,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -208,7 +209,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify segment status changes
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
assert segment.error is None
@ -257,7 +258,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.COMPLETED
)
# Act: Execute the task
@ -268,7 +269,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status unchanged
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is None
# Verify no index processor calls were made
@ -293,20 +294,25 @@ class TestCreateSegmentToIndexTask:
dataset_id=invalid_dataset_id,
tenant_id=tenant.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
enabled=True,
archived=False,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers,
invalid_dataset_id,
document.id,
tenant.id,
account.id,
status=SegmentStatus.WAITING,
)
# Act: Execute the task
@ -317,7 +323,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -337,7 +343,12 @@ class TestCreateSegmentToIndexTask:
invalid_document_id = str(uuid4())
segment = self._create_test_segment(
db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting"
db_session_with_containers,
dataset.id,
invalid_document_id,
tenant.id,
account.id,
status=SegmentStatus.WAITING,
)
# Act: Execute the task
@ -348,7 +359,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -373,7 +384,7 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -384,7 +395,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -409,7 +420,7 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -420,7 +431,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -445,7 +456,7 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -456,7 +467,7 @@ class TestCreateSegmentToIndexTask:
# Verify segment status changed to indexing (task updates status before checking document)
db_session_with_containers.refresh(segment)
assert segment.status == "indexing"
assert segment.status == SegmentStatus.INDEXING
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
@ -477,7 +488,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Mock processor to raise exception
@ -488,7 +499,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify error handling
db_session_with_containers.refresh(segment)
assert segment.status == "error"
assert segment.status == SegmentStatus.ERROR
assert segment.enabled is False
assert segment.disabled_at is not None
assert segment.error == "Processor failed"
@ -512,7 +523,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
custom_keywords = ["custom", "keywords", "test"]
@ -521,7 +532,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -555,7 +566,7 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.commit()
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -563,7 +574,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
# Verify correct doc_form was passed to factory
mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form)
@ -583,7 +594,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task and measure time
@ -597,7 +608,7 @@ class TestCreateSegmentToIndexTask:
# Verify successful completion
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
def test_create_segment_to_index_concurrent_execution(
self, db_session_with_containers, mock_external_service_dependencies
@ -617,7 +628,7 @@ class TestCreateSegmentToIndexTask:
segments = []
for i in range(3):
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
segments.append(segment)
@ -629,7 +640,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify all segments processed
for segment in segments:
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -665,7 +676,7 @@ class TestCreateSegmentToIndexTask:
keywords=["large", "content", "test"],
index_node_id=str(uuid4()),
index_node_hash=str(uuid4()),
status="waiting",
status=SegmentStatus.WAITING,
created_by=account.id,
)
db_session_with_containers.add(segment)
@ -681,7 +692,7 @@ class TestCreateSegmentToIndexTask:
assert execution_time < 10.0 # Should complete within 10 seconds
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -700,7 +711,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Set up Redis cache key to simulate indexing in progress
@ -718,7 +729,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify indexing still completed successfully despite Redis failure
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -740,7 +751,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Simulate an error during indexing to trigger rollback path
@ -752,7 +763,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify error handling and rollback
db_session_with_containers.refresh(segment)
assert segment.status == "error"
assert segment.status == SegmentStatus.ERROR
assert segment.enabled is False
assert segment.disabled_at is not None
assert segment.error is not None
@ -772,7 +783,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task
@ -780,7 +791,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
# Verify index processor was called with correct metadata
mock_processor = mock_external_service_dependencies["index_processor"]
@ -814,11 +825,11 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Verify initial state
assert segment.status == "waiting"
assert segment.status == SegmentStatus.WAITING
assert segment.indexing_at is None
assert segment.completed_at is None
@ -827,7 +838,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify final state
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -861,7 +872,7 @@ class TestCreateSegmentToIndexTask:
keywords=[],
index_node_id=str(uuid4()),
index_node_hash=str(uuid4()),
status="waiting",
status=SegmentStatus.WAITING,
created_by=account.id,
)
db_session_with_containers.add(segment)
@ -872,7 +883,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -907,7 +918,7 @@ class TestCreateSegmentToIndexTask:
keywords=["special", "unicode", "test"],
index_node_id=str(uuid4()),
index_node_hash=str(uuid4()),
status="waiting",
status=SegmentStatus.WAITING,
created_by=account.id,
)
db_session_with_containers.add(segment)
@ -918,7 +929,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -937,7 +948,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Create long keyword list
@ -948,7 +959,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -979,10 +990,10 @@ class TestCreateSegmentToIndexTask:
)
segment1 = self._create_test_segment(
db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting"
db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status=SegmentStatus.WAITING
)
segment2 = self._create_test_segment(
db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting"
db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status=SegmentStatus.WAITING
)
# Act: Execute tasks for both tenants
@ -993,8 +1004,8 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers.refresh(segment1)
db_session_with_containers.refresh(segment2)
assert segment1.status == "completed"
assert segment2.status == "completed"
assert segment1.status == SegmentStatus.COMPLETED
assert segment2.status == SegmentStatus.COMPLETED
assert segment1.tenant_id == tenant1.id
assert segment2.tenant_id == tenant2.id
assert segment1.tenant_id != segment2.tenant_id
@ -1014,7 +1025,7 @@ class TestCreateSegmentToIndexTask:
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
# Act: Execute the task with None keywords
@ -1022,7 +1033,7 @@ class TestCreateSegmentToIndexTask:
# Assert: Verify successful indexing
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
@ -1050,7 +1061,7 @@ class TestCreateSegmentToIndexTask:
segments = []
for i in range(5):
segment = self._create_test_segment(
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
)
segments.append(segment)
@ -1067,7 +1078,7 @@ class TestCreateSegmentToIndexTask:
# Verify all segments processed successfully
for segment in segments:
db_session_with_containers.refresh(segment)
assert segment.status == "completed"
assert segment.status == SegmentStatus.COMPLETED
assert segment.indexing_at is not None
assert segment.completed_at is not None
assert segment.error is None

View File

@ -11,6 +11,7 @@ from core.indexing_runner import DocumentIsPausedError
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from tasks.document_indexing_task import (
_document_indexing,
_document_indexing_with_tenant_queue,
@ -139,7 +140,7 @@ class TestDatasetIndexingTaskIntegration:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -155,12 +156,12 @@ class TestDatasetIndexingTaskIntegration:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=position,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch="test_batch",
name=f"doc-{position}.txt",
created_from="upload_file",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status="waiting",
indexing_status=IndexingStatus.WAITING,
enabled=True,
)
db_session_with_containers.add(document)
@ -181,7 +182,7 @@ class TestDatasetIndexingTaskIntegration:
for document_id in document_ids:
updated = self._query_document(db_session_with_containers, document_id)
assert updated is not None
assert updated.indexing_status == "parsing"
assert updated.indexing_status == IndexingStatus.PARSING
assert updated.processing_started_at is not None
def _assert_documents_error_contains(
@ -195,7 +196,7 @@ class TestDatasetIndexingTaskIntegration:
for document_id in document_ids:
updated = self._query_document(db_session_with_containers, document_id)
assert updated is not None
assert updated.indexing_status == "error"
assert updated.indexing_status == IndexingStatus.ERROR
assert updated.error is not None
assert expected_error_substring in updated.error
assert updated.stopped_at is not None

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from services.account_service import AccountService, TenantService
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -90,7 +91,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -102,13 +103,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -150,7 +151,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -162,13 +163,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -182,13 +183,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -209,7 +210,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -220,7 +221,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor load method was called
mock_factory = mock_index_processor_factory.return_value
@ -251,7 +252,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -263,13 +264,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="parent_child_index",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -283,13 +284,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="parent_child_index",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -310,7 +311,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -321,7 +322,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor clean and load methods were called
mock_factory = mock_index_processor_factory.return_value
@ -367,7 +368,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -399,7 +400,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -411,13 +412,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -430,7 +431,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify that no index processor load was called since no segments exist
mock_factory = mock_index_processor_factory.return_value
@ -455,7 +456,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -488,7 +489,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -500,13 +501,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -520,13 +521,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -547,7 +548,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -563,7 +564,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to error
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "error"
assert updated_document.indexing_status == IndexingStatus.ERROR
assert "Test exception during indexing" in updated_document.error
def test_deal_dataset_vector_index_task_with_custom_index_type(
@ -584,7 +585,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -596,13 +597,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="qa_index",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -623,7 +624,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -634,7 +635,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor was initialized with custom index type
mock_index_processor_factory.assert_called_once_with("qa_index")
@ -660,7 +661,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -672,13 +673,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -699,7 +700,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -710,7 +711,7 @@ class TestDealDatasetVectorIndexTask:
# Verify document status was updated to indexing then completed
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor was initialized with the document's index type
mock_index_processor_factory.assert_called_once_with("text_model")
@ -736,7 +737,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -748,13 +749,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -770,13 +771,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name=f"Test Document {i}",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -801,7 +802,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{i}_{j}",
index_node_hash=f"hash_{i}_{j}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -814,7 +815,7 @@ class TestDealDatasetVectorIndexTask:
# Verify all documents were processed
for document in documents:
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor load was called multiple times
mock_factory = mock_index_processor_factory.return_value
@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -851,13 +852,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -871,13 +872,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Test Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -898,7 +899,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -916,7 +917,7 @@ class TestDealDatasetVectorIndexTask:
# Verify final document status
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
assert updated_document.indexing_status == "completed"
assert updated_document.indexing_status == IndexingStatus.COMPLETED
def test_deal_dataset_vector_index_task_with_disabled_documents(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
@ -936,7 +937,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -948,13 +949,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -968,13 +969,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Enabled Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -987,13 +988,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Disabled Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # This document should be skipped
archived=False,
batch="test_batch",
@ -1015,7 +1016,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -1026,13 +1027,13 @@ class TestDealDatasetVectorIndexTask:
# Verify only enabled document was processed
updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first()
assert updated_enabled_document.indexing_status == "completed"
assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED
# Verify disabled document status remains unchanged
updated_disabled_document = (
db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first()
)
assert updated_disabled_document.indexing_status == "completed" # Should not change
assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change
# Verify index processor load was called only once (for enabled document)
mock_factory = mock_index_processor_factory.return_value
@ -1057,7 +1058,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -1069,13 +1070,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -1089,13 +1090,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Active Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -1108,13 +1109,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Archived Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # This document should be skipped
batch="test_batch",
@ -1136,7 +1137,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -1147,13 +1148,13 @@ class TestDealDatasetVectorIndexTask:
# Verify only active document was processed
updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first()
assert updated_active_document.indexing_status == "completed"
assert updated_active_document.indexing_status == IndexingStatus.COMPLETED
# Verify archived document status remains unchanged
updated_archived_document = (
db_session_with_containers.query(Document).filter_by(id=archived_document.id).first()
)
assert updated_archived_document.indexing_status == "completed" # Should not change
assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change
# Verify index processor load was called only once (for active document)
mock_factory = mock_index_processor_factory.return_value
@ -1178,7 +1179,7 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=account.id,
)
db_session_with_containers.add(dataset)
@ -1190,13 +1191,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Document for doc_form",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -1210,13 +1211,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Completed Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
batch="test_batch",
@ -1229,13 +1230,13 @@ class TestDealDatasetVectorIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="file_import",
data_source_type=DataSourceType.UPLOAD_FILE,
name="Incomplete Document",
created_from="file_import",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="indexing", # This document should be skipped
indexing_status=IndexingStatus.INDEXING, # This document should be skipped
enabled=True,
archived=False,
batch="test_batch",
@ -1257,7 +1258,7 @@ class TestDealDatasetVectorIndexTask:
index_node_id=f"node_{uuid.uuid4()}",
index_node_hash=f"hash_{uuid.uuid4()}",
created_by=account.id,
status="completed",
status=SegmentStatus.COMPLETED,
enabled=True,
)
db_session_with_containers.add(segment)
@ -1270,13 +1271,13 @@ class TestDealDatasetVectorIndexTask:
updated_completed_document = (
db_session_with_containers.query(Document).filter_by(id=completed_document.id).first()
)
assert updated_completed_document.indexing_status == "completed"
assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED
# Verify incomplete document status remains unchanged
updated_incomplete_document = (
db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first()
)
assert updated_incomplete_document.indexing_status == "indexing" # Should not change
assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change
# Verify index processor load was called only once (for completed document)
mock_factory = mock_index_processor_factory.return_value

View File

@ -14,6 +14,7 @@ from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
logger = logging.getLogger(__name__)
@ -106,7 +107,7 @@ class TestDeleteSegmentFromIndexTask:
dataset.description = fake.text(max_nb_chars=200)
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.data_source_type = "upload_file"
dataset.data_source_type = DataSourceType.UPLOAD_FILE
dataset.indexing_technique = "high_quality"
dataset.index_struct = '{"type": "paragraph"}'
dataset.created_by = account.id
@ -145,7 +146,7 @@ class TestDeleteSegmentFromIndexTask:
document.data_source_info = kwargs.get("data_source_info", "{}")
document.batch = kwargs.get("batch", fake.uuid4())
document.name = kwargs.get("name", f"Test Document {fake.word()}")
document.created_from = kwargs.get("created_from", "api")
document.created_from = kwargs.get("created_from", DocumentCreatedFrom.API)
document.created_by = account.id
document.created_at = fake.date_time_this_year()
document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year())
@ -162,7 +163,7 @@ class TestDeleteSegmentFromIndexTask:
document.enabled = kwargs.get("enabled", True)
document.archived = kwargs.get("archived", False)
document.updated_at = fake.date_time_this_year()
document.doc_type = kwargs.get("doc_type", "text")
document.doc_type = kwargs.get("doc_type", DocumentDocType.PERSONAL_DOCUMENT)
document.doc_metadata = kwargs.get("doc_metadata", {})
document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX)
document.doc_language = kwargs.get("doc_language", "en")
@ -204,7 +205,7 @@ class TestDeleteSegmentFromIndexTask:
segment.index_node_hash = fake.sha256()
segment.hit_count = 0
segment.enabled = True
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
segment.created_by = account.id
segment.created_at = fake.date_time_this_year()
segment.updated_by = account.id
@ -386,7 +387,7 @@ class TestDeleteSegmentFromIndexTask:
account = self._create_test_account(db_session_with_containers, tenant, fake)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
document = self._create_test_document(
db_session_with_containers, dataset, account, fake, indexing_status="indexing"
db_session_with_containers, dataset, account, fake, indexing_status=IndexingStatus.INDEXING
)
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)

View File

@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
logger = logging.getLogger(__name__)
@ -97,7 +98,7 @@ class TestDisableSegmentFromIndexTask:
tenant_id=tenant.id,
name=fake.sentence(nb_words=3),
description=fake.text(max_nb_chars=200),
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
)
@ -132,12 +133,12 @@ class TestDisableSegmentFromIndexTask:
tenant_id=tenant.id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
batch=fake.uuid4(),
name=fake.file_name(),
created_from="api",
created_from=DocumentCreatedFrom.API,
created_by=account.id,
indexing_status="completed",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form=doc_form,
@ -189,7 +190,7 @@ class TestDisableSegmentFromIndexTask:
status=status,
enabled=enabled,
created_by=account.id,
completed_at=datetime.now(UTC) if status == "completed" else None,
completed_at=datetime.now(UTC) if status == SegmentStatus.COMPLETED else None,
)
db_session_with_containers.add(segment)
db_session_with_containers.commit()
@ -271,7 +272,7 @@ class TestDisableSegmentFromIndexTask:
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(
db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True
db_session_with_containers, document, dataset, tenant, account, status=SegmentStatus.INDEXING, enabled=True
)
# Act: Execute the task

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from models import Account, Dataset, DocumentSegment
from models import Document as DatasetDocument
from models.dataset import DatasetProcessRule
from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
@ -100,7 +101,7 @@ class TestDisableSegmentsFromIndexTask:
description=fake.text(max_nb_chars=200),
provider="vendor",
permission="only_me",
data_source_type="upload_file",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique="high_quality",
created_by=account.id,
updated_by=account.id,
@ -134,11 +135,11 @@ class TestDisableSegmentsFromIndexTask:
document.tenant_id = dataset.tenant_id
document.dataset_id = dataset.id
document.position = 1
document.data_source_type = "upload_file"
document.data_source_type = DataSourceType.UPLOAD_FILE
document.data_source_info = '{"upload_file_id": "test_file_id"}'
document.batch = fake.uuid4()
document.name = f"Test Document {fake.word()}.txt"
document.created_from = "upload_file"
document.created_from = DocumentCreatedFrom.WEB
document.created_by = account.id
document.created_api_request_id = fake.uuid4()
document.processing_started_at = fake.date_time_this_year()
@ -197,7 +198,7 @@ class TestDisableSegmentsFromIndexTask:
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
segment.status = "completed"
segment.status = SegmentStatus.COMPLETED
segment.created_by = account.id
segment.updated_by = account.id
segment.indexing_at = fake.date_time_this_year()
@ -230,7 +231,7 @@ class TestDisableSegmentsFromIndexTask:
process_rule.id = fake.uuid4()
process_rule.tenant_id = dataset.tenant_id
process_rule.dataset_id = dataset.id
process_rule.mode = "automatic"
process_rule.mode = ProcessRuleMode.AUTOMATIC
process_rule.rules = (
"{"
'"mode": "automatic", '

Some files were not shown because too many files have changed in this diff Show More