mirror of https://github.com/langgenius/dify.git
refactor: select in console datasets segments and API key controllers
This commit is contained in:
parent
e3c1112b15
commit
89c242060f
|
|
@ -3,7 +3,7 @@ from typing import Any, cast
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -737,22 +737,20 @@ class DatasetIndexingStatusApi(Resource):
|
|||
).all()
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
completed_segments = db.session.scalar(
|
||||
select(func.count(DocumentSegment.id))
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
) or 0
|
||||
total_segments = db.session.scalar(
|
||||
select(func.count(DocumentSegment.id))
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
) or 0
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
|
|
@ -801,11 +799,10 @@ class DatasetApiKeyApi(Resource):
|
|||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
current_key_count = db.session.scalar(
|
||||
select(func.count(ApiToken.id))
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
.count()
|
||||
)
|
||||
) or 0
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
console_ns.abort(
|
||||
|
|
@ -839,14 +836,14 @@ class DatasetApiDeleteApi(Resource):
|
|||
def delete(self, api_key_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if key is None:
|
||||
|
|
@ -857,7 +854,7 @@ class DatasetApiDeleteApi(Resource):
|
|||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.delete(key)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -401,10 +401,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
|
@ -447,10 +447,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
|
@ -494,7 +494,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||
upload_file_id = payload.upload_file_id
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
|
||||
|
|
@ -559,10 +559,10 @@ class ChildChunkAddApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
|
@ -616,10 +616,10 @@ class ChildChunkAddApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
|
@ -666,10 +666,10 @@ class ChildChunkAddApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
|
@ -714,24 +714,24 @@ class ChildChunkUpdateApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
|
@ -771,24 +771,24 @@ class ChildChunkUpdateApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
|
|
|||
|
|
@ -1476,8 +1476,8 @@ class TestDatasetIndexingStatusApi:
|
|||
return_value=MagicMock(all=lambda: [document]),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=3,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "dataset-1")
|
||||
|
|
@ -1526,13 +1526,6 @@ class TestDatasetIndexingStatusApi:
|
|||
document.error = None
|
||||
document.stopped_at = None
|
||||
|
||||
# First count = completed segments, second = total segments
|
||||
query_mock = MagicMock()
|
||||
query_mock.where.side_effect = [
|
||||
MagicMock(count=lambda: 2),
|
||||
MagicMock(count=lambda: 5),
|
||||
]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
|
|
@ -1544,8 +1537,8 @@ class TestDatasetIndexingStatusApi:
|
|||
return_value=MagicMock(all=lambda: [document]),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=query_mock,
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
side_effect=[2, 5],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "dataset-1")
|
||||
|
|
@ -1591,8 +1584,8 @@ class TestDatasetApiKeyApi:
|
|||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=3,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
|
||||
|
|
@ -1625,8 +1618,8 @@ class TestDatasetApiKeyApi:
|
|||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=10,
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest) as exc_info:
|
||||
|
|
@ -1653,8 +1646,8 @@ class TestDatasetApiDeleteApi:
|
|||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=mock_key,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.commit",
|
||||
|
|
@ -1681,8 +1674,8 @@ class TestDatasetApiDeleteApi:
|
|||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
|
|
|
|||
|
|
@ -526,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi:
|
|||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
|
@ -621,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
|||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.redis_client.setnx",
|
||||
|
|
@ -706,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
|||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
|
|
@ -738,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
|||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -770,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
|||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.redis_client.setnx",
|
||||
|
|
@ -831,8 +831,8 @@ class TestChildChunkAddApi:
|
|||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
|
@ -880,8 +880,8 @@ class TestChildChunkAddApi:
|
|||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
|
@ -924,11 +924,8 @@ class TestChildChunkUpdateApi:
|
|||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
side_effect=[
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
|
||||
],
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, child_chunk],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
|
@ -970,11 +967,8 @@ class TestChildChunkUpdateApi:
|
|||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
side_effect=[
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
|
||||
],
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, child_chunk],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
|
@ -1180,8 +1174,8 @@ class TestSegmentOperationCases:
|
|||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
|
|
@ -1215,8 +1209,8 @@ class TestSegmentOperationCases:
|
|||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
|
|
|||
Loading…
Reference in New Issue