mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into yanli/fix-iter-log
This commit is contained in:
commit
dea90b0ccd
|
|
@ -63,7 +63,8 @@ pnpm analyze-component <path> --review
|
|||
|
||||
### File Naming
|
||||
|
||||
- Test files: `ComponentName.spec.tsx` (same directory as component)
|
||||
- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory
|
||||
- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`.
|
||||
- Integration tests: `web/__tests__/` directory
|
||||
|
||||
## Test Structure Template
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
|
|||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = vi.fn()
|
||||
// vi.mock('next/navigation', () => ({
|
||||
// vi.mock('@/next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
shardIndex: [1, 2, 3, 4, 5, 6]
|
||||
shardTotal: [6]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document
|
|||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
||||
|
||||
|
|
@ -242,7 +243,7 @@ def migrate_knowledge_vector_database():
|
|||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
|
|
@ -254,7 +255,7 @@ def migrate_knowledge_vector_database():
|
|||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.status == SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
|
|
@ -430,7 +431,7 @@ def old_metadata_migration():
|
|||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
name=key,
|
||||
type="string",
|
||||
type=DatasetMetadataType.STRING,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(dataset_metadata)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
|
|
@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
|
|||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
"""Allow empty string in env/.env to mean 'unset' (None)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str) and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
|
@ -741,13 +742,15 @@ class DatasetIndexingStatusApi(Resource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
|
|
@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
|
|
@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
|
|
@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
match document.data_source_type:
|
||||
|
|
@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
|
|
@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
|
|
@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||
|
||||
match action:
|
||||
case "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
if document.indexing_status != IndexingStatus.INDEXING:
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
|
|
@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||
db.session.commit()
|
||||
|
||||
case "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
|
|
@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource):
|
|||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
# 400 if document is completed
|
||||
if document.indexing_status == "completed":
|
||||
if document.indexing_status == IndexingStatus.COMPLETED:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource):
|
|||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
if pipeline_template is None:
|
||||
return {"error": "Pipeline template not found from upstream service."}, 404
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from extensions.ext_database import db
|
|||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
KnowledgeConfig,
|
||||
|
|
@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
|
|
|
|||
|
|
@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
|
|||
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook_debug(webhook_id: str):
|
||||
"""Handle webhook debug calls without triggering production workflow execution."""
|
||||
"""Handle webhook debug calls without triggering production workflow execution.
|
||||
|
||||
The debug webhook endpoint is only for draft inspection flows. It never enqueues
|
||||
Celery work for the published workflow; instead it dispatches an in-memory debug
|
||||
event to an active Variable Inspector listener. Returning a clear error when no
|
||||
listener is registered prevents a misleading 200 response for requests that are
|
||||
effectively dropped.
|
||||
"""
|
||||
try:
|
||||
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
|
||||
if error:
|
||||
|
|
@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
|
|||
"method": webhook_data.get("method"),
|
||||
},
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
dispatch_count = TriggerDebugEventBus.dispatch(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
event=event,
|
||||
pool_key=pool_key,
|
||||
)
|
||||
if dispatch_count == 0:
|
||||
logger.warning(
|
||||
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
|
||||
webhook_trigger.webhook_id,
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.app_id,
|
||||
webhook_trigger.node_id,
|
||||
)
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": "No active debug listener",
|
||||
"message": (
|
||||
"The webhook debug URL only works while the Variable Inspector is listening. "
|
||||
"Use the published webhook URL to execute the workflow in Celery."
|
||||
),
|
||||
"execution_url": webhook_trigger.webhook_url,
|
||||
}
|
||||
),
|
||||
409,
|
||||
)
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
|
|
|
|||
|
|
@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
|
|||
continue
|
||||
|
||||
result.append(self.organize_agent_user_prompt(message))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
agent_thoughts = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tool_names_raw = agent_thought.tool
|
||||
|
|
|
|||
|
|
@ -1,13 +1,36 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
|
||||
|
||||
class SystemParametersDict(TypedDict):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
class AppParametersDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: dict[str, Any]
|
||||
speech_to_text: dict[str, Any]
|
||||
text_to_speech: dict[str, Any]
|
||||
retriever_resource: dict[str, Any]
|
||||
annotation_reply: dict[str, Any]
|
||||
more_like_this: dict[str, Any]
|
||||
user_input_form: list[dict[str, Any]]
|
||||
sensitive_word_avoidance: dict[str, Any]
|
||||
file_upload: dict[str, Any]
|
||||
system_parameters: SystemParametersDict
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(
|
||||
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||
) -> Mapping[str, Any]:
|
||||
) -> AppParametersDict:
|
||||
"""
|
||||
Mapping from feature dict to webapp parameters
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from core.app.app_config.entities import (
|
|||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
|
@ -117,8 +118,10 @@ class DatasetConfigManager:
|
|||
score_threshold=float(score_threshold_val)
|
||||
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||
else None,
|
||||
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||
reranking_model=cast(RerankingModelDict, reranking_model_val)
|
||||
if isinstance(reranking_model_val, dict)
|
||||
else None,
|
||||
weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None,
|
||||
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=cast(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Literal
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from dify_graph.file import FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
|
@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||
top_k: int | None = None
|
||||
score_threshold: float | None = 0.0
|
||||
rerank_mode: str | None = "reranking_model"
|
||||
reranking_model: dict | None = None
|
||||
weights: dict | None = None
|
||||
reranking_model: RerankingModelDict | None = None
|
||||
weights: WeightsDict | None = None
|
||||
reranking_enabled: bool | None = True
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
from typing import Any, NewType, TypedDict, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -76,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountCreatedByDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class EndUserCreatedByDict(TypedDict):
|
||||
id: str
|
||||
user: str
|
||||
|
||||
|
||||
CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeSnapshot:
|
||||
"""In-memory cache for node metadata between start and completion events."""
|
||||
|
|
@ -249,19 +263,19 @@ class WorkflowResponseConverter:
|
|||
outputs_mapping = graph_runtime_state.outputs or {}
|
||||
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
|
||||
|
||||
created_by: Mapping[str, object] | None
|
||||
created_by: CreatedByDict | dict[str, object] = {}
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
created_by = AccountCreatedByDict(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
)
|
||||
elif isinstance(user, EndUser):
|
||||
created_by = EndUserCreatedByDict(
|
||||
id=user.id,
|
||||
user=user.session_id,
|
||||
)
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
|
@ -43,7 +44,7 @@ class AnnotationReplyFeature:
|
|||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import TypedDict
|
||||
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
|
|
@ -6,7 +8,20 @@ from models.model import MessageFile, UploadFile
|
|||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
class MessageFileInfoDict(TypedDict):
|
||||
related_id: str
|
||||
extension: str
|
||||
filename: str
|
||||
size: int
|
||||
mime_type: str
|
||||
transfer_method: str
|
||||
type: str
|
||||
url: str
|
||||
upload_file_id: str
|
||||
remote_url: str | None
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler:
|
|||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source="app",
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=(
|
||||
CreatorUserRole.ACCOUNT
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
|||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
|
|
@ -546,7 +547,7 @@ class ProviderConfiguration(BaseModel):
|
|||
self._update_load_balancing_configs_with_credential(
|
||||
credential_id=credential_id,
|
||||
credential_record=credential_record,
|
||||
credential_source="provider",
|
||||
credential_source=CredentialSourceType.PROVIDER,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -623,7 +624,7 @@ class ProviderConfiguration(BaseModel):
|
|||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
|
||||
)
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
try:
|
||||
|
|
@ -1043,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
|||
self._update_load_balancing_configs_with_credential(
|
||||
credential_id=credential_id,
|
||||
credential_record=credential_record,
|
||||
credential_source="custom_model",
|
||||
credential_source=CredentialSourceType.CUSTOM_MODEL,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -1073,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
|
|||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
|
||||
)
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
|
||||
|
|
@ -1711,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_model_lb_configs = [
|
||||
config
|
||||
for config in model_setting.load_balancing_configs
|
||||
if config.credential_source_type != "custom_model"
|
||||
if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
|
||||
]
|
||||
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
|
|
@ -1769,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
|
|||
custom_model_lb_configs = [
|
||||
config
|
||||
for config in model_setting.load_balancing_configs
|
||||
if config.credential_source_type != "provider"
|
||||
if config.credential_source_type != CredentialSourceType.PROVIDER
|
||||
]
|
||||
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from libs.datetime_utils import naive_utc_now
|
|||
from models import Account
|
||||
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
|
||||
from models.model import UploadFile
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -56,7 +57,7 @@ class IndexingRunner:
|
|||
logger.exception("consume document failed")
|
||||
document = db.session.get(DatasetDocument, document_id)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
error_message = getattr(error, "description", str(error))
|
||||
document.error = str(error_message)
|
||||
document.stopped_at = naive_utc_now()
|
||||
|
|
@ -219,7 +220,7 @@ class IndexingRunner:
|
|||
if document_segments:
|
||||
for document_segment in document_segments:
|
||||
# transform segment to node
|
||||
if document_segment.status != "completed":
|
||||
if document_segment.status != SegmentStatus.COMPLETED:
|
||||
document = Document(
|
||||
page_content=document_segment.content,
|
||||
metadata={
|
||||
|
|
@ -382,7 +383,7 @@ class IndexingRunner:
|
|||
data_source_info = dataset_document.data_source_info_dict
|
||||
text_docs = []
|
||||
match dataset_document.data_source_type:
|
||||
case "upload_file":
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
|
|
@ -395,7 +396,7 @@ class IndexingRunner:
|
|||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "notion_import":
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
|
|
@ -417,7 +418,7 @@ class IndexingRunner:
|
|||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "website_crawl":
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
|
|
@ -445,7 +446,7 @@ class IndexingRunner:
|
|||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="splitting",
|
||||
after_indexing_status=IndexingStatus.SPLITTING,
|
||||
extra_update_params={
|
||||
DatasetDocument.parsing_completed_at: naive_utc_now(),
|
||||
},
|
||||
|
|
@ -545,7 +546,7 @@ class IndexingRunner:
|
|||
Clean the document text according to the processing rules.
|
||||
"""
|
||||
rules: AutomaticRulesConfig | dict[str, Any]
|
||||
if processing_rule.mode == "automatic":
|
||||
if processing_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
|
|
@ -636,7 +637,7 @@ class IndexingRunner:
|
|||
# update document status to completed
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="completed",
|
||||
after_indexing_status=IndexingStatus.COMPLETED,
|
||||
extra_update_params={
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: naive_utc_now(),
|
||||
|
|
@ -659,10 +660,10 @@ class IndexingRunner:
|
|||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == "indexing",
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
|
|
@ -703,10 +704,10 @@ class IndexingRunner:
|
|||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == "indexing",
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
|
|
@ -725,7 +726,7 @@ class IndexingRunner:
|
|||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
document_id: str, after_indexing_status: str, extra_update_params: dict | None = None
|
||||
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
|
||||
):
|
||||
"""
|
||||
Update the document indexing status.
|
||||
|
|
@ -803,7 +804,7 @@ class IndexingRunner:
|
|||
cur_time = naive_utc_now()
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
after_indexing_status=IndexingStatus.INDEXING,
|
||||
extra_update_params={
|
||||
DatasetDocument.cleaning_completed_at: cur_time,
|
||||
DatasetDocument.splitting_completed_at: cur_time,
|
||||
|
|
@ -815,7 +816,7 @@ class IndexingRunner:
|
|||
self._update_segments_by_document(
|
||||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.status: SegmentStatus.INDEXING,
|
||||
DocumentSegment.indexing_at: naive_utc_now(),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
|
|
@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
|||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class RerankingModelDict(TypedDict):
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSettingDict(TypedDict):
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSettingDict(TypedDict):
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightsDict(TypedDict):
|
||||
vector_setting: VectorSettingDict
|
||||
keyword_setting: KeywordSettingDict
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
"""Interface for data post-processing document."""
|
||||
|
||||
|
|
@ -17,8 +39,8 @@ class DataPostProcessor:
|
|||
self,
|
||||
tenant_id: str,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
reorder_enabled: bool = False,
|
||||
):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
|
||||
|
|
@ -45,8 +67,8 @@ class DataPostProcessor:
|
|||
self,
|
||||
reranking_mode: str,
|
||||
tenant_id: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
) -> BaseRerankRunner | None:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
|
||||
runner = RerankRunnerFactory.create_rerank_runner(
|
||||
|
|
@ -79,12 +101,14 @@ class DataPostProcessor:
|
|||
return ReorderRunner()
|
||||
return None
|
||||
|
||||
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
|
||||
def _get_rerank_model_instance(
|
||||
self, tenant_id: str, reranking_model: RerankingModelDict | None
|
||||
) -> ModelInstance | None:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
reranking_provider_name = reranking_model["reranking_provider_name"]
|
||||
reranking_model_name = reranking_model["reranking_model_name"]
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
import concurrent.futures
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
|
|
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
|
|||
from models.model import UploadFile
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
|
||||
class SegmentAttachmentResult(TypedDict):
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class SegmentAttachmentInfoResult(TypedDict):
|
||||
attachment_id: str
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class ChildChunkDetail(TypedDict):
|
||||
id: str
|
||||
content: str
|
||||
position: int
|
||||
score: float
|
||||
|
||||
|
||||
class SegmentChildMapDetail(TypedDict):
|
||||
max_score: float
|
||||
child_chunks: list[ChildChunkDetail]
|
||||
|
||||
|
||||
class SegmentRecord(TypedDict):
|
||||
segment: DocumentSegment
|
||||
score: NotRequired[float]
|
||||
child_chunks: NotRequired[list[ChildChunkDetail]]
|
||||
files: NotRequired[list[AttachmentInfoDict]]
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod | str
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
|
@ -56,9 +96,9 @@ class RetrievalService:
|
|||
query: str,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_ids: list | None = None,
|
||||
):
|
||||
|
|
@ -235,7 +275,7 @@ class RetrievalService:
|
|||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
|
|
@ -277,8 +317,8 @@ class RetrievalService:
|
|||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
|
|
@ -288,8 +328,8 @@ class RetrievalService:
|
|||
model_manager = ModelManager()
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model.get("reranking_provider_name") or "",
|
||||
model=reranking_model.get("reranking_model_name") or "",
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
model=reranking_model["reranking_model_name"],
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if is_support_vision:
|
||||
|
|
@ -329,7 +369,7 @@ class RetrievalService:
|
|||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
|
|
@ -349,8 +389,8 @@ class RetrievalService:
|
|||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
|
|
@ -459,7 +499,7 @@ class RetrievalService:
|
|||
segment_ids: list[str] = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
|
|
@ -544,12 +584,12 @@ class RetrievalService:
|
|||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
segment_child_map: dict[str, SegmentChildMapDetail] = {}
|
||||
records: list[SegmentRecord] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
|
|
@ -560,14 +600,14 @@ class RetrievalService:
|
|||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
child_chunk_details: list[ChildChunkDetail] = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
if child_document:
|
||||
child_score = child_document.metadata.get("score", 0.0)
|
||||
else:
|
||||
child_score = 0.0
|
||||
child_chunk_detail = {
|
||||
child_chunk_detail: ChildChunkDetail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
|
|
@ -580,7 +620,7 @@ class RetrievalService:
|
|||
if file_document:
|
||||
max_score = max(max_score, file_document.metadata.get("score", 0.0))
|
||||
|
||||
map_detail = {
|
||||
map_detail: SegmentChildMapDetail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
|
|
@ -593,7 +633,7 @@ class RetrievalService:
|
|||
"max_score": summary_score,
|
||||
"child_chunks": [],
|
||||
}
|
||||
record: dict[str, Any] = {
|
||||
record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
|
|
@ -617,19 +657,19 @@ class RetrievalService:
|
|||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
|
||||
record = {
|
||||
another_record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
}
|
||||
records.append(record)
|
||||
records.append(another_record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
record["files"] = attachment_map[record["segment"].id]
|
||||
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
|
|
@ -693,9 +733,9 @@ class RetrievalService:
|
|||
query: str | None = None,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_id: str | None = None,
|
||||
):
|
||||
|
|
@ -807,7 +847,7 @@ class RetrievalService:
|
|||
@classmethod
|
||||
def get_segment_attachment_info(
|
||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||
) -> dict[str, Any] | None:
|
||||
) -> SegmentAttachmentResult | None:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||
if upload_file:
|
||||
attachment_binding = (
|
||||
|
|
@ -816,7 +856,7 @@ class RetrievalService:
|
|||
.first()
|
||||
)
|
||||
if attachment_binding:
|
||||
attachment_info = {
|
||||
attachment_info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
|
|
@ -828,8 +868,10 @@ class RetrievalService:
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||
attachment_infos = []
|
||||
def get_segment_attachment_infos(
|
||||
cls, attachment_ids: list[str], session: Session
|
||||
) -> list[SegmentAttachmentInfoResult]:
|
||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
|
|
@ -843,7 +885,7 @@ class RetrievalService:
|
|||
if attachment_bindings:
|
||||
for upload_file in upload_files:
|
||||
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||
attachment_info = {
|
||||
info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
|
|
@ -855,7 +897,7 @@ class RetrievalService:
|
|||
attachment_infos.append(
|
||||
{
|
||||
"attachment_id": attachment_binding.attachment_id,
|
||||
"attachment_info": attachment_info,
|
||||
"attachment_info": info,
|
||||
"segment_id": attachment_binding.segment_id,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r
|
|||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
|
@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
|||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
|
|
@ -85,18 +112,6 @@ class WeaviateVector(BaseVector):
|
|||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
|
|
|||
|
|
@ -1,8 +1,18 @@
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class AttachmentInfoDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
size: int
|
||||
|
||||
|
||||
class RetrievalChildChunk(BaseModel):
|
||||
"""Retrieval segments."""
|
||||
|
||||
|
|
@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel):
|
|||
segment: DocumentSegment
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
files: list[AttachmentInfoDict] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from flask import current_app
|
|||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
|
@ -51,7 +52,7 @@ class IndexProcessor:
|
|||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
|
|
@ -131,7 +132,12 @@ class IndexProcessor:
|
|||
}
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
self,
|
||||
chunks: Any,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_structure: str,
|
||||
summary_index_setting: SummaryIndexSettingDict | None,
|
||||
) -> Preview:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
|
|
|
|||
|
|
@ -7,14 +7,16 @@ import os
|
|||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, Optional
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
|
|
@ -35,6 +37,13 @@ if TYPE_CHECKING:
|
|||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class SummaryIndexSettingDict(TypedDict):
|
||||
enable: bool
|
||||
model_name: NotRequired[str]
|
||||
model_provider_name: NotRequired[str]
|
||||
summary_prompt: NotRequired[str]
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files."""
|
||||
|
||||
|
|
@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC):
|
|||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
|
@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC):
|
|||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
|||
from core.model_manager import ModelInstance
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
|
|
@ -22,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
|
|
@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
|
@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
|
@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
def generate_summary(
|
||||
tenant_id: str,
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
segment_id: str | None = None,
|
||||
document_language: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.db.session_factory import session_factory
|
|||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
|
|
@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
|
@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,13 +15,14 @@ from core.db.session_factory import session_factory
|
|||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
|
|
@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
):
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
|
@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from core.ops.utils import measure_time
|
|||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
|
|
@ -83,7 +83,7 @@ from models.dataset import (
|
|||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -727,8 +727,8 @@ class DatasetRetrieval:
|
|||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict[str, Any] | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
|
|
@ -1008,7 +1008,7 @@ class DatasetRetrieval:
|
|||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=json.dumps(contents),
|
||||
source="app",
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=app_id,
|
||||
created_by_role=CreatorUserRole(user_from),
|
||||
created_by=user_id,
|
||||
|
|
@ -1181,8 +1181,8 @@ class DatasetRetrieval:
|
|||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"],
|
||||
reranking_model_name=retrieve_config.reranking_model["reranking_model_name"],
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
|
@ -1685,8 +1685,8 @@ class DatasetRetrieval:
|
|||
tenant_id: str,
|
||||
reranking_enable: bool,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None,
|
||||
weights: dict[str, Any] | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
weights: WeightsDict | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import concurrent.futures
|
|||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
|
@ -11,7 +12,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class SummaryIndex:
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
is_preview: bool,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
|
|
|
|||
|
|
@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict):
|
|||
controller: ApiToolProviderController
|
||||
|
||||
|
||||
class EmojiIconDict(TypedDict):
|
||||
background: str
|
||||
content: str
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
|
|
@ -916,7 +921,7 @@ class ToolManager:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
|
|
@ -933,7 +938,7 @@ class ToolManager:
|
|||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
|
|
@ -950,7 +955,7 @@ class ToolManager:
|
|||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
|
@ -970,7 +975,7 @@ class ToolManager:
|
|||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> str | Mapping[str, str]:
|
||||
) -> str | EmojiIconDict | dict[str, str]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument
|
|||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Any, cast
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
|
|
@ -16,7 +17,19 @@ from models.dataset import Dataset
|
|||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
reranking_mode: NotRequired[str]
|
||||
weights: NotRequired[WeightsDict | None]
|
||||
score_threshold: NotRequired[float]
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
|
@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if metadata_condition and not document_ids_filter:
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
from collections.abc import Mapping
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
|
|
@ -20,10 +21,18 @@ class InterfaceDict(TypedDict):
|
|||
operation: dict[str, Any]
|
||||
|
||||
|
||||
class OpenAPISpecDict(TypedDict):
|
||||
openapi: str
|
||||
info: dict[str, str]
|
||||
servers: list[dict[str, Any]]
|
||||
paths: dict[str, Any]
|
||||
components: dict[str, Any]
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
|
@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser:
|
|||
@staticmethod
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
) -> OpenAPISpecDict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
|
@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser:
|
|||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
converted_openapi: dict[str, Any] = {
|
||||
converted_openapi: OpenAPISpecDict = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Literal, Union
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
|
|
@ -161,4 +162,4 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
|||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from collections.abc import Mapping
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
|
|
@ -127,7 +128,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
|||
is_preview: bool,
|
||||
batch: Any,
|
||||
chunks: Mapping[str, Any],
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("document_id is required.")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence
|
|||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
|
|
@ -201,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
reranking_model: RerankingModelDict | None = None
|
||||
weights: WeightsDict | None = None
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Any, Literal, Protocol
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from dify_graph.model_runtime.entities import LLMUsage
|
||||
from dify_graph.nodes.llm.entities import ModelConfig
|
||||
|
||||
|
|
@ -75,8 +76,8 @@ class KnowledgeRetrievalRequest(BaseModel):
|
|||
top_k: int = Field(default=0, description="Number of top results to return")
|
||||
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
|
||||
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
|
||||
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
|
||||
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
|
||||
reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration")
|
||||
weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking")
|
||||
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
|
||||
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
|
||||
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
|||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
max_retries=0,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from events.document_index_event import document_index_created
|
|||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Document
|
||||
from models.enums import IndexingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,7 +36,7 @@ def handle(sender, **kwargs):
|
|||
if not document:
|
||||
raise NotFound("Document not found")
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
|
|
|
|||
|
|
@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase):
|
|||
|
||||
@classmethod
|
||||
def get_by_openid(cls, provider: str, open_id: str):
|
||||
account_integrate = (
|
||||
db.session.query(AccountIntegrate)
|
||||
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
.one_or_none()
|
||||
)
|
||||
account_integrate = db.session.execute(
|
||||
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
).scalar_one_or_none()
|
||||
if account_integrate:
|
||||
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
|
||||
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
|
||||
return None
|
||||
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import os
|
|||
import pickle
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, TypedDict, cast
|
||||
|
|
@ -30,7 +31,20 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
|
|||
from .account import Account
|
||||
from .base import Base, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .enums import (
|
||||
CollectionBindingType,
|
||||
CreatorUserRole,
|
||||
DatasetMetadataType,
|
||||
DatasetQuerySource,
|
||||
DatasetRuntimeMode,
|
||||
DataSourceType,
|
||||
DocumentCreatedFrom,
|
||||
DocumentDocType,
|
||||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SummaryStatus,
|
||||
)
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
|
||||
|
|
@ -120,7 +134,7 @@ class Dataset(Base):
|
|||
server_default=sa.text("'only_me'"),
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
data_source_type = mapped_column(String(255))
|
||||
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
|
||||
indexing_technique: Mapped[str | None] = mapped_column(String(255))
|
||||
index_struct = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -137,7 +151,9 @@ class Dataset(Base):
|
|||
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(AdjustedJSON, nullable=True)
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||
runtime_mode = mapped_column(
|
||||
EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'")
|
||||
)
|
||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
|
@ -145,30 +161,25 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def total_documents(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def total_available_documents(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
dataset_keyword_table = (
|
||||
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
|
||||
)
|
||||
if dataset_keyword_table:
|
||||
return dataset_keyword_table
|
||||
|
||||
return None
|
||||
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
|
||||
|
||||
@property
|
||||
def index_struct_dict(self):
|
||||
|
|
@ -195,64 +206,66 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def latest_process_rule(self):
|
||||
return (
|
||||
db.session.query(DatasetProcessRule)
|
||||
return db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == self.id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@property
|
||||
def app_count(self):
|
||||
return (
|
||||
db.session.query(func.count(AppDatasetJoin.id))
|
||||
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
|
||||
.scalar()
|
||||
db.session.scalar(
|
||||
select(func.count(AppDatasetJoin.id)).where(
|
||||
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def document_count(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def available_document_count(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def available_segment_count(self):
|
||||
return (
|
||||
db.session.query(func.count(DocumentSegment.id))
|
||||
.where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def word_count(self):
|
||||
return (
|
||||
db.session.query(Document)
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
|
||||
.where(Document.dataset_id == self.id)
|
||||
.scalar()
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
def doc_form(self) -> str | None:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
|
||||
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
|
||||
if document:
|
||||
return document.doc_form
|
||||
return None
|
||||
|
|
@ -270,8 +283,8 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
|
|
@ -279,8 +292,7 @@ class Dataset(Base):
|
|||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "knowledge",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
|
|
@ -288,8 +300,8 @@ class Dataset(Base):
|
|||
def external_knowledge_info(self):
|
||||
if self.provider != "external":
|
||||
return None
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
return None
|
||||
|
|
@ -310,7 +322,7 @@ class Dataset(Base):
|
|||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
|
||||
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
|
|
@ -382,7 +394,7 @@ class DatasetProcessRule(Base): # bug
|
|||
|
||||
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
rules = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -428,12 +440,12 @@ class Document(Base):
|
|||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False)
|
||||
data_source_info = mapped_column(LongText, nullable=True)
|
||||
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
|
||||
batch: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_api_request_id = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -467,7 +479,9 @@ class Document(Base):
|
|||
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# basic fields
|
||||
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'"))
|
||||
indexing_status = mapped_column(
|
||||
EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'")
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -478,7 +492,7 @@ class Document(Base):
|
|||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
doc_type = mapped_column(String(40), nullable=True)
|
||||
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
|
|
@ -521,10 +535,8 @@ class Document(Base):
|
|||
if self.data_source_info:
|
||||
if self.data_source_type == "upload_file":
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
.one_or_none()
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
)
|
||||
if file_detail:
|
||||
return {
|
||||
|
|
@ -557,24 +569,23 @@ class Document(Base):
|
|||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
|
||||
@property
|
||||
def segment_count(self):
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
|
||||
return (
|
||||
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def hit_count(self):
|
||||
return (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
|
||||
.where(DocumentSegment.document_id == self.id)
|
||||
.scalar()
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
def uploader(self):
|
||||
user = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
return user.name if user else None
|
||||
|
||||
@property
|
||||
|
|
@ -588,14 +599,13 @@ class Document(Base):
|
|||
@property
|
||||
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
|
||||
if self.doc_metadata:
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
document_metadatas = db.session.scalars(
|
||||
select(DatasetMetadata)
|
||||
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
metadata_list: list[DocMetadataDetailItem] = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict: DocMetadataDetailItem = {
|
||||
|
|
@ -791,7 +801,7 @@ class DocumentSegment(Base):
|
|||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'"))
|
||||
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -826,7 +836,7 @@ class DocumentSegment(Base):
|
|||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self) -> list[Any]:
|
||||
def child_chunks(self) -> Sequence[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
|
|
@ -835,16 +845,13 @@ class DocumentSegment(Base):
|
|||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
def get_child_chunks(self) -> list[Any]:
|
||||
def get_child_chunks(self) -> Sequence[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
|
|
@ -853,12 +860,9 @@ class DocumentSegment(Base):
|
|||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
|
|
@ -1007,15 +1011,15 @@ class ChildChunk(Base):
|
|||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
return db.session.query(Document).where(Document.id == self.document_id).first()
|
||||
return db.session.scalar(select(Document).where(Document.id == self.document_id))
|
||||
|
||||
@property
|
||||
def segment(self):
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
|
||||
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
|
||||
|
||||
|
||||
class AppDatasetJoin(TypeBase):
|
||||
|
|
@ -1061,7 +1065,7 @@ class DatasetQuery(TypeBase):
|
|||
)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False)
|
||||
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -1076,7 +1080,7 @@ class DatasetQuery(TypeBase):
|
|||
if isinstance(queries, list):
|
||||
for query in queries:
|
||||
if query["content_type"] == QueryType.IMAGE_QUERY:
|
||||
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
|
||||
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
|
||||
if file_info:
|
||||
query["file_info"] = {
|
||||
"id": file_info.id,
|
||||
|
|
@ -1141,7 +1145,7 @@ class DatasetKeywordTable(TypeBase):
|
|||
super().__init__(object_hook=object_hook, *args, **kwargs)
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
if not dataset:
|
||||
return None
|
||||
if self.data_source_type == "database":
|
||||
|
|
@ -1206,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase):
|
|||
)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
|
||||
type: Mapped[str] = mapped_column(
|
||||
EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False
|
||||
)
|
||||
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
|
|
@ -1433,7 +1439,7 @@ class DatasetMetadata(TypeBase):
|
|||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
|
|
@ -1535,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase):
|
|||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
|
@ -1570,7 +1576,7 @@ class Pipeline(TypeBase):
|
|||
)
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(TypeBase):
|
||||
|
|
@ -1660,7 +1666,9 @@ class DocumentSegmentSummary(Base):
|
|||
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
|
||||
status: Mapped[str] = mapped_column(
|
||||
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
|
||||
)
|
||||
error: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
|
|||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
|
|
@ -96,3 +103,216 @@ class ConversationStatus(StrEnum):
|
|||
"""Conversation Status Enum"""
|
||||
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
class DataSourceType(StrEnum):
|
||||
"""Data Source Type for Dataset and Document"""
|
||||
|
||||
UPLOAD_FILE = "upload_file"
|
||||
NOTION_IMPORT = "notion_import"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
LOCAL_FILE = "local_file"
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
|
||||
|
||||
class ProcessRuleMode(StrEnum):
|
||||
"""Dataset Process Rule Mode"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOM = "custom"
|
||||
HIERARCHICAL = "hierarchical"
|
||||
|
||||
|
||||
class IndexingStatus(StrEnum):
|
||||
"""Document Indexing Status"""
|
||||
|
||||
WAITING = "waiting"
|
||||
PARSING = "parsing"
|
||||
CLEANING = "cleaning"
|
||||
SPLITTING = "splitting"
|
||||
INDEXING = "indexing"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DocumentCreatedFrom(StrEnum):
|
||||
"""Document Created From"""
|
||||
|
||||
WEB = "web"
|
||||
API = "api"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class ConversationFromSource(StrEnum):
|
||||
"""Conversation / Message from_source"""
|
||||
|
||||
API = "api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class FeedbackFromSource(StrEnum):
|
||||
"""MessageFeedback from_source"""
|
||||
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""How a conversation/message was invoked"""
|
||||
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class DocumentDocType(StrEnum):
|
||||
"""Document doc_type classification"""
|
||||
|
||||
BOOK = "book"
|
||||
WEB_PAGE = "web_page"
|
||||
PAPER = "paper"
|
||||
SOCIAL_MEDIA_POST = "social_media_post"
|
||||
WIKIPEDIA_ENTRY = "wikipedia_entry"
|
||||
PERSONAL_DOCUMENT = "personal_document"
|
||||
BUSINESS_DOCUMENT = "business_document"
|
||||
IM_CHAT_LOG = "im_chat_log"
|
||||
SYNCED_FROM_NOTION = "synced_from_notion"
|
||||
SYNCED_FROM_GITHUB = "synced_from_github"
|
||||
OTHERS = "others"
|
||||
|
||||
|
||||
class TagType(StrEnum):
|
||||
"""Tag type"""
|
||||
|
||||
KNOWLEDGE = "knowledge"
|
||||
APP = "app"
|
||||
|
||||
|
||||
class DatasetMetadataType(StrEnum):
|
||||
"""Dataset metadata value type"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
WAITING = "waiting"
|
||||
INDEXING = "indexing"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
PAUSED = "paused"
|
||||
RE_SEGMENT = "re_segment"
|
||||
|
||||
|
||||
class DatasetRuntimeMode(StrEnum):
|
||||
"""Dataset runtime mode"""
|
||||
|
||||
GENERAL = "general"
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class CollectionBindingType(StrEnum):
|
||||
"""Dataset collection binding type"""
|
||||
|
||||
DATASET = "dataset"
|
||||
ANNOTATION = "annotation"
|
||||
|
||||
|
||||
class DatasetQuerySource(StrEnum):
|
||||
"""Dataset query source"""
|
||||
|
||||
HIT_TESTING = "hit_testing"
|
||||
APP = "app"
|
||||
|
||||
|
||||
class TidbAuthBindingStatus(StrEnum):
|
||||
"""TiDB auth binding status"""
|
||||
|
||||
CREATING = "CREATING"
|
||||
ACTIVE = "ACTIVE"
|
||||
|
||||
|
||||
class MessageFileBelongsTo(StrEnum):
|
||||
"""MessageFile belongs_to"""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class CredentialSourceType(StrEnum):
|
||||
"""Load balancing credential source type"""
|
||||
|
||||
PROVIDER = "provider"
|
||||
CUSTOM_MODEL = "custom_model"
|
||||
|
||||
|
||||
class PaymentStatus(StrEnum):
|
||||
"""Provider order payment status"""
|
||||
|
||||
WAIT_PAY = "wait_pay"
|
||||
PAID = "paid"
|
||||
FAILED = "failed"
|
||||
REFUNDED = "refunded"
|
||||
|
||||
|
||||
class BannerStatus(StrEnum):
|
||||
"""ExporleBanner status"""
|
||||
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class SummaryStatus(StrEnum):
|
||||
"""Document segment summary status"""
|
||||
|
||||
NOT_STARTED = "not_started"
|
||||
GENERATING = "generating"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class MessageChainType(StrEnum):
|
||||
"""Message chain type"""
|
||||
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = "paid"
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> "ProviderQuotaType":
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
|
|
|||
|
|
@ -380,13 +380,12 @@ class App(Base):
|
|||
|
||||
@property
|
||||
def site(self) -> Site | None:
|
||||
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
||||
return site
|
||||
return db.session.scalar(select(Site).where(Site.app_id == self.id))
|
||||
|
||||
@property
|
||||
def app_model_config(self) -> AppModelConfig | None:
|
||||
if self.app_model_config_id:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -395,7 +394,7 @@ class App(Base):
|
|||
if self.workflow_id:
|
||||
from .workflow import Workflow
|
||||
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -405,8 +404,7 @@ class App(Base):
|
|||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
@property
|
||||
def is_agent(self) -> bool:
|
||||
|
|
@ -546,9 +544,9 @@ class App(Base):
|
|||
return deleted_tools
|
||||
|
||||
@property
|
||||
def tags(self) -> list[Tag]:
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
def tags(self) -> Sequence[Tag]:
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
|
|
@ -556,15 +554,14 @@ class App(Base):
|
|||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "app",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
@property
|
||||
def author_name(self) -> str | None:
|
||||
if self.created_by:
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
|
|
@ -616,8 +613,7 @@ class AppModelConfig(TypeBase):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def model_dict(self) -> ModelConfig:
|
||||
|
|
@ -652,8 +648,8 @@ class AppModelConfig(TypeBase):
|
|||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
|
@ -845,8 +841,7 @@ class RecommendedApp(Base): # bug
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class InstalledApp(TypeBase):
|
||||
|
|
@ -873,13 +868,11 @@ class InstalledApp(TypeBase):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
|
|
@ -899,8 +892,7 @@ class TrialApp(Base):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
|
|
@ -919,13 +911,11 @@ class AccountTrialAppRecord(Base):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return user
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class ExporleBanner(TypeBase):
|
||||
|
|
@ -1117,8 +1107,8 @@ class Conversation(Base):
|
|||
else:
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
else:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
app_model_config = db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
|
||||
)
|
||||
if app_model_config:
|
||||
model_config = app_model_config.to_dict()
|
||||
|
|
@ -1141,36 +1131,43 @@ class Conversation(Base):
|
|||
|
||||
@property
|
||||
def annotated(self):
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
|
||||
)
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
|
||||
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
|
||||
|
||||
@property
|
||||
def message_count(self):
|
||||
return db.session.query(Message).where(Message.conversation_id == self.id).count()
|
||||
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def user_feedback_stats(self):
|
||||
like = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
|
|
@ -1178,23 +1175,25 @@ class Conversation(Base):
|
|||
@property
|
||||
def admin_feedback_stats(self):
|
||||
like = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
|
|
@ -1256,22 +1255,19 @@ class Conversation(Base):
|
|||
|
||||
@property
|
||||
def first_message(self):
|
||||
return (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == self.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return session.query(App).where(App.id == self.app_id).first()
|
||||
return session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def from_end_user_session_id(self):
|
||||
if self.from_end_user_id:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
|
||||
if end_user:
|
||||
return end_user.session_id
|
||||
|
||||
|
|
@ -1280,7 +1276,7 @@ class Conversation(Base):
|
|||
@property
|
||||
def from_account_name(self) -> str | None:
|
||||
if self.from_account_id:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
|
|
@ -1505,21 +1501,15 @@ class Message(Base):
|
|||
|
||||
@property
|
||||
def user_feedback(self):
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def admin_feedback(self):
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def feedbacks(self):
|
||||
|
|
@ -1528,28 +1518,27 @@ class Message(Base):
|
|||
|
||||
@property
|
||||
def annotation(self):
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
|
||||
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
|
||||
return annotation
|
||||
|
||||
@property
|
||||
def annotation_hit_history(self):
|
||||
annotation_history = (
|
||||
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
|
||||
annotation_history = db.session.scalar(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
|
||||
)
|
||||
if annotation_history:
|
||||
annotation = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
)
|
||||
return annotation
|
||||
return None
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
|
||||
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
|
||||
if conversation:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
|
||||
return db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -1562,13 +1551,12 @@ class Message(Base):
|
|||
return json.loads(self.message_metadata) if self.message_metadata else {}
|
||||
|
||||
@property
|
||||
def agent_thoughts(self) -> list[MessageAgentThought]:
|
||||
return (
|
||||
db.session.query(MessageAgentThought)
|
||||
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
|
||||
return db.session.scalars(
|
||||
select(MessageAgentThought)
|
||||
.where(MessageAgentThought.message_id == self.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
@property
|
||||
def retriever_resources(self) -> Any:
|
||||
|
|
@ -1579,7 +1567,7 @@ class Message(Base):
|
|||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
current_app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
|
|
@ -1743,8 +1731,7 @@ class MessageFeedback(TypeBase):
|
|||
|
||||
@property
|
||||
def from_account(self) -> Account | None:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
return {
|
||||
|
|
@ -1817,13 +1804,11 @@ class MessageAnnotation(Base):
|
|||
|
||||
@property
|
||||
def account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(TypeBase):
|
||||
|
|
@ -1852,18 +1837,15 @@ class AppAnnotationHitHistory(TypeBase):
|
|||
|
||||
@property
|
||||
def account(self):
|
||||
account = (
|
||||
db.session.query(Account)
|
||||
return db.session.scalar(
|
||||
select(Account)
|
||||
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
|
||||
.where(MessageAnnotation.id == self.annotation_id)
|
||||
.first()
|
||||
)
|
||||
return account
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class AppAnnotationSetting(TypeBase):
|
||||
|
|
@ -1896,12 +1878,9 @@ class AppAnnotationSetting(TypeBase):
|
|||
def collection_binding_detail(self):
|
||||
from .dataset import DatasetCollectionBinding
|
||||
|
||||
collection_binding_detail = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
)
|
||||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(TypeBase):
|
||||
|
|
@ -2007,7 +1986,9 @@ class AppMCPServer(TypeBase):
|
|||
def generate_server_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
|
||||
while (
|
||||
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
|
||||
) > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
|
|
@ -2068,7 +2049,7 @@ class Site(Base):
|
|||
def generate_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(Site).where(Site.code == result).count() > 0:
|
||||
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@ from functools import cached_property
|
|||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, text
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
|
|
@ -96,7 +97,7 @@ class Provider(TypeBase):
|
|||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
|
||||
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
|
|
@ -159,10 +160,8 @@ class ProviderModel(TypeBase):
|
|||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.where(ProviderModelCredential.id == self.credential_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -239,7 +238,9 @@ class ProviderOrder(TypeBase):
|
|||
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
|
||||
currency: Mapped[str | None] = mapped_column(String(40))
|
||||
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
|
||||
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'"))
|
||||
payment_status: Mapped[PaymentStatus] = mapped_column(
|
||||
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
|
||||
)
|
||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
|
|
@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
|
|||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
|
||||
credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
|
||||
EnumText(CredentialSourceType, length=40), nullable=True, default=None
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from uuid import uuid4
|
|||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy import ForeignKey, String, func
|
||||
from sqlalchemy import ForeignKey, String, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
|
@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
|
|||
def user(self) -> Account | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class ToolLabelBinding(TypeBase):
|
||||
|
|
@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
|
|||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
@property
|
||||
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
|
||||
|
|
@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).where(App.id == self.app_id).first()
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class MCPToolProvider(TypeBase):
|
||||
|
|
@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
|
|||
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from datetime import datetime
|
|||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy import DateTime, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import TypeBase
|
||||
|
|
@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
|
|||
|
||||
@property
|
||||
def message(self):
|
||||
return db.session.query(Message).where(Message.id == self.message_id).first()
|
||||
return db.session.scalar(select(Message).where(Message.id == self.message_id))
|
||||
|
||||
|
||||
class PinnedConversation(TypeBase):
|
||||
|
|
|
|||
|
|
@ -679,14 +679,14 @@ class WorkflowRun(Base):
|
|||
def message(self):
|
||||
from .model import Message
|
||||
|
||||
return (
|
||||
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -182,4 +182,9 @@ tasks/app_generate/workflow_execute_task.py
|
|||
tasks/regenerate_summary_index_task.py
|
||||
tasks/trigger_processing_tasks.py
|
||||
tasks/workflow_cfs_scheduler/cfs_scheduler.py
|
||||
tasks/add_document_to_index_task.py
|
||||
tasks/create_segment_to_index_task.py
|
||||
tasks/disable_segment_from_index_task.py
|
||||
tasks/enable_segment_to_index_task.py
|
||||
tasks/remove_document_from_index_task.py
|
||||
tasks/workflow_execution_tasks.py
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
|
|||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
|
||||
|
||||
class AgentService:
|
||||
|
|
@ -47,7 +47,7 @@ class AgentService:
|
|||
if not message:
|
||||
raise ValueError(f"Message not found: {message_id}")
|
||||
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
agent_thoughts = message.agent_thoughts
|
||||
|
||||
if conversation.from_end_user_id:
|
||||
# only select name field
|
||||
|
|
|
|||
|
|
@ -51,6 +51,14 @@ from models.dataset import (
|
|||
Pipeline,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.enums import (
|
||||
DatasetRuntimeMode,
|
||||
DataSourceType,
|
||||
DocumentCreatedFrom,
|
||||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
|
@ -319,7 +327,7 @@ class DatasetService:
|
|||
description=rag_pipeline_dataset_create_entity.description,
|
||||
permission=rag_pipeline_dataset_create_entity.permission,
|
||||
provider="vendor",
|
||||
runtime_mode="rag_pipeline",
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
||||
created_by=current_user.id,
|
||||
pipeline_id=pipeline.id,
|
||||
|
|
@ -614,7 +622,7 @@ class DatasetService:
|
|||
"""
|
||||
Update pipeline knowledge base node data.
|
||||
"""
|
||||
if dataset.runtime_mode != "rag_pipeline":
|
||||
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
|
||||
return
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
|
||||
|
|
@ -1229,10 +1237,15 @@ class DocumentService:
|
|||
"enabled": "available",
|
||||
}
|
||||
|
||||
_INDEXING_STATUSES: tuple[str, ...] = ("parsing", "cleaning", "splitting", "indexing")
|
||||
_INDEXING_STATUSES: tuple[IndexingStatus, ...] = (
|
||||
IndexingStatus.PARSING,
|
||||
IndexingStatus.CLEANING,
|
||||
IndexingStatus.SPLITTING,
|
||||
IndexingStatus.INDEXING,
|
||||
)
|
||||
|
||||
DISPLAY_STATUS_FILTERS: dict[str, tuple[Any, ...]] = {
|
||||
"queuing": (Document.indexing_status == "waiting",),
|
||||
"queuing": (Document.indexing_status == IndexingStatus.WAITING,),
|
||||
"indexing": (
|
||||
Document.indexing_status.in_(_INDEXING_STATUSES),
|
||||
Document.is_paused.is_not(True),
|
||||
|
|
@ -1241,19 +1254,19 @@ class DocumentService:
|
|||
Document.indexing_status.in_(_INDEXING_STATUSES),
|
||||
Document.is_paused.is_(True),
|
||||
),
|
||||
"error": (Document.indexing_status == "error",),
|
||||
"error": (Document.indexing_status == IndexingStatus.ERROR,),
|
||||
"available": (
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived.is_(False),
|
||||
Document.enabled.is_(True),
|
||||
),
|
||||
"disabled": (
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived.is_(False),
|
||||
Document.enabled.is_(False),
|
||||
),
|
||||
"archived": (
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived.is_(True),
|
||||
),
|
||||
}
|
||||
|
|
@ -1536,7 +1549,7 @@ class DocumentService:
|
|||
"""
|
||||
Normalize and validate `Document -> UploadFile` linkage for download flows.
|
||||
"""
|
||||
if document.data_source_type != "upload_file":
|
||||
if document.data_source_type != DataSourceType.UPLOAD_FILE:
|
||||
raise NotFound(invalid_source_message)
|
||||
|
||||
data_source_info: dict[str, Any] = document.data_source_info_dict or {}
|
||||
|
|
@ -1617,7 +1630,7 @@ class DocumentService:
|
|||
select(Document).where(
|
||||
Document.id.in_(document_ids),
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived == False,
|
||||
)
|
||||
).all()
|
||||
|
|
@ -1640,7 +1653,7 @@ class DocumentService:
|
|||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.indexing_status == IndexingStatus.COMPLETED,
|
||||
Document.archived == False,
|
||||
)
|
||||
).all()
|
||||
|
|
@ -1650,7 +1663,10 @@ class DocumentService:
|
|||
@staticmethod
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]),
|
||||
)
|
||||
).all()
|
||||
return documents
|
||||
|
||||
|
|
@ -1683,7 +1699,7 @@ class DocumentService:
|
|||
def delete_document(document):
|
||||
# trigger document_was_deleted signal
|
||||
file_id = None
|
||||
if document.data_source_type == "upload_file":
|
||||
if document.data_source_type == DataSourceType.UPLOAD_FILE:
|
||||
if document.data_source_info:
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
|
|
@ -1704,7 +1720,7 @@ class DocumentService:
|
|||
file_ids = [
|
||||
document.data_source_info_dict.get("upload_file_id", "")
|
||||
for document in documents
|
||||
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
||||
if document.data_source_type == DataSourceType.UPLOAD_FILE and document.data_source_info_dict
|
||||
]
|
||||
|
||||
# Delete documents first, then dispatch cleanup task after commit
|
||||
|
|
@ -1753,7 +1769,13 @@ class DocumentService:
|
|||
|
||||
@staticmethod
|
||||
def pause_document(document):
|
||||
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
|
||||
if document.indexing_status not in {
|
||||
IndexingStatus.WAITING,
|
||||
IndexingStatus.PARSING,
|
||||
IndexingStatus.CLEANING,
|
||||
IndexingStatus.SPLITTING,
|
||||
IndexingStatus.INDEXING,
|
||||
}:
|
||||
raise DocumentIndexingError()
|
||||
# update document to be paused
|
||||
assert current_user is not None
|
||||
|
|
@ -1793,7 +1815,7 @@ class DocumentService:
|
|||
if cache_result is not None:
|
||||
raise ValueError("Document is being retried, please try again later")
|
||||
# retry document indexing
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1812,7 +1834,7 @@ class DocumentService:
|
|||
if cache_result is not None:
|
||||
raise ValueError("Document is being synced, please try again later")
|
||||
# sync document indexing
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info:
|
||||
data_source_info["mode"] = "scrape"
|
||||
|
|
@ -1840,7 +1862,7 @@ class DocumentService:
|
|||
knowledge_config: KnowledgeConfig,
|
||||
account: Account | Any,
|
||||
dataset_process_rule: DatasetProcessRule | None = None,
|
||||
created_from: str = "web",
|
||||
created_from: str = DocumentCreatedFrom.WEB,
|
||||
) -> tuple[list[Document], str]:
|
||||
# check doc_form
|
||||
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
|
||||
|
|
@ -1932,7 +1954,7 @@ class DocumentService:
|
|||
if not dataset_process_rule:
|
||||
process_rule = knowledge_config.process_rule
|
||||
if process_rule:
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
if process_rule.mode in (ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL):
|
||||
if process_rule.rules:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
|
|
@ -1944,7 +1966,7 @@ class DocumentService:
|
|||
dataset_process_rule = dataset.latest_process_rule
|
||||
if not dataset_process_rule:
|
||||
raise ValueError("No process rule found.")
|
||||
elif process_rule.mode == "automatic":
|
||||
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
|
|
@ -1967,7 +1989,7 @@ class DocumentService:
|
|||
if not dataset_process_rule:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode="automatic",
|
||||
mode=ProcessRuleMode.AUTOMATIC,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -2001,7 +2023,7 @@ class DocumentService:
|
|||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == "upload_file",
|
||||
Document.data_source_type == DataSourceType.UPLOAD_FILE,
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
|
|
@ -2021,7 +2043,7 @@ class DocumentService:
|
|||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
|
|
@ -2056,7 +2078,7 @@ class DocumentService:
|
|||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
|
|
@ -2507,7 +2529,7 @@ class DocumentService:
|
|||
document_data: KnowledgeConfig,
|
||||
account: Account,
|
||||
dataset_process_rule: DatasetProcessRule | None = None,
|
||||
created_from: str = "web",
|
||||
created_from: str = DocumentCreatedFrom.WEB,
|
||||
):
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
|
|
@ -2520,14 +2542,14 @@ class DocumentService:
|
|||
# save process rule
|
||||
if document_data.process_rule:
|
||||
process_rule = document_data.process_rule
|
||||
if process_rule.mode in {"custom", "hierarchical"}:
|
||||
if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
|
|
@ -2609,7 +2631,7 @@ class DocumentService:
|
|||
if document_data.name:
|
||||
document.name = document_data.name
|
||||
# update document to be waiting
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
document.completed_at = None
|
||||
document.processing_started_at = None
|
||||
document.parsing_completed_at = None
|
||||
|
|
@ -2623,7 +2645,7 @@ class DocumentService:
|
|||
# update document segment
|
||||
|
||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
||||
{DocumentSegment.status: "re_segment"}
|
||||
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
|
||||
)
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
|
|
@ -2754,7 +2776,7 @@ class DocumentService:
|
|||
if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if knowledge_config.process_rule.mode == "automatic":
|
||||
if knowledge_config.process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
knowledge_config.process_rule.rules = None
|
||||
else:
|
||||
if not knowledge_config.process_rule.rules:
|
||||
|
|
@ -2785,7 +2807,7 @@ class DocumentService:
|
|||
raise ValueError("Process rule segmentation separator is invalid")
|
||||
|
||||
if not (
|
||||
knowledge_config.process_rule.mode == "hierarchical"
|
||||
knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL
|
||||
and knowledge_config.process_rule.rules.parent_mode == "full-doc"
|
||||
):
|
||||
if not knowledge_config.process_rule.rules.segmentation.max_tokens:
|
||||
|
|
@ -2814,7 +2836,7 @@ class DocumentService:
|
|||
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if args["process_rule"]["mode"] == "automatic":
|
||||
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
|
||||
args["process_rule"]["rules"] = {}
|
||||
else:
|
||||
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
|
||||
|
|
@ -3021,7 +3043,7 @@ class DocumentService:
|
|||
@staticmethod
|
||||
def _prepare_disable_update(document, user, now):
|
||||
"""Prepare updates for disabling a document."""
|
||||
if not document.completed_at or document.indexing_status != "completed":
|
||||
if not document.completed_at or document.indexing_status != IndexingStatus.COMPLETED:
|
||||
raise DocumentIndexingError(f"Document: {document.name} is not completed.")
|
||||
|
||||
if not document.enabled:
|
||||
|
|
@ -3130,7 +3152,7 @@ class SegmentService:
|
|||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
indexing_at=naive_utc_now(),
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
|
|
@ -3167,7 +3189,7 @@ class SegmentService:
|
|||
logger.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.status = SegmentStatus.ERROR
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
|
||||
|
|
@ -3227,7 +3249,7 @@ class SegmentService:
|
|||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
keywords=segment_item.get("keywords", []),
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
indexing_at=naive_utc_now(),
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
|
|
@ -3259,7 +3281,7 @@ class SegmentService:
|
|||
for segment_document in segment_data_list:
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.status = SegmentStatus.ERROR
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
return segment_data_list
|
||||
|
|
@ -3405,7 +3427,7 @@ class SegmentService:
|
|||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
segment.tokens = tokens
|
||||
segment.status = "completed"
|
||||
segment.status = SegmentStatus.COMPLETED
|
||||
segment.indexing_at = naive_utc_now()
|
||||
segment.completed_at = naive_utc_now()
|
||||
segment.updated_by = current_user.id
|
||||
|
|
@ -3530,7 +3552,7 @@ class SegmentService:
|
|||
logger.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.status = SegmentStatus.ERROR
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode
|
|||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ class HitTestingService:
|
|||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=json.dumps(dataset_queries),
|
||||
source="hit_testing",
|
||||
source=DatasetQuerySource.HIT_TESTING,
|
||||
source_app_id=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
|
|
@ -137,7 +137,7 @@ class HitTestingService:
|
|||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=query,
|
||||
source="hit_testing",
|
||||
source=DatasetQuerySource.HIT_TESTING,
|
||||
source_app_id=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from extensions.ext_redis import redis_client
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
|
||||
from models.enums import DatasetMetadataType
|
||||
from services.dataset_service import DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
|
|
@ -130,11 +131,11 @@ class MetadataService:
|
|||
@staticmethod
|
||||
def get_built_in_fields():
|
||||
return [
|
||||
{"name": BuiltInField.document_name, "type": "string"},
|
||||
{"name": BuiltInField.uploader, "type": "string"},
|
||||
{"name": BuiltInField.upload_date, "type": "time"},
|
||||
{"name": BuiltInField.last_update_date, "type": "time"},
|
||||
{"name": BuiltInField.source, "type": "string"},
|
||||
{"name": BuiltInField.document_name, "type": DatasetMetadataType.STRING},
|
||||
{"name": BuiltInField.uploader, "type": DatasetMetadataType.STRING},
|
||||
{"name": BuiltInField.upload_date, "type": DatasetMetadataType.TIME},
|
||||
{"name": BuiltInField.last_update_date, "type": DatasetMetadataType.TIME},
|
||||
{"name": BuiltInField.source, "type": DatasetMetadataType.STRING},
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from dify_graph.model_runtime.entities.provider_entities import (
|
|||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -103,9 +104,9 @@ class ModelLoadBalancingService:
|
|||
is_load_balancing_enabled = True
|
||||
|
||||
if config_from == "predefined-model":
|
||||
credential_source_type = "provider"
|
||||
credential_source_type = CredentialSourceType.PROVIDER
|
||||
else:
|
||||
credential_source_type = "custom_model"
|
||||
credential_source_type = CredentialSourceType.CUSTOM_MODEL
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = (
|
||||
|
|
@ -421,7 +422,11 @@ class ModelLoadBalancingService:
|
|||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
if credential_id:
|
||||
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
|
||||
credential_source = (
|
||||
CredentialSourceType.PROVIDER
|
||||
if config_from == "predefined-model"
|
||||
else CredentialSourceType.CUSTOM_MODEL
|
||||
)
|
||||
assert credential_record is not None
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Pipeline
|
||||
from models.enums import IndexingStatus
|
||||
from models.model import Account, App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
|
@ -111,6 +112,6 @@ class PipelineGenerateService:
|
|||
"""
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "waiting"
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
Retrieval recommended app from dify official
|
||||
"""
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str):
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
|
||||
result: dict | None
|
||||
try:
|
||||
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
|
||||
except Exception as e:
|
||||
|
|
@ -35,17 +36,23 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return PipelineTemplateType.REMOTE
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict:
|
||||
"""
|
||||
Fetch pipeline template detail from dify official.
|
||||
:param template_id: Pipeline ID
|
||||
:return:
|
||||
|
||||
:param template_id: Pipeline template ID
|
||||
:return: Template detail dict
|
||||
:raises ValueError: When upstream returns a non-200 status code
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/pipeline-templates/{template_id}"
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
raise ValueError(
|
||||
"fetch pipeline template detail failed,"
|
||||
+ f" status_code: {response.status_code},"
|
||||
+ f" response: {response.text[:1000]}"
|
||||
)
|
||||
data: dict = response.json()
|
||||
return data
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ from models.dataset import ( # type: ignore
|
|||
PipelineCustomizedTemplate,
|
||||
PipelineRecommendedPlugin,
|
||||
)
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.enums import IndexingStatus, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
|
|
@ -117,13 +117,21 @@ class RagPipelineService:
|
|||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
|
||||
"""
|
||||
Get pipeline template detail.
|
||||
|
||||
:param template_id: template id
|
||||
:return:
|
||||
:param type: template type, "built-in" or "customized"
|
||||
:return: template detail dict, or None if not found
|
||||
"""
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
if built_in_result is None:
|
||||
logger.warning(
|
||||
"pipeline template retrieval returned empty result, template_id: %s, mode: %s",
|
||||
template_id,
|
||||
mode,
|
||||
)
|
||||
return built_in_result
|
||||
else:
|
||||
mode = "customized"
|
||||
|
|
@ -906,7 +914,7 @@ class RagPipelineService:
|
|||
if document_id:
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = error
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from extensions.ext_redis import redis_client
|
|||
from factories import variable_factory
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
||||
from models.enums import CollectionBindingType, DatasetRuntimeMode
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
IconInfo,
|
||||
|
|
@ -313,7 +314,7 @@ class RagPipelineDslService:
|
|||
indexing_technique=knowledge_configuration.indexing_technique,
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
|
||||
runtime_mode="rag_pipeline",
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
|
|
@ -323,7 +324,7 @@ class RagPipelineDslService:
|
|||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
|
|
@ -334,7 +335,7 @@ class RagPipelineDslService:
|
|||
provider_name=knowledge_configuration.embedding_model_provider,
|
||||
model_name=knowledge_configuration.embedding_model,
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
type=CollectionBindingType.DATASET,
|
||||
)
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
|
|
@ -445,13 +446,13 @@ class RagPipelineDslService:
|
|||
indexing_technique=knowledge_configuration.indexing_technique,
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
|
||||
runtime_mode="rag_pipeline",
|
||||
runtime_mode=DatasetRuntimeMode.RAG_PIPELINE,
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
else:
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
|
|
@ -460,7 +461,7 @@ class RagPipelineDslService:
|
|||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
DatasetCollectionBinding.type == "dataset",
|
||||
DatasetCollectionBinding.type == CollectionBindingType.DATASET,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
|
|
@ -471,7 +472,7 @@ class RagPipelineDslService:
|
|||
provider_name=knowledge_configuration.embedding_model_provider,
|
||||
model_name=knowledge_configuration.embedding_model,
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
type=CollectionBindingType.DATASET,
|
||||
)
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import DatasetRuntimeMode, DataSourceType
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
|
||||
|
|
@ -27,7 +28,7 @@ class RagPipelineTransformService:
|
|||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
|
||||
if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE:
|
||||
return {
|
||||
"pipeline_id": dataset.pipeline_id,
|
||||
"dataset_id": dataset_id,
|
||||
|
|
@ -85,7 +86,7 @@ class RagPipelineTransformService:
|
|||
else:
|
||||
raise ValueError("Unsupported doc form")
|
||||
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.pipeline_id = pipeline.id
|
||||
|
||||
# deal document data
|
||||
|
|
@ -102,7 +103,7 @@ class RagPipelineTransformService:
|
|||
pipeline_yaml = {}
|
||||
if doc_form == "text_model":
|
||||
match datasource_type:
|
||||
case "upload_file":
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.file-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
|
||||
|
|
@ -111,7 +112,7 @@ class RagPipelineTransformService:
|
|||
# get graph from transform.file-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "notion_import":
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.notion-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
|
||||
|
|
@ -120,7 +121,7 @@ class RagPipelineTransformService:
|
|||
# get graph from transform.notion-general-economy.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "website_crawl":
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
if indexing_technique == "high_quality":
|
||||
# get graph from transform.website-crawl-general-high-quality.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
|
||||
|
|
@ -133,15 +134,15 @@ class RagPipelineTransformService:
|
|||
raise ValueError("Unsupported datasource type")
|
||||
elif doc_form == "hierarchical_model":
|
||||
match datasource_type:
|
||||
case "upload_file":
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
# get graph from transform.file-parentchild.yml
|
||||
with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "notion_import":
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
# get graph from transform.notion-parentchild.yml
|
||||
with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
case "website_crawl":
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
# get graph from transform.website-crawl-parentchild.yml
|
||||
with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f:
|
||||
pipeline_yaml = yaml.safe_load(f)
|
||||
|
|
@ -287,7 +288,7 @@ class RagPipelineTransformService:
|
|||
db.session.flush()
|
||||
|
||||
dataset.pipeline_id = pipeline.id
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE
|
||||
dataset.updated_by = current_user.id
|
||||
dataset.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.add(dataset)
|
||||
|
|
@ -310,8 +311,8 @@ class RagPipelineTransformService:
|
|||
data_source_info_dict = document.data_source_info_dict
|
||||
if not data_source_info_dict:
|
||||
continue
|
||||
if document.data_source_type == "upload_file":
|
||||
document.data_source_type = "local_file"
|
||||
if document.data_source_type == DataSourceType.UPLOAD_FILE:
|
||||
document.data_source_type = DataSourceType.LOCAL_FILE
|
||||
file_id = data_source_info_dict.get("upload_file_id")
|
||||
if file_id:
|
||||
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
|
@ -331,7 +332,7 @@ class RagPipelineTransformService:
|
|||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document.id,
|
||||
pipeline_id=dataset.pipeline_id,
|
||||
datasource_type="local_file",
|
||||
datasource_type=DataSourceType.LOCAL_FILE,
|
||||
datasource_info=data_source_info,
|
||||
input_data={},
|
||||
created_by=document.created_by,
|
||||
|
|
@ -340,8 +341,8 @@ class RagPipelineTransformService:
|
|||
document_pipeline_execution_log.created_at = document.created_at
|
||||
db.session.add(document)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
elif document.data_source_type == "notion_import":
|
||||
document.data_source_type = "online_document"
|
||||
elif document.data_source_type == DataSourceType.NOTION_IMPORT:
|
||||
document.data_source_type = DataSourceType.ONLINE_DOCUMENT
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
"workspace_id": data_source_info_dict.get("notion_workspace_id"),
|
||||
|
|
@ -359,7 +360,7 @@ class RagPipelineTransformService:
|
|||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document.id,
|
||||
pipeline_id=dataset.pipeline_id,
|
||||
datasource_type="online_document",
|
||||
datasource_type=DataSourceType.ONLINE_DOCUMENT,
|
||||
datasource_info=data_source_info,
|
||||
input_data={},
|
||||
created_by=document.created_by,
|
||||
|
|
@ -368,8 +369,7 @@ class RagPipelineTransformService:
|
|||
document_pipeline_execution_log.created_at = document.created_at
|
||||
db.session.add(document)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
document.data_source_type = "website_crawl"
|
||||
elif document.data_source_type == DataSourceType.WEBSITE_CRAWL:
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
"source_url": data_source_info_dict.get("url"),
|
||||
|
|
@ -388,7 +388,7 @@ class RagPipelineTransformService:
|
|||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document.id,
|
||||
pipeline_id=dataset.pipeline_id,
|
||||
datasource_type="website_crawl",
|
||||
datasource_type=DataSourceType.WEBSITE_CRAWL,
|
||||
datasource_info=data_source_info,
|
||||
input_data={},
|
||||
created_by=document.created_by,
|
||||
|
|
|
|||
|
|
@ -12,12 +12,14 @@ from core.db.session_factory import session_factory
|
|||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import SummaryStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,7 +31,7 @@ class SummaryIndexService:
|
|||
def generate_summary_for_segment(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Generate summary for a single segment.
|
||||
|
|
@ -73,7 +75,7 @@ class SummaryIndexService:
|
|||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_content: str,
|
||||
status: str = "generating",
|
||||
status: SummaryStatus = SummaryStatus.GENERATING,
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Create or update a DocumentSegmentSummary record.
|
||||
|
|
@ -83,7 +85,7 @@ class SummaryIndexService:
|
|||
segment: DocumentSegment to create summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_content: Generated summary content
|
||||
status: Summary status (default: "generating")
|
||||
status: Summary status (default: SummaryStatus.GENERATING)
|
||||
|
||||
Returns:
|
||||
Created or updated DocumentSegmentSummary instance
|
||||
|
|
@ -326,7 +328,7 @@ class SummaryIndexService:
|
|||
summary_index_node_id=summary_index_node_id,
|
||||
summary_index_node_hash=summary_hash,
|
||||
tokens=embedding_tokens,
|
||||
status="completed",
|
||||
status=SummaryStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
session.add(summary_record_in_session)
|
||||
|
|
@ -362,7 +364,7 @@ class SummaryIndexService:
|
|||
summary_record_in_session.summary_index_node_id = summary_index_node_id
|
||||
summary_record_in_session.summary_index_node_hash = summary_hash
|
||||
summary_record_in_session.tokens = embedding_tokens # Save embedding tokens
|
||||
summary_record_in_session.status = "completed"
|
||||
summary_record_in_session.status = SummaryStatus.COMPLETED
|
||||
# Ensure summary_content is preserved (use the latest from summary_record parameter)
|
||||
# This is critical: use the parameter value, not the database value
|
||||
summary_record_in_session.summary_content = summary_content
|
||||
|
|
@ -400,7 +402,7 @@ class SummaryIndexService:
|
|||
summary_record.summary_index_node_id = summary_index_node_id
|
||||
summary_record.summary_index_node_hash = summary_hash
|
||||
summary_record.tokens = embedding_tokens
|
||||
summary_record.status = "completed"
|
||||
summary_record.status = SummaryStatus.COMPLETED
|
||||
summary_record.summary_content = summary_content
|
||||
if summary_record_in_session.updated_at:
|
||||
summary_record.updated_at = summary_record_in_session.updated_at
|
||||
|
|
@ -487,7 +489,7 @@ class SummaryIndexService:
|
|||
)
|
||||
|
||||
if summary_record_in_session:
|
||||
summary_record_in_session.status = "error"
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
summary_record_in_session.error = f"Vectorization failed: {str(e)}"
|
||||
summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
error_session.add(summary_record_in_session)
|
||||
|
|
@ -498,7 +500,7 @@ class SummaryIndexService:
|
|||
summary_record_in_session.id,
|
||||
)
|
||||
# Update the original object for consistency
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = summary_record_in_session.error
|
||||
summary_record.updated_at = summary_record_in_session.updated_at
|
||||
else:
|
||||
|
|
@ -514,7 +516,7 @@ class SummaryIndexService:
|
|||
def batch_create_summary_records(
|
||||
segments: list[DocumentSegment],
|
||||
dataset: Dataset,
|
||||
status: str = "not_started",
|
||||
status: SummaryStatus = SummaryStatus.NOT_STARTED,
|
||||
) -> None:
|
||||
"""
|
||||
Batch create summary records for segments with specified status.
|
||||
|
|
@ -523,7 +525,7 @@ class SummaryIndexService:
|
|||
Args:
|
||||
segments: List of DocumentSegment instances
|
||||
dataset: Dataset containing the segments
|
||||
status: Initial status for the records (default: "not_started")
|
||||
status: Initial status for the records (default: SummaryStatus.NOT_STARTED)
|
||||
"""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if not segment_ids:
|
||||
|
|
@ -588,7 +590,7 @@ class SummaryIndexService:
|
|||
)
|
||||
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = error
|
||||
session.add(summary_record)
|
||||
session.commit()
|
||||
|
|
@ -599,7 +601,7 @@ class SummaryIndexService:
|
|||
def generate_and_vectorize_summary(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Generate summary for a segment and vectorize it.
|
||||
|
|
@ -631,14 +633,14 @@ class SummaryIndexService:
|
|||
document_id=segment.document_id,
|
||||
chunk_id=segment.id,
|
||||
summary_content="",
|
||||
status="generating",
|
||||
status=SummaryStatus.GENERATING,
|
||||
enabled=True,
|
||||
)
|
||||
session.add(summary_record_in_session)
|
||||
session.flush()
|
||||
|
||||
# Update status to "generating"
|
||||
summary_record_in_session.status = "generating"
|
||||
summary_record_in_session.status = SummaryStatus.GENERATING
|
||||
summary_record_in_session.error = None # type: ignore[assignment]
|
||||
session.add(summary_record_in_session)
|
||||
# Don't flush here - wait until after vectorization succeeds
|
||||
|
|
@ -681,7 +683,7 @@ class SummaryIndexService:
|
|||
except Exception as vectorize_error:
|
||||
# If vectorization fails, update status to error in current session
|
||||
logger.exception("Failed to vectorize summary for segment %s", segment.id)
|
||||
summary_record_in_session.status = "error"
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}"
|
||||
session.add(summary_record_in_session)
|
||||
session.commit()
|
||||
|
|
@ -694,7 +696,7 @@ class SummaryIndexService:
|
|||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
if summary_record_in_session:
|
||||
summary_record_in_session.status = "error"
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
summary_record_in_session.error = str(e)
|
||||
session.add(summary_record_in_session)
|
||||
session.commit()
|
||||
|
|
@ -704,7 +706,7 @@ class SummaryIndexService:
|
|||
def generate_summaries_for_document(
|
||||
dataset: Dataset,
|
||||
document: DatasetDocument,
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
segment_ids: list[str] | None = None,
|
||||
only_parent_chunks: bool = False,
|
||||
) -> list[DocumentSegmentSummary]:
|
||||
|
|
@ -770,7 +772,7 @@ class SummaryIndexService:
|
|||
SummaryIndexService.batch_create_summary_records(
|
||||
segments=segments,
|
||||
dataset=dataset,
|
||||
status="not_started",
|
||||
status=SummaryStatus.NOT_STARTED,
|
||||
)
|
||||
|
||||
summary_records = []
|
||||
|
|
@ -1067,7 +1069,7 @@ class SummaryIndexService:
|
|||
|
||||
# Update summary content
|
||||
summary_record.summary_content = summary_content
|
||||
summary_record.status = "generating"
|
||||
summary_record.status = SummaryStatus.GENERATING
|
||||
summary_record.error = None # type: ignore[assignment] # Clear any previous errors
|
||||
session.add(summary_record)
|
||||
# Flush to ensure summary_content is saved before vectorize_summary queries it
|
||||
|
|
@ -1102,7 +1104,7 @@ class SummaryIndexService:
|
|||
# If vectorization fails, update status to error in current session
|
||||
# Don't raise the exception - just log it and return the record with error status
|
||||
# This allows the segment update to complete even if vectorization fails
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = f"Vectorization failed: {str(e)}"
|
||||
session.commit()
|
||||
logger.exception("Failed to vectorize summary for segment %s", segment.id)
|
||||
|
|
@ -1112,7 +1114,7 @@ class SummaryIndexService:
|
|||
else:
|
||||
# Create new summary record if doesn't exist
|
||||
summary_record = SummaryIndexService.create_summary_record(
|
||||
segment, dataset, summary_content, status="generating"
|
||||
segment, dataset, summary_content, status=SummaryStatus.GENERATING
|
||||
)
|
||||
# Re-vectorize summary (this will update status to "completed" and tokens in its own session)
|
||||
# Note: summary_record was created in a different session,
|
||||
|
|
@ -1132,7 +1134,7 @@ class SummaryIndexService:
|
|||
# If vectorization fails, update status to error in current session
|
||||
# Merge the record into current session first
|
||||
error_record = session.merge(summary_record)
|
||||
error_record.status = "error"
|
||||
error_record.status = SummaryStatus.ERROR
|
||||
error_record.error = f"Vectorization failed: {str(e)}"
|
||||
session.commit()
|
||||
logger.exception("Failed to vectorize summary for segment %s", segment.id)
|
||||
|
|
@ -1146,7 +1148,7 @@ class SummaryIndexService:
|
|||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = str(e)
|
||||
session.add(summary_record)
|
||||
session.commit()
|
||||
|
|
@ -1266,7 +1268,7 @@ class SummaryIndexService:
|
|||
# Check if there are any "not_started" or "generating" status summaries
|
||||
has_pending_summaries = any(
|
||||
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
|
||||
and summary_status_map[segment_id] in ("not_started", "generating")
|
||||
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
|
||||
for segment_id in segment_ids
|
||||
)
|
||||
|
||||
|
|
@ -1330,7 +1332,7 @@ class SummaryIndexService:
|
|||
# it means the summary is disabled (enabled=False) or not created yet, ignore it
|
||||
has_pending_summaries = any(
|
||||
summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
|
||||
and summary_status_map[segment_id] in ("not_started", "generating")
|
||||
and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING)
|
||||
for segment_id in segment_ids
|
||||
)
|
||||
|
||||
|
|
@ -1393,17 +1395,17 @@ class SummaryIndexService:
|
|||
|
||||
# Count statuses
|
||||
status_counts = {
|
||||
"completed": 0,
|
||||
"generating": 0,
|
||||
"error": 0,
|
||||
"not_started": 0,
|
||||
SummaryStatus.COMPLETED: 0,
|
||||
SummaryStatus.GENERATING: 0,
|
||||
SummaryStatus.ERROR: 0,
|
||||
SummaryStatus.NOT_STARTED: 0,
|
||||
}
|
||||
|
||||
summary_list = []
|
||||
for segment in segments:
|
||||
summary = summary_map.get(segment.id)
|
||||
if summary:
|
||||
status = summary.status
|
||||
status = SummaryStatus(summary.status)
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
summary_list.append(
|
||||
{
|
||||
|
|
@ -1421,12 +1423,12 @@ class SummaryIndexService:
|
|||
}
|
||||
)
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
status_counts[SummaryStatus.NOT_STARTED] += 1
|
||||
summary_list.append(
|
||||
{
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": "not_started",
|
||||
"status": SummaryStatus.NOT_STARTED,
|
||||
"summary_preview": None,
|
||||
"error": None,
|
||||
"created_at": None,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from extensions.ext_redis import redis_client
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DatasetAutoDisableLog, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,7 +35,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
|
||||
return
|
||||
|
||||
if dataset_document.indexing_status != "completed":
|
||||
if dataset_document.indexing_status != IndexingStatus.COMPLETED:
|
||||
return
|
||||
|
||||
indexing_cache_key = f"document_{dataset_document.id}_indexing"
|
||||
|
|
@ -48,7 +49,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.status == SegmentStatus.COMPLETED,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
|
|
@ -139,7 +140,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||
logger.exception("add document to index failed")
|
||||
dataset_document.enabled = False
|
||||
dataset_document.disabled_at = naive_utc_now()
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.indexing_status = IndexingStatus.ERROR
|
||||
dataset_document.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
|
@ -47,7 +48,7 @@ def enable_annotation_reply_task(
|
|||
try:
|
||||
documents = []
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
|
||||
)
|
||||
annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
|
|
@ -56,7 +57,7 @@ def enable_annotation_reply_task(
|
|||
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
|
||||
old_dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
annotation_setting.collection_binding_id, "annotation"
|
||||
annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION
|
||||
)
|
||||
)
|
||||
if old_dataset_collection_binding and annotations:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DocumentSegment
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
|||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
if segment.status != "waiting":
|
||||
if segment.status != SegmentStatus.WAITING:
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
|
@ -40,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
|||
# update segment status to indexing
|
||||
session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.status: SegmentStatus.INDEXING,
|
||||
DocumentSegment.indexing_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
|
|
@ -70,7 +71,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
|||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
or dataset_document.indexing_status != IndexingStatus.COMPLETED
|
||||
):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
|
@ -82,7 +83,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
|||
# update segment to completed
|
||||
session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
|
|
@ -94,7 +95,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
|||
logger.exception("create segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.status = SegmentStatus.ERROR
|
||||
segment.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
|||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -37,7 +38,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
if document.indexing_status == "parsing":
|
||||
if document.indexing_status == IndexingStatus.PARSING:
|
||||
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
|
||||
return
|
||||
|
||||
|
|
@ -88,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
return
|
||||
|
|
@ -128,7 +129,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
data_source_info["last_edited_time"] = last_edited_time
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
|
|
@ -151,6 +152,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
|||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import IndexingStatus
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
|
|
@ -81,7 +82,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
|
|
@ -96,7 +97,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
# Transaction committed and closed
|
||||
|
|
@ -148,7 +149,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
document.need_summary,
|
||||
)
|
||||
if (
|
||||
document.indexing_status == "completed"
|
||||
document.indexing_status == IndexingStatus.COMPLETED
|
||||
and document.doc_form != "qa_model"
|
||||
and document.need_summary is True
|
||||
):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
|||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -33,7 +34,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
|||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -112,7 +113,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
|
|||
)
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
|
|
@ -146,7 +147,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
|
|||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
|||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DocumentSegment
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -33,7 +34,7 @@ def enable_segment_to_index_task(segment_id: str):
|
|||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
if segment.status != SegmentStatus.COMPLETED:
|
||||
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
|
|
@ -65,7 +66,7 @@ def enable_segment_to_index_task(segment_id: str):
|
|||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
or dataset_document.indexing_status != IndexingStatus.COMPLETED
|
||||
):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
|
@ -123,7 +124,7 @@ def enable_segment_to_index_task(segment_id: str):
|
|||
logger.exception("enable segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.status = SegmentStatus.ERROR
|
||||
segment.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from extensions.ext_redis import redis_client
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.feature_service import FeatureService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
|
@ -63,7 +64,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
|||
.first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
|
|
@ -95,7 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
|||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
|
|
@ -108,7 +109,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
|||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
|||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -48,7 +49,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
|||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
|
|
@ -76,7 +77,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
|||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.indexing_status = IndexingStatus.PARSING
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
|
|
@ -85,7 +86,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
|||
indexing_runner.run([document])
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from faker import Faker
|
|||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
|
@ -35,7 +36,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
|
|
@ -49,14 +50,14 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
name=f"Document {i}",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
|
|
@ -94,7 +95,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -106,13 +107,13 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Archived Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # Archived
|
||||
)
|
||||
|
|
@ -147,7 +148,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -159,13 +160,13 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Disabled Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # Disabled
|
||||
archived=False,
|
||||
)
|
||||
|
|
@ -200,21 +201,21 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
# Create documents with non-completed status
|
||||
for i, status in enumerate(["indexing", "parsing", "splitting"]):
|
||||
for i, status in enumerate([IndexingStatus.INDEXING, IndexingStatus.PARSING, IndexingStatus.SPLITTING]):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {status}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
|
|
@ -263,7 +264,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="external", # External provider
|
||||
data_source_type="external",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -307,7 +308,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant1.id,
|
||||
name="Tenant 1 Dataset",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account1.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset1)
|
||||
|
|
@ -318,7 +319,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant2.id,
|
||||
name="Tenant 2 Dataset",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account2.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset2)
|
||||
|
|
@ -330,13 +331,13 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document for {dataset.name}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
|
|
@ -398,7 +399,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=f"Dataset {i}",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -410,13 +411,13 @@ class TestGetAvailableDatasetsIntegration:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
|
|
@ -456,7 +457,7 @@ class TestKnowledgeRetrievalIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
|
|
@ -467,12 +468,12 @@ class TestKnowledgeRetrievalIntegration:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=fake.sentence(),
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
|
|
@ -525,7 +526,7 @@ class TestKnowledgeRetrievalIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -572,7 +573,7 @@ class TestKnowledgeRetrievalIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
|
||||
|
||||
class TestDatasetDocumentProperties:
|
||||
|
|
@ -29,7 +30,7 @@ class TestDatasetDocumentProperties:
|
|||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -39,10 +40,10 @@ class TestDatasetDocumentProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=i + 1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name=f"doc_{i}.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(doc)
|
||||
|
|
@ -56,7 +57,7 @@ class TestDatasetDocumentProperties:
|
|||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -65,12 +66,12 @@ class TestDatasetDocumentProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="available.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
|
|
@ -78,12 +79,12 @@ class TestDatasetDocumentProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=2,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="pending.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status="waiting",
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
|
|
@ -91,12 +92,12 @@ class TestDatasetDocumentProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=3,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="disabled.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False,
|
||||
archived=False,
|
||||
)
|
||||
|
|
@ -111,7 +112,7 @@ class TestDatasetDocumentProperties:
|
|||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -121,10 +122,10 @@ class TestDatasetDocumentProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=i + 1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name=f"doc_{i}.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
word_count=wc,
|
||||
)
|
||||
|
|
@ -139,7 +140,7 @@ class TestDatasetDocumentProperties:
|
|||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -148,10 +149,10 @@ class TestDatasetDocumentProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="doc.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(doc)
|
||||
|
|
@ -166,7 +167,7 @@ class TestDatasetDocumentProperties:
|
|||
content=f"segment {i}",
|
||||
word_count=100,
|
||||
tokens=50,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
@ -180,7 +181,7 @@ class TestDatasetDocumentProperties:
|
|||
content="waiting segment",
|
||||
word_count=100,
|
||||
tokens=50,
|
||||
status="waiting",
|
||||
status=SegmentStatus.WAITING,
|
||||
enabled=True,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
@ -195,7 +196,7 @@ class TestDatasetDocumentProperties:
|
|||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -204,10 +205,10 @@ class TestDatasetDocumentProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="doc.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(doc)
|
||||
|
|
@ -235,7 +236,7 @@ class TestDatasetDocumentProperties:
|
|||
created_by = str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
|
||||
tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -244,10 +245,10 @@ class TestDatasetDocumentProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="doc.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(doc)
|
||||
|
|
@ -288,7 +289,7 @@ class TestDocumentSegmentNavigationProperties:
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -298,10 +299,10 @@ class TestDocumentSegmentNavigationProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="test.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
@ -335,7 +336,7 @@ class TestDocumentSegmentNavigationProperties:
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -345,10 +346,10 @@ class TestDocumentSegmentNavigationProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="test.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
@ -382,7 +383,7 @@ class TestDocumentSegmentNavigationProperties:
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -392,10 +393,10 @@ class TestDocumentSegmentNavigationProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="test.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
@ -439,7 +440,7 @@ class TestDocumentSegmentNavigationProperties:
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -449,10 +450,10 @@ class TestDocumentSegmentNavigationProperties:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="batch_001",
|
||||
name="test.pdf",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from models.enums import CollectionBindingType
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
|
|
@ -32,7 +33,7 @@ class DatasetCollectionBindingTestDataFactory:
|
|||
provider_name: str = "openai",
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
collection_name: str = "collection-abc",
|
||||
collection_type: str = "dataset",
|
||||
collection_type: str = CollectionBindingType.DATASET,
|
||||
) -> DatasetCollectionBinding:
|
||||
"""
|
||||
Create a DatasetCollectionBinding with specified attributes.
|
||||
|
|
@ -41,7 +42,7 @@ class DatasetCollectionBindingTestDataFactory:
|
|||
provider_name: Name of the embedding model provider (e.g., "openai", "cohere")
|
||||
model_name: Name of the embedding model (e.g., "text-embedding-ada-002")
|
||||
collection_name: Name of the vector database collection
|
||||
collection_type: Type of collection (default: "dataset")
|
||||
collection_type: Type of collection (default: CollectionBindingType.DATASET)
|
||||
|
||||
Returns:
|
||||
DatasetCollectionBinding instance
|
||||
|
|
@ -76,7 +77,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
|||
# Arrange
|
||||
provider_name = "openai"
|
||||
model_name = "text-embedding-ada-002"
|
||||
collection_type = "dataset"
|
||||
collection_type = CollectionBindingType.DATASET
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name=provider_name,
|
||||
|
|
@ -104,7 +105,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
|||
# Arrange
|
||||
provider_name = f"provider-{uuid4()}"
|
||||
model_name = f"model-{uuid4()}"
|
||||
collection_type = "dataset"
|
||||
collection_type = CollectionBindingType.DATASET
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
|
|
@ -145,7 +146,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
|||
result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name)
|
||||
|
||||
# Assert
|
||||
assert result.type == "dataset"
|
||||
assert result.type == CollectionBindingType.DATASET
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
|
||||
|
|
@ -186,18 +187,20 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
|||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
collection_type="dataset",
|
||||
collection_type=CollectionBindingType.DATASET,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "dataset")
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
binding.id, CollectionBindingType.DATASET
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.id == binding.id
|
||||
assert result.provider_name == "openai"
|
||||
assert result.model_name == "text-embedding-ada-002"
|
||||
assert result.collection_name == "test-collection"
|
||||
assert result.type == "dataset"
|
||||
assert result.type == CollectionBindingType.DATASET
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session):
|
||||
"""Test error handling when collection binding is not found by ID and type."""
|
||||
|
|
@ -206,7 +209,9 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
|||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset")
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
non_existent_id, CollectionBindingType.DATASET
|
||||
)
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(
|
||||
self, db_session_with_containers: Session
|
||||
|
|
@ -240,7 +245,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
|||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
collection_type="dataset",
|
||||
collection_type=CollectionBindingType.DATASET,
|
||||
)
|
||||
|
||||
# Act
|
||||
|
|
@ -248,7 +253,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
|||
|
||||
# Assert
|
||||
assert result.id == binding.id
|
||||
assert result.type == "dataset"
|
||||
assert result.type == CollectionBindingType.DATASET
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session):
|
||||
"""Test error when binding exists but with wrong collection type."""
|
||||
|
|
@ -258,7 +263,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
|||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
collection_type="dataset",
|
||||
collection_type=CollectionBindingType.DATASET,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
|
||||
from models.enums import DataSourceType
|
||||
from models.model import App
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
|
@ -72,7 +73,7 @@ class DatasetUpdateDeleteTestDataFactory:
|
|||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="Test description",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import pytest
|
|||
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
|
|
@ -88,7 +88,7 @@ class DocumentStatusTestDataFactory:
|
|||
data_source_info=json.dumps(data_source_info or {}),
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
doc_form="text_model",
|
||||
)
|
||||
|
|
@ -100,7 +100,7 @@ class DocumentStatusTestDataFactory:
|
|||
document.paused_by = paused_by
|
||||
document.paused_at = paused_at
|
||||
document.doc_metadata = doc_metadata or {}
|
||||
if indexing_status == "completed" and "completed_at" not in kwargs:
|
||||
if indexing_status == IndexingStatus.COMPLETED and "completed_at" not in kwargs:
|
||||
document.completed_at = FIXED_TIME
|
||||
|
||||
for key, value in kwargs.items():
|
||||
|
|
@ -139,7 +139,7 @@ class DocumentStatusTestDataFactory:
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
dataset.id = dataset_id
|
||||
|
|
@ -291,7 +291,7 @@ class TestDocumentServicePauseDocument:
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="waiting",
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
|
|
@ -326,7 +326,7 @@ class TestDocumentServicePauseDocument:
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="indexing",
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
|
|
@ -354,7 +354,7 @@ class TestDocumentServicePauseDocument:
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="parsing",
|
||||
indexing_status=IndexingStatus.PARSING,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
|
|
@ -383,7 +383,7 @@ class TestDocumentServicePauseDocument:
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
|
|
@ -412,7 +412,7 @@ class TestDocumentServicePauseDocument:
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
|
|
@ -487,7 +487,7 @@ class TestDocumentServiceRecoverDocument:
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="indexing",
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
is_paused=True,
|
||||
paused_by=str(uuid4()),
|
||||
paused_at=paused_time,
|
||||
|
|
@ -526,7 +526,7 @@ class TestDocumentServiceRecoverDocument:
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="indexing",
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
is_paused=False,
|
||||
)
|
||||
|
||||
|
|
@ -609,7 +609,7 @@ class TestDocumentServiceRetryDocument:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
|
||||
mock_document_service_dependencies["redis_client"].get.return_value = None
|
||||
|
|
@ -619,7 +619,7 @@ class TestDocumentServiceRetryDocument:
|
|||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.indexing_status == "waiting"
|
||||
assert document.indexing_status == IndexingStatus.WAITING
|
||||
|
||||
expected_cache_key = f"document_{document.id}_is_retried"
|
||||
mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1)
|
||||
|
|
@ -646,14 +646,14 @@ class TestDocumentServiceRetryDocument:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
document2 = DocumentStatusTestDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
position=2,
|
||||
)
|
||||
|
||||
|
|
@ -665,8 +665,8 @@ class TestDocumentServiceRetryDocument:
|
|||
# Assert
|
||||
db_session_with_containers.refresh(document1)
|
||||
db_session_with_containers.refresh(document2)
|
||||
assert document1.indexing_status == "waiting"
|
||||
assert document2.indexing_status == "waiting"
|
||||
assert document1.indexing_status == IndexingStatus.WAITING
|
||||
assert document2.indexing_status == IndexingStatus.WAITING
|
||||
|
||||
mock_document_service_dependencies["retry_task"].delay.assert_called_once_with(
|
||||
dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"]
|
||||
|
|
@ -693,7 +693,7 @@ class TestDocumentServiceRetryDocument:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
|
||||
mock_document_service_dependencies["redis_client"].get.return_value = "1"
|
||||
|
|
@ -703,7 +703,7 @@ class TestDocumentServiceRetryDocument:
|
|||
DocumentService.retry_document(dataset.id, [document])
|
||||
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.indexing_status == "error"
|
||||
assert document.indexing_status == IndexingStatus.ERROR
|
||||
|
||||
def test_retry_document_missing_current_user_error(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
|
|
@ -726,7 +726,7 @@ class TestDocumentServiceRetryDocument:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="error",
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
|
||||
mock_document_service_dependencies["redis_client"].get.return_value = None
|
||||
|
|
@ -816,7 +816,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
|||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
enabled=False,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
document2 = DocumentStatusTestDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
|
|
@ -824,7 +824,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
|||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
enabled=False,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
position=2,
|
||||
)
|
||||
document_ids = [document1.id, document2.id]
|
||||
|
|
@ -866,7 +866,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
|||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
enabled=True,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
completed_at=FIXED_TIME,
|
||||
)
|
||||
document_ids = [document.id]
|
||||
|
|
@ -909,7 +909,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
|||
document_id=str(uuid4()),
|
||||
archived=False,
|
||||
enabled=True,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
document_ids = [document.id]
|
||||
|
||||
|
|
@ -951,7 +951,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
|||
document_id=str(uuid4()),
|
||||
archived=True,
|
||||
enabled=True,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
document_ids = [document.id]
|
||||
|
||||
|
|
@ -1015,7 +1015,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
document_id=str(uuid4()),
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
document_ids = [document.id]
|
||||
|
||||
|
|
@ -1098,7 +1098,7 @@ class TestDocumentServiceRenameDocument:
|
|||
document_id=document_id,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act
|
||||
|
|
@ -1139,7 +1139,7 @@ class TestDocumentServiceRenameDocument:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
doc_metadata={"existing_key": "existing_value"},
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act
|
||||
|
|
@ -1187,7 +1187,7 @@ class TestDocumentServiceRenameDocument:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act
|
||||
|
|
@ -1277,7 +1277,7 @@ class TestDocumentServiceRenameDocument:
|
|||
document_id=document_id,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=str(uuid4()),
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from models.dataset import (
|
|||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
)
|
||||
from models.enums import DataSourceType
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
|
@ -67,7 +68,7 @@ class DatasetPermissionTestDataFactory:
|
|||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline
|
||||
from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
|
|
@ -74,7 +75,7 @@ class DatasetServiceIntegrationDataFactory:
|
|||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique=indexing_technique,
|
||||
created_by=created_by,
|
||||
provider=provider,
|
||||
|
|
@ -98,13 +99,13 @@ class DatasetServiceIntegrationDataFactory:
|
|||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info='{"upload_file_id": "upload-file-id"}',
|
||||
batch=str(uuid4()),
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
@ -437,7 +438,7 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
|||
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
|
||||
assert created_dataset is not None
|
||||
assert created_dataset.name == entity.name
|
||||
assert created_dataset.runtime_mode == "rag_pipeline"
|
||||
assert created_dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE
|
||||
assert created_dataset.created_by == account.id
|
||||
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
|
||||
assert created_pipeline is not None
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import pytest
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
|
||||
|
|
@ -42,7 +43,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
name=name,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by or str(uuid4()),
|
||||
)
|
||||
if dataset_id:
|
||||
|
|
@ -72,11 +73,11 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
|||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info=json.dumps({"upload_file_id": str(uuid4())}),
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by or str(uuid4()),
|
||||
doc_form="text_model",
|
||||
)
|
||||
|
|
@ -85,7 +86,9 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
|||
document.archived = archived
|
||||
document.indexing_status = indexing_status
|
||||
document.completed_at = (
|
||||
completed_at if completed_at is not None else (FIXED_TIME if indexing_status == "completed" else None)
|
||||
completed_at
|
||||
if completed_at is not None
|
||||
else (FIXED_TIME if indexing_status == IndexingStatus.COMPLETED else None)
|
||||
)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
|
|
@ -243,7 +246,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
|||
dataset=dataset,
|
||||
document_ids=document_ids,
|
||||
enabled=True,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Act
|
||||
|
|
@ -277,7 +280,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
|||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=False,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
completed_at=FIXED_TIME,
|
||||
)
|
||||
|
||||
|
|
@ -306,7 +309,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
|||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=True,
|
||||
indexing_status="indexing",
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
completed_at=None,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from uuid import uuid4
|
|||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
|
|
@ -58,7 +59,7 @@ class DatasetDeleteIntegrationDataFactory:
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"dataset-{uuid4()}",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique=indexing_technique,
|
||||
index_struct=index_struct,
|
||||
created_by=created_by,
|
||||
|
|
@ -84,10 +85,10 @@ class DatasetDeleteIntegrationDataFactory:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=f"batch-{uuid4()}",
|
||||
name="Document",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
doc_form=doc_form,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
|
||||
|
|
@ -62,7 +63,7 @@ class SegmentServiceTestDataFactory:
|
|||
tenant_id=tenant_id,
|
||||
name=f"Test Dataset {uuid4()}",
|
||||
description="Test description",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
|
|
@ -82,10 +83,10 @@ class SegmentServiceTestDataFactory:
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=f"test-doc-{uuid4()}.txt",
|
||||
created_from="api",
|
||||
created_from=DocumentCreatedFrom.API,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from models.dataset import (
|
|||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
)
|
||||
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode
|
||||
from models.model import Tag, TagBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
|
|
@ -100,7 +101,7 @@ class DatasetRetrievalTestDataFactory:
|
|||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
|
|
@ -149,7 +150,7 @@ class DatasetRetrievalTestDataFactory:
|
|||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=content,
|
||||
source="web",
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by=created_by,
|
||||
|
|
@ -601,7 +602,7 @@ class TestDatasetServiceGetProcessRules:
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
created_by=account.id,
|
||||
mode="custom",
|
||||
mode=ProcessRuleMode.CUSTOM,
|
||||
rules=rules_data,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
|||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
from models.enums import DataSourceType
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
|
@ -64,7 +65,7 @@ class DatasetUpdateTestDataFactory:
|
|||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique=indexing_technique,
|
||||
created_by=created_by,
|
||||
provider=provider,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from uuid import uuid4
|
|||
from sqlalchemy import select
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
|
|
@ -11,7 +12,7 @@ def _create_dataset(db_session_with_containers) -> Dataset:
|
|||
dataset = Dataset(
|
||||
tenant_id=str(uuid4()),
|
||||
name=f"dataset-{uuid4()}",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
dataset.id = str(uuid4())
|
||||
|
|
@ -35,11 +36,11 @@ def _create_document(
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=position,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info="{}",
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=f"doc-{uuid4()}",
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=str(uuid4()),
|
||||
doc_form="text_model",
|
||||
)
|
||||
|
|
@ -48,7 +49,7 @@ def _create_document(
|
|||
document.enabled = enabled
|
||||
document.archived = archived
|
||||
document.is_paused = is_paused
|
||||
if indexing_status == "completed":
|
||||
if indexing_status == IndexingStatus.COMPLETED:
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
db_session_with_containers.add(document)
|
||||
|
|
@ -62,7 +63,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
position=1,
|
||||
|
|
@ -71,7 +72,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False,
|
||||
archived=False,
|
||||
position=2,
|
||||
|
|
@ -80,7 +81,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True,
|
||||
position=3,
|
||||
|
|
@ -101,14 +102,14 @@ def test_apply_display_status_filter_applies_when_status_present(db_session_with
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="waiting",
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
position=1,
|
||||
)
|
||||
_create_document(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
position=2,
|
||||
)
|
||||
|
||||
|
|
@ -125,14 +126,14 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c
|
|||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="waiting",
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
position=1,
|
||||
)
|
||||
doc2 = _create_document(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
position=2,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import pytest
|
|||
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, bu
|
|||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"dataset-{uuid4()}",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
dataset.id = dataset_id
|
||||
|
|
@ -62,11 +62,11 @@ def make_document(
|
|||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info=json.dumps(data_source_info or {}),
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=str(uuid4()),
|
||||
doc_form="text_model",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
|||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import DataSourceType
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
|
|
@ -287,7 +288,7 @@ class TestMessagesCleanServiceIntegration:
|
|||
dataset_name="Test dataset",
|
||||
document_id=str(uuid.uuid4()),
|
||||
document_name="Test document",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
segment_id=str(uuid.uuid4()),
|
||||
score=0.9,
|
||||
content="Test content",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
|||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
|
||||
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
|
@ -101,7 +102,7 @@ class TestMetadataService:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
built_in_field_enabled=False,
|
||||
)
|
||||
|
|
@ -132,11 +133,11 @@ class TestMetadataService:
|
|||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info="{}",
|
||||
batch="test-batch",
|
||||
name=fake.file_name(),
|
||||
created_from="web",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text",
|
||||
doc_language="en",
|
||||
|
|
@ -163,7 +164,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
|
@ -201,7 +202,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
long_name = "a" * 256 # 256 characters, exceeding 255 limit
|
||||
metadata_args = MetadataArgs(type="string", name=long_name)
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
|
||||
|
|
@ -226,11 +227,11 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create first metadata
|
||||
first_metadata_args = MetadataArgs(type="string", name="duplicate_name")
|
||||
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name")
|
||||
MetadataService.create_metadata(dataset.id, first_metadata_args)
|
||||
|
||||
# Try to create second metadata with same name
|
||||
second_metadata_args = MetadataArgs(type="number", name="duplicate_name")
|
||||
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError, match="Metadata name already exists."):
|
||||
|
|
@ -256,7 +257,7 @@ class TestMetadataService:
|
|||
|
||||
# Try to create metadata with built-in field name
|
||||
built_in_field_name = BuiltInField.document_name
|
||||
metadata_args = MetadataArgs(type="string", name=built_in_field_name)
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
|
||||
|
|
@ -281,7 +282,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata first
|
||||
metadata_args = MetadataArgs(type="string", name="old_name")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Act: Execute the method under test
|
||||
|
|
@ -318,7 +319,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata first
|
||||
metadata_args = MetadataArgs(type="string", name="old_name")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Try to update with too long name
|
||||
|
|
@ -347,10 +348,10 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create two metadata entries
|
||||
first_metadata_args = MetadataArgs(type="string", name="first_metadata")
|
||||
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata")
|
||||
first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args)
|
||||
|
||||
second_metadata_args = MetadataArgs(type="number", name="second_metadata")
|
||||
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata")
|
||||
second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args)
|
||||
|
||||
# Try to update first metadata with second metadata's name
|
||||
|
|
@ -376,7 +377,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata first
|
||||
metadata_args = MetadataArgs(type="string", name="old_name")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Try to update with built-in field name
|
||||
|
|
@ -432,7 +433,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata first
|
||||
metadata_args = MetadataArgs(type="string", name="to_be_deleted")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Act: Execute the method under test
|
||||
|
|
@ -496,7 +497,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Create metadata binding
|
||||
|
|
@ -798,7 +799,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Mock DocumentService.get_document
|
||||
|
|
@ -866,7 +867,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Mock DocumentService.get_document
|
||||
|
|
@ -917,7 +918,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Create metadata operation data
|
||||
|
|
@ -1038,7 +1039,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Create document and metadata binding
|
||||
|
|
@ -1101,7 +1102,7 @@ class TestMetadataService:
|
|||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
# Act: Execute the method under test
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from models.enums import DataSourceType
|
||||
from models.model import App, Tag, TagBinding
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
|
@ -100,7 +101,7 @@ class TestTagService:
|
|||
description=fake.text(max_nb_chars=100),
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
data_source_type="upload",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
tenant_id=tenant_id,
|
||||
created_by=mock_external_service_dependencies["current_user"].id,
|
||||
|
|
|
|||
|
|
@ -510,7 +510,7 @@ class TestWorkflowConverter:
|
|||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=10,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"provider": "cohere", "model": "rerank-v2"},
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
|
@ -543,8 +543,8 @@ class TestWorkflowConverter:
|
|||
multiple_config = node["data"]["multiple_retrieval_config"]
|
||||
assert multiple_config["top_k"] == 10
|
||||
assert multiple_config["score_threshold"] == 0.8
|
||||
assert multiple_config["reranking_model"]["provider"] == "cohere"
|
||||
assert multiple_config["reranking_model"]["model"] == "rerank-v2"
|
||||
assert multiple_config["reranking_model"]["reranking_provider_name"] == "cohere"
|
||||
assert multiple_config["reranking_model"]["reranking_model_name"] == "rerank-v2"
|
||||
|
||||
# Verify single retrieval config is None for multiple strategy
|
||||
assert node["data"]["single_retrieval_config"] is None
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType
|
|||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
|
||||
|
||||
|
|
@ -79,7 +80,7 @@ class TestAddDocumentToIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -92,12 +93,12 @@ class TestAddDocumentToIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
|
@ -137,7 +138,7 @@ class TestAddDocumentToIndexTask:
|
|||
index_node_id=f"node_{i}",
|
||||
index_node_hash=f"hash_{i}",
|
||||
enabled=False,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -297,7 +298,7 @@ class TestAddDocumentToIndexTask:
|
|||
)
|
||||
|
||||
# Set invalid indexing status
|
||||
document.indexing_status = "processing"
|
||||
document.indexing_status = IndexingStatus.INDEXING
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -339,7 +340,7 @@ class TestAddDocumentToIndexTask:
|
|||
# Assert: Verify error handling
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.enabled is False
|
||||
assert document.indexing_status == "error"
|
||||
assert document.indexing_status == IndexingStatus.ERROR
|
||||
assert document.error is not None
|
||||
assert "doesn't exist" in document.error
|
||||
assert document.disabled_at is not None
|
||||
|
|
@ -434,7 +435,7 @@ class TestAddDocumentToIndexTask:
|
|||
Test document indexing when segments are already enabled.
|
||||
|
||||
This test verifies:
|
||||
- Segments with status="completed" are processed regardless of enabled status
|
||||
- Segments with status=SegmentStatus.COMPLETED are processed regardless of enabled status
|
||||
- Index processing occurs with all completed segments
|
||||
- Auto disable log deletion still occurs
|
||||
- Redis cache is cleared
|
||||
|
|
@ -460,7 +461,7 @@ class TestAddDocumentToIndexTask:
|
|||
index_node_id=f"node_{i}",
|
||||
index_node_hash=f"hash_{i}",
|
||||
enabled=True, # Already enabled
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -482,7 +483,7 @@ class TestAddDocumentToIndexTask:
|
|||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||
|
||||
# Verify the load method was called with all completed segments
|
||||
# (implementation doesn't filter by enabled status, only by status="completed")
|
||||
# (implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED)
|
||||
call_args = mock_external_service_dependencies["index_processor"].load.call_args
|
||||
assert call_args is not None
|
||||
documents = call_args[0][1] # Second argument should be documents list
|
||||
|
|
@ -594,7 +595,7 @@ class TestAddDocumentToIndexTask:
|
|||
# Assert: Verify error handling
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.enabled is False
|
||||
assert document.indexing_status == "error"
|
||||
assert document.indexing_status == IndexingStatus.ERROR
|
||||
assert document.error is not None
|
||||
assert "Index processing failed" in document.error
|
||||
assert document.disabled_at is not None
|
||||
|
|
@ -614,7 +615,7 @@ class TestAddDocumentToIndexTask:
|
|||
Test segment filtering with various edge cases.
|
||||
|
||||
This test verifies:
|
||||
- Only segments with status="completed" are processed (regardless of enabled status)
|
||||
- Only segments with status=SegmentStatus.COMPLETED are processed (regardless of enabled status)
|
||||
- Segments with status!="completed" are NOT processed
|
||||
- Segments are ordered by position correctly
|
||||
- Mixed segment states are handled properly
|
||||
|
|
@ -630,7 +631,7 @@ class TestAddDocumentToIndexTask:
|
|||
fake = Faker()
|
||||
segments = []
|
||||
|
||||
# Segment 1: Should be processed (enabled=False, status="completed")
|
||||
# Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
|
||||
segment1 = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
|
|
@ -643,14 +644,14 @@ class TestAddDocumentToIndexTask:
|
|||
index_node_id="node_0",
|
||||
index_node_hash="hash_0",
|
||||
enabled=False,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db_session_with_containers.add(segment1)
|
||||
segments.append(segment1)
|
||||
|
||||
# Segment 2: Should be processed (enabled=True, status="completed")
|
||||
# Note: Implementation doesn't filter by enabled status, only by status="completed"
|
||||
# Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED)
|
||||
# Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED
|
||||
segment2 = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
|
|
@ -663,7 +664,7 @@ class TestAddDocumentToIndexTask:
|
|||
index_node_id="node_1",
|
||||
index_node_hash="hash_1",
|
||||
enabled=True, # Already enabled, but will still be processed
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db_session_with_containers.add(segment2)
|
||||
|
|
@ -682,13 +683,13 @@ class TestAddDocumentToIndexTask:
|
|||
index_node_id="node_2",
|
||||
index_node_hash="hash_2",
|
||||
enabled=False,
|
||||
status="processing", # Not completed
|
||||
status=SegmentStatus.INDEXING, # Not completed
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db_session_with_containers.add(segment3)
|
||||
segments.append(segment3)
|
||||
|
||||
# Segment 4: Should be processed (enabled=False, status="completed")
|
||||
# Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
|
||||
segment4 = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
|
|
@ -701,7 +702,7 @@ class TestAddDocumentToIndexTask:
|
|||
index_node_id="node_3",
|
||||
index_node_hash="hash_3",
|
||||
enabled=False,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db_session_with_containers.add(segment4)
|
||||
|
|
@ -726,7 +727,7 @@ class TestAddDocumentToIndexTask:
|
|||
call_args = mock_external_service_dependencies["index_processor"].load.call_args
|
||||
assert call_args is not None
|
||||
documents = call_args[0][1] # Second argument should be documents list
|
||||
assert len(documents) == 3 # 3 segments with status="completed" should be processed
|
||||
assert len(documents) == 3 # 3 segments with status=SegmentStatus.COMPLETED should be processed
|
||||
|
||||
# Verify correct segments were processed (by position order)
|
||||
# Segments 1, 2, 4 should be processed (positions 0, 1, 3)
|
||||
|
|
@ -799,7 +800,7 @@ class TestAddDocumentToIndexTask:
|
|||
# Assert: Verify consistent error handling
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.enabled is False, f"Document should be disabled for {error_name}"
|
||||
assert document.indexing_status == "error", f"Document status should be error for {error_name}"
|
||||
assert document.indexing_status == IndexingStatus.ERROR, f"Document status should be error for {error_name}"
|
||||
assert document.error is not None, f"Error should be recorded for {error_name}"
|
||||
assert str(exception) in document.error, f"Error message should contain exception for {error_name}"
|
||||
assert document.disabled_at is not None, f"Disabled timestamp should be set for {error_name}"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from models.model import UploadFile
|
||||
from tasks.batch_clean_document_task import batch_clean_document_task
|
||||
|
||||
|
|
@ -113,7 +114,7 @@ class TestBatchCleanDocumentTask:
|
|||
tenant_id=account.current_tenant.id,
|
||||
name=fake.word(),
|
||||
description=fake.sentence(),
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
|
|
@ -144,12 +145,12 @@ class TestBatchCleanDocumentTask:
|
|||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
name=fake.word(),
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info=json.dumps({"upload_file_id": str(uuid.uuid4())}),
|
||||
batch="test_batch",
|
||||
created_from="test",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
)
|
||||
|
||||
|
|
@ -183,7 +184,7 @@ class TestBatchCleanDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -297,7 +298,7 @@ class TestBatchCleanDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -671,7 +672,7 @@ class TestBatchCleanDocumentTask:
|
|||
tokens=25 + i * 5,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
segments.append(segment)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from models.model import UploadFile
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(),
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
|
|
@ -170,12 +170,12 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
|
|
@ -301,7 +301,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
assert segment.dataset_id == dataset.id
|
||||
assert segment.document_id == document.id
|
||||
assert segment.position == i + 1
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
assert segment.answer is None # text_model doesn't have answers
|
||||
|
|
@ -442,12 +442,12 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
name="disabled_document",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # Document is disabled
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
|
|
@ -458,12 +458,12 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=2,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
name="archived_document",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # Document is archived
|
||||
doc_form="text_model",
|
||||
|
|
@ -474,12 +474,12 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=3,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
name="incomplete_document",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status="indexing", # Not completed
|
||||
indexing_status=IndexingStatus.INDEXING, # Not completed
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
|
|
@ -643,7 +643,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
word_count=len(f"Existing segment {i + 1}"),
|
||||
tokens=10,
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
index_node_hash=f"hash_{i}",
|
||||
)
|
||||
|
|
@ -694,7 +694,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
for i, segment in enumerate(new_segments):
|
||||
expected_position = 4 + i # Should start at position 4
|
||||
assert segment.position == expected_position
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,14 @@ from models.dataset import (
|
|||
Document,
|
||||
DocumentSegment,
|
||||
)
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import (
|
||||
CreatorUserRole,
|
||||
DatasetMetadataType,
|
||||
DataSourceType,
|
||||
DocumentCreatedFrom,
|
||||
IndexingStatus,
|
||||
SegmentStatus,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from tasks.clean_dataset_task import clean_dataset_task
|
||||
|
||||
|
|
@ -176,12 +183,12 @@ class TestCleanDatasetTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
name="test_document",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="paragraph_index",
|
||||
|
|
@ -219,7 +226,7 @@ class TestCleanDatasetTask:
|
|||
word_count=20,
|
||||
tokens=30,
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
index_node_hash="test_hash",
|
||||
created_at=datetime.now(),
|
||||
|
|
@ -373,7 +380,7 @@ class TestCleanDatasetTask:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
name="test_metadata",
|
||||
type="string",
|
||||
type=DatasetMetadataType.STRING,
|
||||
created_by=account.id,
|
||||
)
|
||||
metadata.id = str(uuid.uuid4())
|
||||
|
|
@ -587,7 +594,7 @@ class TestCleanDatasetTask:
|
|||
word_count=len(segment_content),
|
||||
tokens=50,
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
index_node_hash="test_hash",
|
||||
created_at=datetime.now(),
|
||||
|
|
@ -686,7 +693,7 @@ class TestCleanDatasetTask:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
name=f"test_metadata_{i}",
|
||||
type="string",
|
||||
type=DatasetMetadataType.STRING,
|
||||
created_by=account.id,
|
||||
)
|
||||
metadata.id = str(uuid.uuid4())
|
||||
|
|
@ -880,11 +887,11 @@ class TestCleanDatasetTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info="{}",
|
||||
batch="test_batch",
|
||||
name=f"test_doc_{special_content}",
|
||||
created_from="test",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
|
|
@ -905,7 +912,7 @@ class TestCleanDatasetTask:
|
|||
word_count=len(segment_content.split()),
|
||||
tokens=len(segment_content) // 4, # Rough token estimation
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
index_node_hash="test_hash_" + "x" * 50, # Long hash within limits
|
||||
created_at=datetime.now(),
|
||||
|
|
@ -946,7 +953,7 @@ class TestCleanDatasetTask:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
name=f"metadata_{special_content}",
|
||||
type="string",
|
||||
type=DatasetMetadataType.STRING,
|
||||
created_by=account.id,
|
||||
)
|
||||
special_metadata.id = str(uuid.uuid4())
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import pytest
|
|||
from faker import Faker
|
||||
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
|
@ -88,7 +89,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -105,17 +106,17 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name=f"Notion Page {i}",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model", # Set doc_form to ensure dataset.doc_form works
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -134,7 +135,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id=f"node_{i}_{j}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
|
|
@ -220,7 +221,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -269,7 +270,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=f"{fake.company()}_{index_type}",
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -281,17 +282,17 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name="Test Notion Page",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form=index_type,
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -308,7 +309,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id="test_node",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -357,7 +358,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -369,16 +370,16 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name="Test Notion Page",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -397,7 +398,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id=None, # No index node ID
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
|
|
@ -443,7 +444,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -460,16 +461,16 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name=f"Notion Page {i}",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -488,7 +489,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id=f"node_{i}_{j}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
all_segments.append(segment)
|
||||
|
|
@ -558,7 +559,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -570,22 +571,22 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name="Test Notion Page",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create segments with different statuses
|
||||
segment_statuses = ["waiting", "processing", "completed", "error"]
|
||||
segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR]
|
||||
segments = []
|
||||
index_node_ids = []
|
||||
|
||||
|
|
@ -654,7 +655,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -666,16 +667,16 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name="Test Notion Page",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -692,7 +693,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id="test_node",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -736,7 +737,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -754,16 +755,16 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name=f"Notion Page {i}",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -783,7 +784,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id=f"node_{i}_{j}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
all_segments.append(segment)
|
||||
|
|
@ -848,7 +849,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=f"{fake.company()}_{i}",
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -866,16 +867,16 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=account.current_tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name=f"Notion Page {i}",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
|
@ -894,7 +895,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id=f"node_{i}_{j}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
all_segments.append(segment)
|
||||
|
|
@ -963,14 +964,22 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create documents with different indexing statuses
|
||||
document_statuses = ["waiting", "parsing", "cleaning", "splitting", "indexing", "completed", "error"]
|
||||
document_statuses = [
|
||||
IndexingStatus.WAITING,
|
||||
IndexingStatus.PARSING,
|
||||
IndexingStatus.CLEANING,
|
||||
IndexingStatus.SPLITTING,
|
||||
IndexingStatus.INDEXING,
|
||||
IndexingStatus.COMPLETED,
|
||||
IndexingStatus.ERROR,
|
||||
]
|
||||
documents = []
|
||||
all_segments = []
|
||||
all_index_node_ids = []
|
||||
|
|
@ -981,13 +990,13 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"}
|
||||
),
|
||||
batch="test_batch",
|
||||
name=f"Notion Page {i}",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_language="en",
|
||||
indexing_status=status,
|
||||
|
|
@ -1009,7 +1018,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=50,
|
||||
index_node_id=f"node_{i}_{j}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
all_segments.append(segment)
|
||||
|
|
@ -1066,7 +1075,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
created_by=account.id,
|
||||
built_in_field_enabled=True,
|
||||
)
|
||||
|
|
@ -1079,7 +1088,7 @@ class TestCleanNotionDocumentTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="notion_import",
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
data_source_info=json.dumps(
|
||||
{
|
||||
"notion_workspace_id": "workspace_test",
|
||||
|
|
@ -1091,10 +1100,10 @@ class TestCleanNotionDocumentTask:
|
|||
),
|
||||
batch="test_batch",
|
||||
name="Test Notion Page with Metadata",
|
||||
created_from="notion_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_metadata={
|
||||
"document_name": "Test Notion Page with Metadata",
|
||||
"uploader": account.name,
|
||||
|
|
@ -1122,7 +1131,7 @@ class TestCleanNotionDocumentTask:
|
|||
tokens=75,
|
||||
index_node_id=f"node_{i}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
keywords={"key1": ["value1", "value2"], "key2": ["value3"]},
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from faker import Faker
|
|||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from tasks.create_segment_to_index_task import create_segment_to_index_task
|
||||
|
||||
|
||||
|
|
@ -118,7 +119,7 @@ class TestCreateSegmentToIndexTask:
|
|||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
tenant_id=tenant_id,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
|
|
@ -133,13 +134,13 @@ class TestCreateSegmentToIndexTask:
|
|||
dataset_id=dataset.id,
|
||||
tenant_id=tenant_id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account_id,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="qa_model",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
@ -148,7 +149,7 @@ class TestCreateSegmentToIndexTask:
|
|||
return dataset, document
|
||||
|
||||
def _create_test_segment(
|
||||
self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting"
|
||||
self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING
|
||||
):
|
||||
"""
|
||||
Helper method to create a test document segment for testing.
|
||||
|
|
@ -200,7 +201,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -208,7 +209,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify segment status changes
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
assert segment.error is None
|
||||
|
|
@ -257,7 +258,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.COMPLETED
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -268,7 +269,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Verify segment status unchanged
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is None
|
||||
|
||||
# Verify no index processor calls were made
|
||||
|
|
@ -293,20 +294,25 @@ class TestCreateSegmentToIndexTask:
|
|||
dataset_id=invalid_dataset_id,
|
||||
tenant_id=tenant.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers,
|
||||
invalid_dataset_id,
|
||||
document.id,
|
||||
tenant.id,
|
||||
account.id,
|
||||
status=SegmentStatus.WAITING,
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -317,7 +323,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Verify segment status changed to indexing (task updates status before checking document)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "indexing"
|
||||
assert segment.status == SegmentStatus.INDEXING
|
||||
|
||||
# Verify no index processor calls were made
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
|
||||
|
|
@ -337,7 +343,12 @@ class TestCreateSegmentToIndexTask:
|
|||
invalid_document_id = str(uuid4())
|
||||
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers,
|
||||
dataset.id,
|
||||
invalid_document_id,
|
||||
tenant.id,
|
||||
account.id,
|
||||
status=SegmentStatus.WAITING,
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -348,7 +359,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Verify segment status changed to indexing (task updates status before checking document)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "indexing"
|
||||
assert segment.status == SegmentStatus.INDEXING
|
||||
|
||||
# Verify no index processor calls were made
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
|
||||
|
|
@ -373,7 +384,7 @@ class TestCreateSegmentToIndexTask:
|
|||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -384,7 +395,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Verify segment status changed to indexing (task updates status before checking document)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "indexing"
|
||||
assert segment.status == SegmentStatus.INDEXING
|
||||
|
||||
# Verify no index processor calls were made
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
|
||||
|
|
@ -409,7 +420,7 @@ class TestCreateSegmentToIndexTask:
|
|||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -420,7 +431,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Verify segment status changed to indexing (task updates status before checking document)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "indexing"
|
||||
assert segment.status == SegmentStatus.INDEXING
|
||||
|
||||
# Verify no index processor calls were made
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
|
||||
|
|
@ -445,7 +456,7 @@ class TestCreateSegmentToIndexTask:
|
|||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -456,7 +467,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Verify segment status changed to indexing (task updates status before checking document)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "indexing"
|
||||
assert segment.status == SegmentStatus.INDEXING
|
||||
|
||||
# Verify no index processor calls were made
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
|
||||
|
|
@ -477,7 +488,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Mock processor to raise exception
|
||||
|
|
@ -488,7 +499,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify error handling
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "error"
|
||||
assert segment.status == SegmentStatus.ERROR
|
||||
assert segment.enabled is False
|
||||
assert segment.disabled_at is not None
|
||||
assert segment.error == "Processor failed"
|
||||
|
|
@ -512,7 +523,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
custom_keywords = ["custom", "keywords", "test"]
|
||||
|
||||
|
|
@ -521,7 +532,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify successful indexing
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -555,7 +566,7 @@ class TestCreateSegmentToIndexTask:
|
|||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -563,7 +574,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify successful indexing
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
|
||||
# Verify correct doc_form was passed to factory
|
||||
mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form)
|
||||
|
|
@ -583,7 +594,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute the task and measure time
|
||||
|
|
@ -597,7 +608,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Verify successful completion
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
|
||||
def test_create_segment_to_index_concurrent_execution(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
|
|
@ -617,7 +628,7 @@ class TestCreateSegmentToIndexTask:
|
|||
segments = []
|
||||
for i in range(3):
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
segments.append(segment)
|
||||
|
||||
|
|
@ -629,7 +640,7 @@ class TestCreateSegmentToIndexTask:
|
|||
# Assert: Verify all segments processed
|
||||
for segment in segments:
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -665,7 +676,7 @@ class TestCreateSegmentToIndexTask:
|
|||
keywords=["large", "content", "test"],
|
||||
index_node_id=str(uuid4()),
|
||||
index_node_hash=str(uuid4()),
|
||||
status="waiting",
|
||||
status=SegmentStatus.WAITING,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -681,7 +692,7 @@ class TestCreateSegmentToIndexTask:
|
|||
assert execution_time < 10.0 # Should complete within 10 seconds
|
||||
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -700,7 +711,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Set up Redis cache key to simulate indexing in progress
|
||||
|
|
@ -718,7 +729,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify indexing still completed successfully despite Redis failure
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -740,7 +751,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Simulate an error during indexing to trigger rollback path
|
||||
|
|
@ -752,7 +763,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify error handling and rollback
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "error"
|
||||
assert segment.status == SegmentStatus.ERROR
|
||||
assert segment.enabled is False
|
||||
assert segment.disabled_at is not None
|
||||
assert segment.error is not None
|
||||
|
|
@ -772,7 +783,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
@ -780,7 +791,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify successful indexing
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
|
||||
# Verify index processor was called with correct metadata
|
||||
mock_processor = mock_external_service_dependencies["index_processor"]
|
||||
|
|
@ -814,11 +825,11 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Verify initial state
|
||||
assert segment.status == "waiting"
|
||||
assert segment.status == SegmentStatus.WAITING
|
||||
assert segment.indexing_at is None
|
||||
assert segment.completed_at is None
|
||||
|
||||
|
|
@ -827,7 +838,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify final state
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -861,7 +872,7 @@ class TestCreateSegmentToIndexTask:
|
|||
keywords=[],
|
||||
index_node_id=str(uuid4()),
|
||||
index_node_hash=str(uuid4()),
|
||||
status="waiting",
|
||||
status=SegmentStatus.WAITING,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -872,7 +883,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify successful indexing
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -907,7 +918,7 @@ class TestCreateSegmentToIndexTask:
|
|||
keywords=["special", "unicode", "test"],
|
||||
index_node_id=str(uuid4()),
|
||||
index_node_hash=str(uuid4()),
|
||||
status="waiting",
|
||||
status=SegmentStatus.WAITING,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -918,7 +929,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify successful indexing
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -937,7 +948,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Create long keyword list
|
||||
|
|
@ -948,7 +959,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify successful indexing
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -979,10 +990,10 @@ class TestCreateSegmentToIndexTask:
|
|||
)
|
||||
|
||||
segment1 = self._create_test_segment(
|
||||
db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting"
|
||||
db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
segment2 = self._create_test_segment(
|
||||
db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting"
|
||||
db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute tasks for both tenants
|
||||
|
|
@ -993,8 +1004,8 @@ class TestCreateSegmentToIndexTask:
|
|||
db_session_with_containers.refresh(segment1)
|
||||
db_session_with_containers.refresh(segment2)
|
||||
|
||||
assert segment1.status == "completed"
|
||||
assert segment2.status == "completed"
|
||||
assert segment1.status == SegmentStatus.COMPLETED
|
||||
assert segment2.status == SegmentStatus.COMPLETED
|
||||
assert segment1.tenant_id == tenant1.id
|
||||
assert segment2.tenant_id == tenant2.id
|
||||
assert segment1.tenant_id != segment2.tenant_id
|
||||
|
|
@ -1014,7 +1025,7 @@ class TestCreateSegmentToIndexTask:
|
|||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
|
||||
# Act: Execute the task with None keywords
|
||||
|
|
@ -1022,7 +1033,7 @@ class TestCreateSegmentToIndexTask:
|
|||
|
||||
# Assert: Verify successful indexing
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
|
||||
|
|
@ -1050,7 +1061,7 @@ class TestCreateSegmentToIndexTask:
|
|||
segments = []
|
||||
for i in range(5):
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING
|
||||
)
|
||||
segments.append(segment)
|
||||
|
||||
|
|
@ -1067,7 +1078,7 @@ class TestCreateSegmentToIndexTask:
|
|||
# Verify all segments processed successfully
|
||||
for segment in segments:
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.status == "completed"
|
||||
assert segment.status == SegmentStatus.COMPLETED
|
||||
assert segment.indexing_at is not None
|
||||
assert segment.completed_at is not None
|
||||
assert segment.error is None
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.indexing_runner import DocumentIsPausedError
|
|||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from tasks.document_indexing_task import (
|
||||
_document_indexing,
|
||||
_document_indexing_with_tenant_queue,
|
||||
|
|
@ -139,7 +140,7 @@ class TestDatasetIndexingTaskIntegration:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -155,12 +156,12 @@ class TestDatasetIndexingTaskIntegration:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch="test_batch",
|
||||
name=f"doc-{position}.txt",
|
||||
created_from="upload_file",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status="waiting",
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
@ -181,7 +182,7 @@ class TestDatasetIndexingTaskIntegration:
|
|||
for document_id in document_ids:
|
||||
updated = self._query_document(db_session_with_containers, document_id)
|
||||
assert updated is not None
|
||||
assert updated.indexing_status == "parsing"
|
||||
assert updated.indexing_status == IndexingStatus.PARSING
|
||||
assert updated.processing_started_at is not None
|
||||
|
||||
def _assert_documents_error_contains(
|
||||
|
|
@ -195,7 +196,7 @@ class TestDatasetIndexingTaskIntegration:
|
|||
for document_id in document_ids:
|
||||
updated = self._query_document(db_session_with_containers, document_id)
|
||||
assert updated is not None
|
||||
assert updated.indexing_status == "error"
|
||||
assert updated.indexing_status == IndexingStatus.ERROR
|
||||
assert updated.error is not None
|
||||
assert expected_error_substring in updated.error
|
||||
assert updated.stopped_at is not None
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import pytest
|
|||
from faker import Faker
|
||||
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
|
@ -90,7 +91,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -102,13 +103,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -150,7 +151,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -162,13 +163,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -182,13 +183,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Test Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -209,7 +210,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -220,7 +221,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify document status was updated to indexing then completed
|
||||
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor load method was called
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
|
|
@ -251,7 +252,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -263,13 +264,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="parent_child_index",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -283,13 +284,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Test Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="parent_child_index",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -310,7 +311,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -321,7 +322,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify document status was updated to indexing then completed
|
||||
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor clean and load methods were called
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
|
|
@ -367,7 +368,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -399,7 +400,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -411,13 +412,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Test Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -430,7 +431,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify document status was updated to indexing then completed
|
||||
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify that no index processor load was called since no segments exist
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
|
|
@ -455,7 +456,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -488,7 +489,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -500,13 +501,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -520,13 +521,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Test Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -547,7 +548,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -563,7 +564,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify document status was updated to error
|
||||
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.indexing_status == IndexingStatus.ERROR
|
||||
assert "Test exception during indexing" in updated_document.error
|
||||
|
||||
def test_deal_dataset_vector_index_task_with_custom_index_type(
|
||||
|
|
@ -584,7 +585,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -596,13 +597,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Test Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="qa_index",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -623,7 +624,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -634,7 +635,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify document status was updated to indexing then completed
|
||||
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor was initialized with custom index type
|
||||
mock_index_processor_factory.assert_called_once_with("qa_index")
|
||||
|
|
@ -660,7 +661,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -672,13 +673,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Test Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -699,7 +700,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -710,7 +711,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify document status was updated to indexing then completed
|
||||
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor was initialized with the document's index type
|
||||
mock_index_processor_factory.assert_called_once_with("text_model")
|
||||
|
|
@ -736,7 +737,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -748,13 +749,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -770,13 +771,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name=f"Test Document {i}",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -801,7 +802,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{i}_{j}",
|
||||
index_node_hash=f"hash_{i}_{j}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -814,7 +815,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
# Verify all documents were processed
|
||||
for document in documents:
|
||||
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor load was called multiple times
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
|
|
@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -851,13 +852,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -871,13 +872,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Test Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -898,7 +899,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -916,7 +917,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify final document status
|
||||
updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first()
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
def test_deal_dataset_vector_index_task_with_disabled_documents(
|
||||
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
|
||||
|
|
@ -936,7 +937,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -948,13 +949,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -968,13 +969,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Enabled Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -987,13 +988,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Disabled Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # This document should be skipped
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -1015,7 +1016,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -1026,13 +1027,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify only enabled document was processed
|
||||
updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first()
|
||||
assert updated_enabled_document.indexing_status == "completed"
|
||||
assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify disabled document status remains unchanged
|
||||
updated_disabled_document = (
|
||||
db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first()
|
||||
)
|
||||
assert updated_disabled_document.indexing_status == "completed" # Should not change
|
||||
assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change
|
||||
|
||||
# Verify index processor load was called only once (for enabled document)
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
|
|
@ -1057,7 +1058,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -1069,13 +1070,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -1089,13 +1090,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Active Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -1108,13 +1109,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Archived Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # This document should be skipped
|
||||
batch="test_batch",
|
||||
|
|
@ -1136,7 +1137,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -1147,13 +1148,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
|
||||
# Verify only active document was processed
|
||||
updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first()
|
||||
assert updated_active_document.indexing_status == "completed"
|
||||
assert updated_active_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify archived document status remains unchanged
|
||||
updated_archived_document = (
|
||||
db_session_with_containers.query(Document).filter_by(id=archived_document.id).first()
|
||||
)
|
||||
assert updated_archived_document.indexing_status == "completed" # Should not change
|
||||
assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change
|
||||
|
||||
# Verify index processor load was called only once (for active document)
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
|
|
@ -1178,7 +1179,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
|
@ -1190,13 +1191,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Document for doc_form",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -1210,13 +1211,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Completed Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -1229,13 +1230,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="file_import",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
name="Incomplete Document",
|
||||
created_from="file_import",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="indexing", # This document should be skipped
|
||||
indexing_status=IndexingStatus.INDEXING, # This document should be skipped
|
||||
enabled=True,
|
||||
archived=False,
|
||||
batch="test_batch",
|
||||
|
|
@ -1257,7 +1258,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
index_node_id=f"node_{uuid.uuid4()}",
|
||||
index_node_hash=f"hash_{uuid.uuid4()}",
|
||||
created_by=account.id,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
|
|
@ -1270,13 +1271,13 @@ class TestDealDatasetVectorIndexTask:
|
|||
updated_completed_document = (
|
||||
db_session_with_containers.query(Document).filter_by(id=completed_document.id).first()
|
||||
)
|
||||
assert updated_completed_document.indexing_status == "completed"
|
||||
assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify incomplete document status remains unchanged
|
||||
updated_incomplete_document = (
|
||||
db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first()
|
||||
)
|
||||
assert updated_incomplete_document.indexing_status == "indexing" # Should not change
|
||||
assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change
|
||||
|
||||
# Verify index processor load was called only once (for completed document)
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from faker import Faker
|
|||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models import Account, Dataset, Document, DocumentSegment, Tenant
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -106,7 +107,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||
dataset.description = fake.text(max_nb_chars=200)
|
||||
dataset.provider = "vendor"
|
||||
dataset.permission = "only_me"
|
||||
dataset.data_source_type = "upload_file"
|
||||
dataset.data_source_type = DataSourceType.UPLOAD_FILE
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.index_struct = '{"type": "paragraph"}'
|
||||
dataset.created_by = account.id
|
||||
|
|
@ -145,7 +146,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||
document.data_source_info = kwargs.get("data_source_info", "{}")
|
||||
document.batch = kwargs.get("batch", fake.uuid4())
|
||||
document.name = kwargs.get("name", f"Test Document {fake.word()}")
|
||||
document.created_from = kwargs.get("created_from", "api")
|
||||
document.created_from = kwargs.get("created_from", DocumentCreatedFrom.API)
|
||||
document.created_by = account.id
|
||||
document.created_at = fake.date_time_this_year()
|
||||
document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year())
|
||||
|
|
@ -162,7 +163,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||
document.enabled = kwargs.get("enabled", True)
|
||||
document.archived = kwargs.get("archived", False)
|
||||
document.updated_at = fake.date_time_this_year()
|
||||
document.doc_type = kwargs.get("doc_type", "text")
|
||||
document.doc_type = kwargs.get("doc_type", DocumentDocType.PERSONAL_DOCUMENT)
|
||||
document.doc_metadata = kwargs.get("doc_metadata", {})
|
||||
document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX)
|
||||
document.doc_language = kwargs.get("doc_language", "en")
|
||||
|
|
@ -204,7 +205,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||
segment.index_node_hash = fake.sha256()
|
||||
segment.hit_count = 0
|
||||
segment.enabled = True
|
||||
segment.status = "completed"
|
||||
segment.status = SegmentStatus.COMPLETED
|
||||
segment.created_by = account.id
|
||||
segment.created_at = fake.date_time_this_year()
|
||||
segment.updated_by = account.id
|
||||
|
|
@ -386,7 +387,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||
account = self._create_test_account(db_session_with_containers, tenant, fake)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake)
|
||||
document = self._create_test_document(
|
||||
db_session_with_containers, dataset, account, fake, indexing_status="indexing"
|
||||
db_session_with_containers, dataset, account, fake, indexing_status=IndexingStatus.INDEXING
|
||||
)
|
||||
segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
|
|||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -97,7 +98,7 @@ class TestDisableSegmentFromIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
name=fake.sentence(nb_words=3),
|
||||
description=fake.text(max_nb_chars=200),
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -132,12 +133,12 @@ class TestDisableSegmentFromIndexTask:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
batch=fake.uuid4(),
|
||||
name=fake.file_name(),
|
||||
created_from="api",
|
||||
created_from=DocumentCreatedFrom.API,
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form=doc_form,
|
||||
|
|
@ -189,7 +190,7 @@ class TestDisableSegmentFromIndexTask:
|
|||
status=status,
|
||||
enabled=enabled,
|
||||
created_by=account.id,
|
||||
completed_at=datetime.now(UTC) if status == "completed" else None,
|
||||
completed_at=datetime.now(UTC) if status == SegmentStatus.COMPLETED else None,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -271,7 +272,7 @@ class TestDisableSegmentFromIndexTask:
|
|||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True
|
||||
db_session_with_containers, document, dataset, tenant, account, status=SegmentStatus.INDEXING, enabled=True
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
|
|||
from models import Account, Dataset, DocumentSegment
|
||||
from models import Document as DatasetDocument
|
||||
from models.dataset import DatasetProcessRule
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus
|
||||
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
|
||||
|
||||
|
||||
|
|
@ -100,7 +101,7 @@ class TestDisableSegmentsFromIndexTask:
|
|||
description=fake.text(max_nb_chars=200),
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
data_source_type="upload_file",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
|
|
@ -134,11 +135,11 @@ class TestDisableSegmentsFromIndexTask:
|
|||
document.tenant_id = dataset.tenant_id
|
||||
document.dataset_id = dataset.id
|
||||
document.position = 1
|
||||
document.data_source_type = "upload_file"
|
||||
document.data_source_type = DataSourceType.UPLOAD_FILE
|
||||
document.data_source_info = '{"upload_file_id": "test_file_id"}'
|
||||
document.batch = fake.uuid4()
|
||||
document.name = f"Test Document {fake.word()}.txt"
|
||||
document.created_from = "upload_file"
|
||||
document.created_from = DocumentCreatedFrom.WEB
|
||||
document.created_by = account.id
|
||||
document.created_api_request_id = fake.uuid4()
|
||||
document.processing_started_at = fake.date_time_this_year()
|
||||
|
|
@ -197,7 +198,7 @@ class TestDisableSegmentsFromIndexTask:
|
|||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
segment.status = "completed"
|
||||
segment.status = SegmentStatus.COMPLETED
|
||||
segment.created_by = account.id
|
||||
segment.updated_by = account.id
|
||||
segment.indexing_at = fake.date_time_this_year()
|
||||
|
|
@ -230,7 +231,7 @@ class TestDisableSegmentsFromIndexTask:
|
|||
process_rule.id = fake.uuid4()
|
||||
process_rule.tenant_id = dataset.tenant_id
|
||||
process_rule.dataset_id = dataset.id
|
||||
process_rule.mode = "automatic"
|
||||
process_rule.mode = ProcessRuleMode.AUTOMATIC
|
||||
process_rule.rules = (
|
||||
"{"
|
||||
'"mode": "automatic", '
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue