refactor: select in console app message controller (#33893)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-03-23 08:38:04 +01:00 committed by GitHub
parent a942d4c926
commit 02e13e6d05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 29 deletions

View File

@ -4,7 +4,7 @@ from typing import Literal
from flask import request from flask import request
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@ -244,27 +244,25 @@ class ChatMessageListApi(Resource):
def get(self, app_model): def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict()) args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = ( conversation = db.session.scalar(
db.session.query(Conversation) select(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first() .limit(1)
) )
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if args.first_id: if args.first_id:
first_message = ( first_message = db.session.scalar(
db.session.query(Message) select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
) )
if not first_message: if not first_message:
raise NotFound("First message not found") raise NotFound("First message not found")
history_messages = ( history_messages = db.session.scalars(
db.session.query(Message) select(Message)
.where( .where(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at, Message.created_at < first_message.created_at,
@ -272,16 +270,14 @@ class ChatMessageListApi(Resource):
) )
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args.limit) .limit(args.limit)
.all() ).all()
)
else: else:
history_messages = ( history_messages = db.session.scalars(
db.session.query(Message) select(Message)
.where(Message.conversation_id == conversation.id) .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args.limit) .limit(args.limit)
.all() ).all()
)
# Initialize has_more based on whether we have a full page # Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit: if len(history_messages) == args.limit:
@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource):
message_id = str(args.message_id) message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message: if not message:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() count = db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
)
return {"count": count} return {"count": count}
@ -479,7 +479,9 @@ class MessageApi(Resource):
def get(self, app_model, message_id: str): def get(self, app_model, message_id: str):
message_id = str(message_id) message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message: if not message:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

View File

@ -170,7 +170,7 @@ class TestMessageEndpoints:
mock_app_model, mock_app_model,
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}, qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args): ) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.first.return_value = None mock_db.session.scalar.return_value = None
with pytest.raises(NotFound): with pytest.raises(NotFound):
api.get(**v_args) api.get(**v_args)
@ -198,11 +198,11 @@ class TestMessageEndpoints:
mock_msg.message = {} mock_msg.message = {}
mock_msg.message_metadata_dict = {} mock_msg.message_metadata_dict = {}
# mock returns # scalar() is called twice: first for conversation lookup, second for has_more check
q_mock = mock_db.data_query mock_db.session.scalar.side_effect = [mock_conv, False]
q_mock.where.return_value.first.side_effect = [mock_conv] scalars_result = MagicMock()
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg] scalars_result.all.return_value = [mock_msg]
mock_db.session.scalar.return_value = False mock_db.session.scalars.return_value = scalars_result
resp = api.get(**v_args) resp = api.get(**v_args)
assert resp["limit"] == 1 assert resp["limit"] == 1
@ -219,7 +219,7 @@ class TestMessageEndpoints:
mock_app_model, mock_app_model,
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"}, payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args): ) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.first.return_value = None mock_db.session.scalar.return_value = None
with pytest.raises(NotFound): with pytest.raises(NotFound):
api.post(**v_args) api.post(**v_args)
@ -231,7 +231,7 @@ class TestMessageEndpoints:
) as (api, mock_db, v_args): ) as (api, mock_db, v_args):
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.admin_feedback = None mock_msg.admin_feedback = None
mock_db.data_query.where.return_value.first.return_value = mock_msg mock_db.session.scalar.return_value = mock_msg
resp = api.post(**v_args) resp = api.post(**v_args)
assert resp == {"result": "success"} assert resp == {"result": "success"}
@ -240,7 +240,7 @@ class TestMessageEndpoints:
with setup_test_context( with setup_test_context(
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args): ) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.count.return_value = 5 mock_db.session.scalar.return_value = 5
resp = api.get(**v_args) resp = api.get(**v_args)
assert resp == {"count": 5} assert resp == {"count": 5}
@ -314,7 +314,7 @@ class TestMessageEndpoints:
mock_msg.message = {} mock_msg.message = {}
mock_msg.message_metadata_dict = {} mock_msg.message_metadata_dict = {}
mock_db.data_query.where.return_value.first.return_value = mock_msg mock_db.session.scalar.return_value = mock_msg
resp = api.get(**v_args) resp = api.get(**v_args)
assert resp["id"] == "msg_123" assert resp["id"] == "msg_123"