mirror of https://github.com/langgenius/dify.git
refactor: select in console datasets document controller (#34019)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
542c1a14e0
commit
e3c1112b15
|
|
@ -10,7 +10,7 @@ import sqlalchemy as sa
|
|||
from flask import request, send_file
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, desc, select
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -211,12 +211,11 @@ class GetProcessRuleApi(Resource):
|
|||
raise Forbidden(str(e))
|
||||
|
||||
# get the latest process rule
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
dataset_process_rule = db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == document.dataset_id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
|
|
@ -330,21 +329,23 @@ class DatasetDocumentListApi(Resource):
|
|||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
|
|
@ -521,10 +522,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
file = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
|
|
@ -586,10 +587,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if file_detail is None:
|
||||
|
|
@ -672,20 +673,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
|
|
@ -723,18 +727,23 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
|
|
@ -1258,11 +1267,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
|||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
|
|
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
|
|||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
pipeline = db.session.scalar(
|
||||
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
|
|
|
|||
|
|
@ -140,8 +140,8 @@ class TestDatasetDocumentListApi:
|
|||
return_value=pagination,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=2,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status",
|
||||
|
|
@ -700,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi:
|
|||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(
|
||||
filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log))
|
||||
),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=log,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
|
|
@ -827,15 +825,12 @@ class TestDocumentIndexingEstimateApi:
|
|||
dataset_process_rule=None,
|
||||
)
|
||||
|
||||
query_mock = MagicMock()
|
||||
query_mock.where.return_value.first.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(api, "get_document", return_value=document),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=query_mock,
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
|
|
@ -863,10 +858,8 @@ class TestDocumentIndexingEstimateApi:
|
|||
app.test_request_context("/"),
|
||||
patch.object(api, "get_document", return_value=document),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(
|
||||
where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file)))
|
||||
),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.ExtractSetting",
|
||||
|
|
@ -1239,12 +1232,8 @@ class TestDocumentPermissionCases:
|
|||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(
|
||||
where=lambda *a: MagicMock(
|
||||
order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule))
|
||||
)
|
||||
),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=process_rule,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
|
@ -1364,8 +1353,8 @@ class TestDocumentIndexingEdgeCases:
|
|||
app.test_request_context("/"),
|
||||
patch.object(api, "get_document", return_value=document),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.ExtractSetting",
|
||||
|
|
|
|||
|
|
@ -26,12 +26,9 @@ class TestGetRagPipeline:
|
|||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
"controllers.console.datasets.wraps.db.session.scalar",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(PipelineNotFoundError):
|
||||
|
|
@ -51,12 +48,9 @@ class TestGetRagPipeline:
|
|||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
"controllers.console.datasets.wraps.db.session.scalar",
|
||||
return_value=pipeline,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
|
@ -76,12 +70,9 @@ class TestGetRagPipeline:
|
|||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
"controllers.console.datasets.wraps.db.session.scalar",
|
||||
return_value=pipeline,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
|
@ -100,18 +91,15 @@ class TestGetRagPipeline:
|
|||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
def where_side_effect(*args, **kwargs):
|
||||
assert args[0].right.value == "123"
|
||||
return Mock(first=lambda: pipeline)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.side_effect = where_side_effect
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
mock_scalar = mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.scalar",
|
||||
return_value=pipeline,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id=123)
|
||||
|
||||
assert result is pipeline
|
||||
# Verify the pipeline_id was cast to string in the where clause
|
||||
stmt = mock_scalar.call_args[0][0]
|
||||
where_clauses = stmt.whereclause.clauses
|
||||
assert where_clauses[0].right.value == "123"
|
||||
|
|
|
|||
Loading…
Reference in New Issue