refactor(api): Query API to select function_1 (#33565)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-03-17 15:29:16 +01:00 committed by GitHub
parent 076b297b18
commit 7757bb5089
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 196 additions and 258 deletions

View File

@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
agent_thoughts = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tool_names_raw = agent_thought.tool

View File

@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase):
@classmethod
def get_by_openid(cls, provider: str, open_id: str):
account_integrate = (
db.session.query(AccountIntegrate)
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none()
)
account_integrate = db.session.execute(
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
).scalar_one_or_none()
if account_integrate:
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
return None
# check current_user.current_tenant.current_role in ['admin', 'owner']

View File

@ -8,6 +8,7 @@ import os
import pickle
import re
import time
from collections.abc import Sequence
from datetime import datetime
from json import JSONDecodeError
from typing import Any, TypedDict, cast
@ -145,30 +146,25 @@ class Dataset(Base):
@property
def total_documents(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def total_available_documents(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def dataset_keyword_table(self):
dataset_keyword_table = (
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
)
if dataset_keyword_table:
return dataset_keyword_table
return None
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
@property
def index_struct_dict(self):
@ -195,64 +191,66 @@ class Dataset(Base):
@property
def latest_process_rule(self):
return (
db.session.query(DatasetProcessRule)
return db.session.scalar(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
.limit(1)
)
@property
def app_count(self):
return (
db.session.query(func.count(AppDatasetJoin.id))
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar()
db.session.scalar(
select(func.count(AppDatasetJoin.id)).where(
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
)
)
or 0
)
@property
def document_count(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def available_document_count(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def available_segment_count(self):
return (
db.session.query(func.count(DocumentSegment.id))
.where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
)
.scalar()
or 0
)
@property
def word_count(self):
return (
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
.where(Document.dataset_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
)
@property
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
if document:
return document.doc_form
return None
@ -270,8 +268,8 @@ class Dataset(Base):
@property
def tags(self):
tags = (
db.session.query(Tag)
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@ -279,8 +277,7 @@ class Dataset(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "knowledge",
)
.all()
)
).all()
return tags or []
@ -288,8 +285,8 @@ class Dataset(Base):
def external_knowledge_info(self):
if self.provider != "external":
return None
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
)
if not external_knowledge_binding:
return None
@ -310,7 +307,7 @@ class Dataset(Base):
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
if pipeline:
return pipeline.is_published
return False
@ -521,10 +518,8 @@ class Document(Base):
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none()
file_detail = db.session.scalar(
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
)
if file_detail:
return {
@ -557,24 +552,23 @@ class Document(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def segment_count(self):
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
return (
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
)
@property
def hit_count(self):
return (
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.where(DocumentSegment.document_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
)
@property
def uploader(self):
user = db.session.query(Account).where(Account.id == self.created_by).first()
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
return user.name if user else None
@property
@ -588,14 +582,13 @@ class Document(Base):
@property
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
document_metadatas = db.session.scalars(
select(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
)
).all()
metadata_list: list[DocMetadataDetailItem] = []
for metadata in document_metadatas:
metadata_dict: DocMetadataDetailItem = {
@ -826,7 +819,7 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self) -> list[Any]:
def child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -835,16 +828,13 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
def get_child_chunks(self) -> list[Any]:
def get_child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -853,12 +843,9 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
@ -1007,15 +994,15 @@ class ChildChunk(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def document(self):
return db.session.query(Document).where(Document.id == self.document_id).first()
return db.session.scalar(select(Document).where(Document.id == self.document_id))
@property
def segment(self):
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
class AppDatasetJoin(TypeBase):
@ -1076,7 +1063,7 @@ class DatasetQuery(TypeBase):
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
if file_info:
query["file_info"] = {
"id": file_info.id,
@ -1141,7 +1128,7 @@ class DatasetKeywordTable(TypeBase):
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
if not dataset:
return None
if self.data_source_type == "database":
@ -1535,7 +1522,7 @@ class PipelineCustomizedTemplate(TypeBase):
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
return ""
@ -1570,7 +1557,7 @@ class Pipeline(TypeBase):
)
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
class DocumentPipelineExecutionLog(TypeBase):

View File

@ -380,13 +380,12 @@ class App(Base):
@property
def site(self) -> Site | None:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
return db.session.scalar(select(Site).where(Site.app_id == self.id))
@property
def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
return None
@ -395,7 +394,7 @@ class App(Base):
if self.workflow_id:
from .workflow import Workflow
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
return None
@ -405,8 +404,7 @@ class App(Base):
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def is_agent(self) -> bool:
@ -546,9 +544,9 @@ class App(Base):
return deleted_tools
@property
def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
def tags(self) -> Sequence[Tag]:
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@ -556,15 +554,14 @@ class App(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "app",
)
.all()
)
).all()
return tags or []
@property
def author_name(self) -> str | None:
if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
@ -616,8 +613,7 @@ class AppModelConfig(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def model_dict(self) -> ModelConfig:
@ -652,8 +648,8 @@ class AppModelConfig(TypeBase):
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
@ -845,8 +841,7 @@ class RecommendedApp(Base): # bug
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class InstalledApp(TypeBase):
@ -873,13 +868,11 @@ class InstalledApp(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base):
@ -899,8 +892,7 @@ class TrialApp(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base):
@ -919,13 +911,11 @@ class AccountTrialAppRecord(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class ExporleBanner(TypeBase):
@ -1117,8 +1107,8 @@ class Conversation(Base):
else:
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
app_model_config = db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
)
if app_model_config:
model_config = app_model_config.to_dict()
@ -1141,36 +1131,43 @@ class Conversation(Base):
@property
def annotated(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
return (
db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
)
or 0
) > 0
@property
def annotation(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
@property
def message_count(self):
return db.session.query(Message).where(Message.conversation_id == self.id).count()
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
@property
def user_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@ -1178,23 +1175,25 @@ class Conversation(Base):
@property
def admin_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@ -1256,22 +1255,19 @@ class Conversation(Base):
@property
def first_message(self):
return (
db.session.query(Message)
.where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc())
.first()
return db.session.scalar(
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
)
@property
def app(self) -> App | None:
with Session(db.engine, expire_on_commit=False) as session:
return session.query(App).where(App.id == self.app_id).first()
return session.scalar(select(App).where(App.id == self.app_id))
@property
def from_end_user_session_id(self):
if self.from_end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
if end_user:
return end_user.session_id
@ -1280,7 +1276,7 @@ class Conversation(Base):
@property
def from_account_name(self) -> str | None:
if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
if account:
return account.name
@ -1505,21 +1501,15 @@ class Message(Base):
@property
def user_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
)
return feedback
@property
def admin_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
)
return feedback
@property
def feedbacks(self):
@ -1528,28 +1518,27 @@ class Message(Base):
@property
def annotation(self):
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
return annotation
@property
def annotation_hit_history(self):
annotation_history = (
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
annotation_history = db.session.scalar(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
)
if annotation_history:
annotation = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id == annotation_history.annotation_id)
.first()
return db.session.scalar(
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
)
return annotation
return None
@property
def app_model_config(self):
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
if conversation:
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
return db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
)
return None
@ -1562,13 +1551,12 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
return db.session.scalars(
select(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
).all()
@property
def retriever_resources(self) -> Any:
@ -1579,7 +1567,7 @@ class Message(Base):
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
if not current_app:
raise ValueError(f"App {self.app_id} not found")
@ -1743,8 +1731,7 @@ class MessageFeedback(TypeBase):
@property
def from_account(self) -> Account | None:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
def to_dict(self) -> MessageFeedbackDict:
return {
@ -1817,13 +1804,11 @@ class MessageAnnotation(Base):
@property
def account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationHitHistory(TypeBase):
@ -1852,18 +1837,15 @@ class AppAnnotationHitHistory(TypeBase):
@property
def account(self):
account = (
db.session.query(Account)
return db.session.scalar(
select(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.where(MessageAnnotation.id == self.annotation_id)
.first()
)
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationSetting(TypeBase):
@ -1896,12 +1878,9 @@ class AppAnnotationSetting(TypeBase):
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
collection_binding_detail = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == self.collection_binding_id)
.first()
return db.session.scalar(
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
)
return collection_binding_detail
class OperationLog(TypeBase):
@ -2007,7 +1986,9 @@ class AppMCPServer(TypeBase):
def generate_server_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
while (
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
) > 0:
result = generate_string(n)
return result
@ -2068,7 +2049,7 @@ class Site(Base):
def generate_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0:
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
result = generate_string(n)
return result

View File

@ -6,7 +6,7 @@ from functools import cached_property
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, text
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
@ -96,7 +96,7 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
@property
def credential_name(self):
@ -159,10 +159,8 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return (
db.session.query(ProviderModelCredential)
.where(ProviderModelCredential.id == self.credential_id)
.first()
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
@property

View File

@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func
from sqlalchemy import ForeignKey, String, func, select
from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
def user(self) -> Account | None:
if not self.user_id:
return None
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class ToolLabelBinding(TypeBase):
@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
@property
def user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
@property
def app(self) -> App | None:
return db.session.query(App).where(App.id == self.app_id).first()
return db.session.scalar(select(App).where(App.id == self.app_id))
class MCPToolProvider(TypeBase):
@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def credentials(self) -> dict[str, Any]:

View File

@ -2,7 +2,7 @@ from datetime import datetime
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, func
from sqlalchemy import DateTime, func, select
from sqlalchemy.orm import Mapped, mapped_column
from .base import TypeBase
@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
@property
def message(self):
return db.session.query(Message).where(Message.id == self.message_id).first()
return db.session.scalar(select(Message).where(Message.id == self.message_id))
class PinnedConversation(TypeBase):

View File

@ -679,14 +679,14 @@ class WorkflowRun(Base):
def message(self):
from .model import Message
return (
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
return db.session.scalar(
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
)
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def workflow(self):
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
def to_dict(self):
return {

View File

@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
from models.model import App, Conversation, EndUser, Message
class AgentService:
@ -47,7 +47,7 @@ class AgentService:
if not message:
raise ValueError(f"Message not found: {message_id}")
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
agent_thoughts = message.agent_thoughts
if conversation.from_end_user_id:
# only select name field

View File

@ -622,28 +622,10 @@ class TestAccountGetByOpenId:
mock_account = Account(name="Test User", email="test@example.com")
mock_account.id = account_id
# Mock the query chain
mock_query = MagicMock()
mock_where = MagicMock()
mock_where.one_or_none.return_value = mock_account_integrate
mock_query.where.return_value = mock_where
mock_db.session.query.return_value = mock_query
# Mock the second query for account
mock_account_query = MagicMock()
mock_account_where = MagicMock()
mock_account_where.one_or_none.return_value = mock_account
mock_account_query.where.return_value = mock_account_where
# Setup query to return different results based on model
def query_side_effect(model):
if model.__name__ == "AccountIntegrate":
return mock_query
elif model.__name__ == "Account":
return mock_account_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
# Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup
mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate
# Mock db.session.scalar() for Account lookup
mock_db.session.scalar.return_value = mock_account
# Act
result = Account.get_by_openid(provider, open_id)
@ -658,12 +640,8 @@ class TestAccountGetByOpenId:
provider = "github"
open_id = "github_user_456"
# Mock the query chain to return None
mock_query = MagicMock()
mock_where = MagicMock()
mock_where.one_or_none.return_value = None
mock_query.where.return_value = mock_where
mock_db.session.query.return_value = mock_query
# Mock db.session.execute().scalar_one_or_none() to return None
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
# Act
result = Account.get_by_openid(provider, open_id)

View File

@ -300,10 +300,8 @@ class TestAppModelConfig:
created_by=str(uuid4()),
)
# Mock database query to return None
with patch("models.model.db.session.query", autospec=True) as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
# Mock database scalar to return None (no annotation setting found)
with patch("models.model.db.session.scalar", return_value=None):
# Act
result = config.annotation_reply_dict
@ -951,10 +949,8 @@ class TestSiteModel:
def test_site_generate_code(self):
"""Test Site.generate_code static method."""
# Mock database query to return 0 (no existing codes)
with patch("models.model.db.session.query", autospec=True) as mock_query:
mock_query.return_value.where.return_value.count.return_value = 0
# Mock database scalar to return 0 (no existing codes)
with patch("models.model.db.session.scalar", return_value=0):
# Act
code = Site.generate_code(8)