mirror of https://github.com/langgenius/dify.git
refactor: select in console app message controller (#33893)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a942d4c926
commit
02e13e6d05
|
|
@ -4,7 +4,7 @@ from typing import Literal
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, marshal_with
|
from flask_restx import Resource, fields, marshal_with
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import exists, func, select
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
|
|
@ -244,27 +244,25 @@ class ChatMessageListApi(Resource):
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||||
|
|
||||||
conversation = (
|
conversation = db.session.scalar(
|
||||||
db.session.query(Conversation)
|
select(Conversation)
|
||||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
if args.first_id:
|
if args.first_id:
|
||||||
first_message = (
|
first_message = db.session.scalar(
|
||||||
db.session.query(Message)
|
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
|
||||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not first_message:
|
if not first_message:
|
||||||
raise NotFound("First message not found")
|
raise NotFound("First message not found")
|
||||||
|
|
||||||
history_messages = (
|
history_messages = db.session.scalars(
|
||||||
db.session.query(Message)
|
select(Message)
|
||||||
.where(
|
.where(
|
||||||
Message.conversation_id == conversation.id,
|
Message.conversation_id == conversation.id,
|
||||||
Message.created_at < first_message.created_at,
|
Message.created_at < first_message.created_at,
|
||||||
|
|
@ -272,16 +270,14 @@ class ChatMessageListApi(Resource):
|
||||||
)
|
)
|
||||||
.order_by(Message.created_at.desc())
|
.order_by(Message.created_at.desc())
|
||||||
.limit(args.limit)
|
.limit(args.limit)
|
||||||
.all()
|
).all()
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
history_messages = (
|
history_messages = db.session.scalars(
|
||||||
db.session.query(Message)
|
select(Message)
|
||||||
.where(Message.conversation_id == conversation.id)
|
.where(Message.conversation_id == conversation.id)
|
||||||
.order_by(Message.created_at.desc())
|
.order_by(Message.created_at.desc())
|
||||||
.limit(args.limit)
|
.limit(args.limit)
|
||||||
.all()
|
).all()
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize has_more based on whether we have a full page
|
# Initialize has_more based on whether we have a full page
|
||||||
if len(history_messages) == args.limit:
|
if len(history_messages) == args.limit:
|
||||||
|
|
@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource):
|
||||||
|
|
||||||
message_id = str(args.message_id)
|
message_id = str(args.message_id)
|
||||||
|
|
||||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
message = db.session.scalar(
|
||||||
|
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
count = db.session.scalar(
|
||||||
|
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||||
|
)
|
||||||
|
|
||||||
return {"count": count}
|
return {"count": count}
|
||||||
|
|
||||||
|
|
@ -479,7 +479,9 @@ class MessageApi(Resource):
|
||||||
def get(self, app_model, message_id: str):
|
def get(self, app_model, message_id: str):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
message = db.session.scalar(
|
||||||
|
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,7 @@ class TestMessageEndpoints:
|
||||||
mock_app_model,
|
mock_app_model,
|
||||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
|
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||||
) as (api, mock_db, v_args):
|
) as (api, mock_db, v_args):
|
||||||
mock_db.data_query.where.return_value.first.return_value = None
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
with pytest.raises(NotFound):
|
with pytest.raises(NotFound):
|
||||||
api.get(**v_args)
|
api.get(**v_args)
|
||||||
|
|
@ -198,11 +198,11 @@ class TestMessageEndpoints:
|
||||||
mock_msg.message = {}
|
mock_msg.message = {}
|
||||||
mock_msg.message_metadata_dict = {}
|
mock_msg.message_metadata_dict = {}
|
||||||
|
|
||||||
# mock returns
|
# scalar() is called twice: first for conversation lookup, second for has_more check
|
||||||
q_mock = mock_db.data_query
|
mock_db.session.scalar.side_effect = [mock_conv, False]
|
||||||
q_mock.where.return_value.first.side_effect = [mock_conv]
|
scalars_result = MagicMock()
|
||||||
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
|
scalars_result.all.return_value = [mock_msg]
|
||||||
mock_db.session.scalar.return_value = False
|
mock_db.session.scalars.return_value = scalars_result
|
||||||
|
|
||||||
resp = api.get(**v_args)
|
resp = api.get(**v_args)
|
||||||
assert resp["limit"] == 1
|
assert resp["limit"] == 1
|
||||||
|
|
@ -219,7 +219,7 @@ class TestMessageEndpoints:
|
||||||
mock_app_model,
|
mock_app_model,
|
||||||
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
|
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||||
) as (api, mock_db, v_args):
|
) as (api, mock_db, v_args):
|
||||||
mock_db.data_query.where.return_value.first.return_value = None
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
with pytest.raises(NotFound):
|
with pytest.raises(NotFound):
|
||||||
api.post(**v_args)
|
api.post(**v_args)
|
||||||
|
|
@ -231,7 +231,7 @@ class TestMessageEndpoints:
|
||||||
) as (api, mock_db, v_args):
|
) as (api, mock_db, v_args):
|
||||||
mock_msg = MagicMock()
|
mock_msg = MagicMock()
|
||||||
mock_msg.admin_feedback = None
|
mock_msg.admin_feedback = None
|
||||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
mock_db.session.scalar.return_value = mock_msg
|
||||||
|
|
||||||
resp = api.post(**v_args)
|
resp = api.post(**v_args)
|
||||||
assert resp == {"result": "success"}
|
assert resp == {"result": "success"}
|
||||||
|
|
@ -240,7 +240,7 @@ class TestMessageEndpoints:
|
||||||
with setup_test_context(
|
with setup_test_context(
|
||||||
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
|
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
|
||||||
) as (api, mock_db, v_args):
|
) as (api, mock_db, v_args):
|
||||||
mock_db.data_query.where.return_value.count.return_value = 5
|
mock_db.session.scalar.return_value = 5
|
||||||
|
|
||||||
resp = api.get(**v_args)
|
resp = api.get(**v_args)
|
||||||
assert resp == {"count": 5}
|
assert resp == {"count": 5}
|
||||||
|
|
@ -314,7 +314,7 @@ class TestMessageEndpoints:
|
||||||
mock_msg.message = {}
|
mock_msg.message = {}
|
||||||
mock_msg.message_metadata_dict = {}
|
mock_msg.message_metadata_dict = {}
|
||||||
|
|
||||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
mock_db.session.scalar.return_value = mock_msg
|
||||||
|
|
||||||
resp = api.get(**v_args)
|
resp = api.get(**v_args)
|
||||||
assert resp["id"] == "msg_123"
|
assert resp["id"] == "msg_123"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue