mirror of https://github.com/langgenius/dify.git
refactor(api): Query API to select function_1 (#33565)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
076b297b18
commit
7757bb5089
|
|
@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
|
|||
continue
|
||||
|
||||
result.append(self.organize_agent_user_prompt(message))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
agent_thoughts = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tool_names_raw = agent_thought.tool
|
||||
|
|
|
|||
|
|
@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase):
|
|||
|
||||
@classmethod
|
||||
def get_by_openid(cls, provider: str, open_id: str):
|
||||
account_integrate = (
|
||||
db.session.query(AccountIntegrate)
|
||||
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
.one_or_none()
|
||||
)
|
||||
account_integrate = db.session.execute(
|
||||
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
).scalar_one_or_none()
|
||||
if account_integrate:
|
||||
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
|
||||
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
|
||||
return None
|
||||
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import os
|
|||
import pickle
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, TypedDict, cast
|
||||
|
|
@ -145,30 +146,25 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def total_documents(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def total_available_documents(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
dataset_keyword_table = (
|
||||
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
|
||||
)
|
||||
if dataset_keyword_table:
|
||||
return dataset_keyword_table
|
||||
|
||||
return None
|
||||
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
|
||||
|
||||
@property
|
||||
def index_struct_dict(self):
|
||||
|
|
@ -195,64 +191,66 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def latest_process_rule(self):
|
||||
return (
|
||||
db.session.query(DatasetProcessRule)
|
||||
return db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == self.id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@property
|
||||
def app_count(self):
|
||||
return (
|
||||
db.session.query(func.count(AppDatasetJoin.id))
|
||||
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
|
||||
.scalar()
|
||||
db.session.scalar(
|
||||
select(func.count(AppDatasetJoin.id)).where(
|
||||
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def document_count(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def available_document_count(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def available_segment_count(self):
|
||||
return (
|
||||
db.session.query(func.count(DocumentSegment.id))
|
||||
.where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def word_count(self):
|
||||
return (
|
||||
db.session.query(Document)
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
|
||||
.where(Document.dataset_id == self.id)
|
||||
.scalar()
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
def doc_form(self) -> str | None:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
|
||||
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
|
||||
if document:
|
||||
return document.doc_form
|
||||
return None
|
||||
|
|
@ -270,8 +268,8 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
|
|
@ -279,8 +277,7 @@ class Dataset(Base):
|
|||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "knowledge",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
|
|
@ -288,8 +285,8 @@ class Dataset(Base):
|
|||
def external_knowledge_info(self):
|
||||
if self.provider != "external":
|
||||
return None
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
return None
|
||||
|
|
@ -310,7 +307,7 @@ class Dataset(Base):
|
|||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
|
||||
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
|
|
@ -521,10 +518,8 @@ class Document(Base):
|
|||
if self.data_source_info:
|
||||
if self.data_source_type == "upload_file":
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
.one_or_none()
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
)
|
||||
if file_detail:
|
||||
return {
|
||||
|
|
@ -557,24 +552,23 @@ class Document(Base):
|
|||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
|
||||
@property
|
||||
def segment_count(self):
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
|
||||
return (
|
||||
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def hit_count(self):
|
||||
return (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
|
||||
.where(DocumentSegment.document_id == self.id)
|
||||
.scalar()
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
def uploader(self):
|
||||
user = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
return user.name if user else None
|
||||
|
||||
@property
|
||||
|
|
@ -588,14 +582,13 @@ class Document(Base):
|
|||
@property
|
||||
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
|
||||
if self.doc_metadata:
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
document_metadatas = db.session.scalars(
|
||||
select(DatasetMetadata)
|
||||
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
metadata_list: list[DocMetadataDetailItem] = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict: DocMetadataDetailItem = {
|
||||
|
|
@ -826,7 +819,7 @@ class DocumentSegment(Base):
|
|||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self) -> list[Any]:
|
||||
def child_chunks(self) -> Sequence[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
|
|
@ -835,16 +828,13 @@ class DocumentSegment(Base):
|
|||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
def get_child_chunks(self) -> list[Any]:
|
||||
def get_child_chunks(self) -> Sequence[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
|
|
@ -853,12 +843,9 @@ class DocumentSegment(Base):
|
|||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
|
|
@ -1007,15 +994,15 @@ class ChildChunk(Base):
|
|||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
return db.session.query(Document).where(Document.id == self.document_id).first()
|
||||
return db.session.scalar(select(Document).where(Document.id == self.document_id))
|
||||
|
||||
@property
|
||||
def segment(self):
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
|
||||
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
|
||||
|
||||
|
||||
class AppDatasetJoin(TypeBase):
|
||||
|
|
@ -1076,7 +1063,7 @@ class DatasetQuery(TypeBase):
|
|||
if isinstance(queries, list):
|
||||
for query in queries:
|
||||
if query["content_type"] == QueryType.IMAGE_QUERY:
|
||||
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
|
||||
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
|
||||
if file_info:
|
||||
query["file_info"] = {
|
||||
"id": file_info.id,
|
||||
|
|
@ -1141,7 +1128,7 @@ class DatasetKeywordTable(TypeBase):
|
|||
super().__init__(object_hook=object_hook, *args, **kwargs)
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
if not dataset:
|
||||
return None
|
||||
if self.data_source_type == "database":
|
||||
|
|
@ -1535,7 +1522,7 @@ class PipelineCustomizedTemplate(TypeBase):
|
|||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
|
@ -1570,7 +1557,7 @@ class Pipeline(TypeBase):
|
|||
)
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(TypeBase):
|
||||
|
|
|
|||
|
|
@ -380,13 +380,12 @@ class App(Base):
|
|||
|
||||
@property
|
||||
def site(self) -> Site | None:
|
||||
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
||||
return site
|
||||
return db.session.scalar(select(Site).where(Site.app_id == self.id))
|
||||
|
||||
@property
|
||||
def app_model_config(self) -> AppModelConfig | None:
|
||||
if self.app_model_config_id:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -395,7 +394,7 @@ class App(Base):
|
|||
if self.workflow_id:
|
||||
from .workflow import Workflow
|
||||
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -405,8 +404,7 @@ class App(Base):
|
|||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
@property
|
||||
def is_agent(self) -> bool:
|
||||
|
|
@ -546,9 +544,9 @@ class App(Base):
|
|||
return deleted_tools
|
||||
|
||||
@property
|
||||
def tags(self) -> list[Tag]:
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
def tags(self) -> Sequence[Tag]:
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
|
|
@ -556,15 +554,14 @@ class App(Base):
|
|||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "app",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
@property
|
||||
def author_name(self) -> str | None:
|
||||
if self.created_by:
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
|
|
@ -616,8 +613,7 @@ class AppModelConfig(TypeBase):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def model_dict(self) -> ModelConfig:
|
||||
|
|
@ -652,8 +648,8 @@ class AppModelConfig(TypeBase):
|
|||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
|
@ -845,8 +841,7 @@ class RecommendedApp(Base): # bug
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class InstalledApp(TypeBase):
|
||||
|
|
@ -873,13 +868,11 @@ class InstalledApp(TypeBase):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
|
|
@ -899,8 +892,7 @@ class TrialApp(Base):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
|
|
@ -919,13 +911,11 @@ class AccountTrialAppRecord(Base):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return user
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class ExporleBanner(TypeBase):
|
||||
|
|
@ -1117,8 +1107,8 @@ class Conversation(Base):
|
|||
else:
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
else:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
app_model_config = db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
|
||||
)
|
||||
if app_model_config:
|
||||
model_config = app_model_config.to_dict()
|
||||
|
|
@ -1141,36 +1131,43 @@ class Conversation(Base):
|
|||
|
||||
@property
|
||||
def annotated(self):
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
|
||||
)
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
|
||||
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
|
||||
|
||||
@property
|
||||
def message_count(self):
|
||||
return db.session.query(Message).where(Message.conversation_id == self.id).count()
|
||||
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def user_feedback_stats(self):
|
||||
like = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
|
|
@ -1178,23 +1175,25 @@ class Conversation(Base):
|
|||
@property
|
||||
def admin_feedback_stats(self):
|
||||
like = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
|
|
@ -1256,22 +1255,19 @@ class Conversation(Base):
|
|||
|
||||
@property
|
||||
def first_message(self):
|
||||
return (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == self.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return session.query(App).where(App.id == self.app_id).first()
|
||||
return session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def from_end_user_session_id(self):
|
||||
if self.from_end_user_id:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
|
||||
if end_user:
|
||||
return end_user.session_id
|
||||
|
||||
|
|
@ -1280,7 +1276,7 @@ class Conversation(Base):
|
|||
@property
|
||||
def from_account_name(self) -> str | None:
|
||||
if self.from_account_id:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
|
|
@ -1505,21 +1501,15 @@ class Message(Base):
|
|||
|
||||
@property
|
||||
def user_feedback(self):
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def admin_feedback(self):
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def feedbacks(self):
|
||||
|
|
@ -1528,28 +1518,27 @@ class Message(Base):
|
|||
|
||||
@property
|
||||
def annotation(self):
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
|
||||
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
|
||||
return annotation
|
||||
|
||||
@property
|
||||
def annotation_hit_history(self):
|
||||
annotation_history = (
|
||||
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
|
||||
annotation_history = db.session.scalar(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
|
||||
)
|
||||
if annotation_history:
|
||||
annotation = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
)
|
||||
return annotation
|
||||
return None
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
|
||||
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
|
||||
if conversation:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
|
||||
return db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -1562,13 +1551,12 @@ class Message(Base):
|
|||
return json.loads(self.message_metadata) if self.message_metadata else {}
|
||||
|
||||
@property
|
||||
def agent_thoughts(self) -> list[MessageAgentThought]:
|
||||
return (
|
||||
db.session.query(MessageAgentThought)
|
||||
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
|
||||
return db.session.scalars(
|
||||
select(MessageAgentThought)
|
||||
.where(MessageAgentThought.message_id == self.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
@property
|
||||
def retriever_resources(self) -> Any:
|
||||
|
|
@ -1579,7 +1567,7 @@ class Message(Base):
|
|||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
current_app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
|
|
@ -1743,8 +1731,7 @@ class MessageFeedback(TypeBase):
|
|||
|
||||
@property
|
||||
def from_account(self) -> Account | None:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
return {
|
||||
|
|
@ -1817,13 +1804,11 @@ class MessageAnnotation(Base):
|
|||
|
||||
@property
|
||||
def account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(TypeBase):
|
||||
|
|
@ -1852,18 +1837,15 @@ class AppAnnotationHitHistory(TypeBase):
|
|||
|
||||
@property
|
||||
def account(self):
|
||||
account = (
|
||||
db.session.query(Account)
|
||||
return db.session.scalar(
|
||||
select(Account)
|
||||
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
|
||||
.where(MessageAnnotation.id == self.annotation_id)
|
||||
.first()
|
||||
)
|
||||
return account
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class AppAnnotationSetting(TypeBase):
|
||||
|
|
@ -1896,12 +1878,9 @@ class AppAnnotationSetting(TypeBase):
|
|||
def collection_binding_detail(self):
|
||||
from .dataset import DatasetCollectionBinding
|
||||
|
||||
collection_binding_detail = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
)
|
||||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(TypeBase):
|
||||
|
|
@ -2007,7 +1986,9 @@ class AppMCPServer(TypeBase):
|
|||
def generate_server_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
|
||||
while (
|
||||
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
|
||||
) > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
|
|
@ -2068,7 +2049,7 @@ class Site(Base):
|
|||
def generate_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(Site).where(Site.code == result).count() > 0:
|
||||
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from functools import cached_property
|
|||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, text
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
|
@ -96,7 +96,7 @@ class Provider(TypeBase):
|
|||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
|
||||
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
|
|
@ -159,10 +159,8 @@ class ProviderModel(TypeBase):
|
|||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.where(ProviderModelCredential.id == self.credential_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from uuid import uuid4
|
|||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy import ForeignKey, String, func
|
||||
from sqlalchemy import ForeignKey, String, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
|
@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
|
|||
def user(self) -> Account | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class ToolLabelBinding(TypeBase):
|
||||
|
|
@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
|
|||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
@property
|
||||
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
|
||||
|
|
@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
|
|||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).where(App.id == self.app_id).first()
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class MCPToolProvider(TypeBase):
|
||||
|
|
@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
|
|||
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from datetime import datetime
|
|||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy import DateTime, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import TypeBase
|
||||
|
|
@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
|
|||
|
||||
@property
|
||||
def message(self):
|
||||
return db.session.query(Message).where(Message.id == self.message_id).first()
|
||||
return db.session.scalar(select(Message).where(Message.id == self.message_id))
|
||||
|
||||
|
||||
class PinnedConversation(TypeBase):
|
||||
|
|
|
|||
|
|
@ -679,14 +679,14 @@ class WorkflowRun(Base):
|
|||
def message(self):
|
||||
from .model import Message
|
||||
|
||||
return (
|
||||
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
|
|||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
|
||||
|
||||
class AgentService:
|
||||
|
|
@ -47,7 +47,7 @@ class AgentService:
|
|||
if not message:
|
||||
raise ValueError(f"Message not found: {message_id}")
|
||||
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
agent_thoughts = message.agent_thoughts
|
||||
|
||||
if conversation.from_end_user_id:
|
||||
# only select name field
|
||||
|
|
|
|||
|
|
@ -622,28 +622,10 @@ class TestAccountGetByOpenId:
|
|||
mock_account = Account(name="Test User", email="test@example.com")
|
||||
mock_account.id = account_id
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_where.one_or_none.return_value = mock_account_integrate
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
# Mock the second query for account
|
||||
mock_account_query = MagicMock()
|
||||
mock_account_where = MagicMock()
|
||||
mock_account_where.one_or_none.return_value = mock_account
|
||||
mock_account_query.where.return_value = mock_account_where
|
||||
|
||||
# Setup query to return different results based on model
|
||||
def query_side_effect(model):
|
||||
if model.__name__ == "AccountIntegrate":
|
||||
return mock_query
|
||||
elif model.__name__ == "Account":
|
||||
return mock_account_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
# Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate
|
||||
# Mock db.session.scalar() for Account lookup
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
|
@ -658,12 +640,8 @@ class TestAccountGetByOpenId:
|
|||
provider = "github"
|
||||
open_id = "github_user_456"
|
||||
|
||||
# Mock the query chain to return None
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_where.one_or_none.return_value = None
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_db.session.query.return_value = mock_query
|
||||
# Mock db.session.execute().scalar_one_or_none() to return None
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
|
|
|||
|
|
@ -300,10 +300,8 @@ class TestAppModelConfig:
|
|||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Mock database query to return None
|
||||
with patch("models.model.db.session.query", autospec=True) as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Mock database scalar to return None (no annotation setting found)
|
||||
with patch("models.model.db.session.scalar", return_value=None):
|
||||
# Act
|
||||
result = config.annotation_reply_dict
|
||||
|
||||
|
|
@ -951,10 +949,8 @@ class TestSiteModel:
|
|||
|
||||
def test_site_generate_code(self):
|
||||
"""Test Site.generate_code static method."""
|
||||
# Mock database query to return 0 (no existing codes)
|
||||
with patch("models.model.db.session.query", autospec=True) as mock_query:
|
||||
mock_query.return_value.where.return_value.count.return_value = 0
|
||||
|
||||
# Mock database scalar to return 0 (no existing codes)
|
||||
with patch("models.model.db.session.scalar", return_value=0):
|
||||
# Act
|
||||
code = Site.generate_code(8)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue