refactor: select in console datasets segments and API key controllers

This commit is contained in:
RenzoMXD 2026-03-24 13:42:01 +00:00
parent e3c1112b15
commit 89c242060f
4 changed files with 76 additions and 92 deletions

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -737,22 +737,20 @@ class DatasetIndexingStatusApi(Resource):
).all()
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
completed_segments = db.session.scalar(
select(func.count(DocumentSegment.id))
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
) or 0
total_segments = db.session.scalar(
select(func.count(DocumentSegment.id))
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
) or 0
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
@ -801,11 +799,10 @@ class DatasetApiKeyApi(Resource):
def post(self):
_, current_tenant_id = current_account_with_tenant()
current_key_count = (
db.session.query(ApiToken)
current_key_count = db.session.scalar(
select(func.count(ApiToken.id))
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
)
) or 0
if current_key_count >= self.max_keys:
console_ns.abort(
@ -839,14 +836,14 @@ class DatasetApiDeleteApi(Resource):
def delete(self, api_key_id):
_, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
key = (
db.session.query(ApiToken)
key = db.session.scalar(
select(ApiToken)
.where(
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
.limit(1)
)
if key is None:
@ -857,7 +854,7 @@ class DatasetApiDeleteApi(Resource):
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.delete(key)
db.session.commit()
return {"result": "success"}, 204

View File

@ -401,10 +401,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -447,10 +447,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -494,7 +494,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
payload = BatchImportPayload.model_validate(console_ns.payload or {})
upload_file_id = payload.upload_file_id
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
if not upload_file:
raise NotFound("UploadFile not found.")
@ -559,10 +559,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -616,10 +616,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -666,10 +666,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -714,24 +714,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.first()
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -771,24 +771,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.first()
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@ -1476,8 +1476,8 @@ class TestDatasetIndexingStatusApi:
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=3,
),
):
response, status = method(api, "dataset-1")
@ -1526,13 +1526,6 @@ class TestDatasetIndexingStatusApi:
document.error = None
document.stopped_at = None
# First count = completed segments, second = total segments
query_mock = MagicMock()
query_mock.where.side_effect = [
MagicMock(count=lambda: 2),
MagicMock(count=lambda: 5),
]
with (
app.test_request_context("/"),
patch(
@ -1544,8 +1537,8 @@ class TestDatasetIndexingStatusApi:
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=query_mock,
"controllers.console.datasets.datasets.db.session.scalar",
side_effect=[2, 5],
),
):
response, status = method(api, "dataset-1")
@ -1591,8 +1584,8 @@ class TestDatasetApiKeyApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=3,
),
patch(
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
@ -1625,8 +1618,8 @@ class TestDatasetApiKeyApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=10,
),
):
with pytest.raises(BadRequest) as exc_info:
@ -1653,8 +1646,8 @@ class TestDatasetApiDeleteApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=mock_key,
),
patch(
"controllers.console.datasets.datasets.db.session.commit",
@ -1681,8 +1674,8 @@ class TestDatasetApiDeleteApi:
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)),
"controllers.console.datasets.datasets.db.session.scalar",
return_value=None,
),
):
with pytest.raises(NotFound):

View File

@ -526,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -621,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.redis_client.setnx",
@ -706,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=None,
),
):
with pytest.raises(NotFound):
@ -738,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
):
with pytest.raises(ValueError):
@ -770,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.redis_client.setnx",
@ -831,8 +831,8 @@ class TestChildChunkAddApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -880,8 +880,8 @@ class TestChildChunkAddApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -924,11 +924,8 @@ class TestChildChunkUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
side_effect=[
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
],
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, child_chunk],
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -970,11 +967,8 @@ class TestChildChunkUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
side_effect=[
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
],
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, child_chunk],
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -1180,8 +1174,8 @@ class TestSegmentOperationCases:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
):
with pytest.raises(NotFound):
@ -1215,8 +1209,8 @@ class TestSegmentOperationCases:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_segments.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",