refactor: select in console datasets document controller (#34019)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-03-24 13:57:38 +01:00 committed by GitHub
parent 542c1a14e0
commit e3c1112b15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 96 deletions

View File

@ -10,7 +10,7 @@ import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -211,12 +211,11 @@ class GetProcessRuleApi(Resource):
raise Forbidden(str(e))
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
dataset_process_rule = db.session.scalar(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
)
if dataset_process_rule:
mode = dataset_process_rule.mode
@ -330,21 +329,23 @@ class DatasetDocumentListApi(Resource):
if fetch:
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
document.completed_segments = completed_segments
document.total_segments = total_segments
@ -521,10 +522,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file = (
db.session.query(UploadFile)
file = db.session.scalar(
select(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
.limit(1)
)
# raise error if file not found
@ -586,10 +587,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
file_detail = db.session.scalar(
select(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
.limit(1)
)
if file_detail is None:
@ -672,20 +673,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
# Create a dictionary with document attributes and additional fields
document_dict = {
@ -723,18 +727,23 @@ class DocumentIndexingStatusApi(DocumentResource):
document = self.get_document(dataset_id, document_id)
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
.count()
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
or 0
)
# Create a dictionary with document attributes and additional fields
@ -1258,11 +1267,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
log = db.session.scalar(
select(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
.limit(1)
)
if not log:
return {

View File

@ -2,6 +2,8 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
del kwargs["pipeline_id"]
pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
pipeline = db.session.scalar(
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
)
if not pipeline:

View File

@ -140,8 +140,8 @@ class TestDatasetDocumentListApi:
return_value=pagination,
),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=2,
),
patch(
"controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status",
@ -700,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(
filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log))
),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=log,
),
):
response, status = method(api, "ds-1", "doc-1")
@ -827,15 +825,12 @@ class TestDocumentIndexingEstimateApi:
dataset_process_rule=None,
)
query_mock = MagicMock()
query_mock.where.return_value.first.return_value = None
with (
app.test_request_context("/"),
patch.object(api, "get_document", return_value=document),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=query_mock,
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=None,
),
):
with pytest.raises(NotFound):
@ -863,10 +858,8 @@ class TestDocumentIndexingEstimateApi:
app.test_request_context("/"),
patch.object(api, "get_document", return_value=document),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(
where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file)))
),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_document.ExtractSetting",
@ -1239,12 +1232,8 @@ class TestDocumentPermissionCases:
return_value=None,
),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(
where=lambda *a: MagicMock(
order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule))
)
),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=process_rule,
),
):
result = method(api)
@ -1364,8 +1353,8 @@ class TestDocumentIndexingEdgeCases:
app.test_request_context("/"),
patch.object(api, "get_document", return_value=document),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_document.ExtractSetting",

View File

@ -26,12 +26,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = None
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
"controllers.console.datasets.wraps.db.session.scalar",
return_value=None,
)
with pytest.raises(PipelineNotFoundError):
@ -51,12 +48,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
"controllers.console.datasets.wraps.db.session.scalar",
return_value=pipeline,
)
result = dummy_view(pipeline_id="pipeline-1")
@ -76,12 +70,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
"controllers.console.datasets.wraps.db.session.scalar",
return_value=pipeline,
)
result = dummy_view(pipeline_id="pipeline-1")
@ -100,18 +91,15 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"),
)
def where_side_effect(*args, **kwargs):
assert args[0].right.value == "123"
return Mock(first=lambda: pipeline)
mock_query = Mock()
mock_query.where.side_effect = where_side_effect
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
mock_scalar = mocker.patch(
"controllers.console.datasets.wraps.db.session.scalar",
return_value=pipeline,
)
result = dummy_view(pipeline_id=123)
assert result is pipeline
# Verify the pipeline_id was cast to string in the where clause
stmt = mock_scalar.call_args[0][0]
where_clauses = stmt.whereclause.clauses
assert where_clauses[0].right.value == "123"