diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 3beea2a385..4fb73f61f3 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -30,6 +30,7 @@ from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required +from models.enums import FeedbackFromSource, FeedbackRating from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError @@ -335,7 +336,7 @@ class MessageFeedbackApi(Resource): if not args.rating and feedback: db.session.delete(feedback) elif args.rating and feedback: - feedback.rating = args.rating + feedback.rating = FeedbackRating(args.rating) feedback.content = args.content elif not args.rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") @@ -347,9 +348,9 @@ class MessageFeedbackApi(Resource): app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=rating_value, + rating=FeedbackRating(rating_value), content=args.content, - from_source="admin", + from_source=FeedbackFromSource.ADMIN, from_account_id=current_user.id, ) db.session.add(feedback) diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 53970dbd3b..15e1aea361 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -27,6 +27,7 @@ from fields.message_fields import MessageInfiniteScrollPagination, MessageListIt from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant +from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource): app_model=app_model, message_id=message_id, user=current_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 2aaf920efb..77fee9c142 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from libs.helper import UUIDStrOrEmpty +from models.enums import FeedbackRating from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource): app_model=app_model, message_id=message_id, user=end_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 2b60691949..aa56292614 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -25,6 +25,7 @@ from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper from libs.helper import uuid_value +from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource): app_model=app_model, message_id=message_id, user=end_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 6583ba51e9..f7b5030d33 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -76,7 +76,7 @@ from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile -from models.enums import CreatorUserRole, MessageStatus +from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus from models.execution_extra_content import HumanInputContent from models.workflow import Workflow @@ -939,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): type=file["type"], transfer_method=file["transfer_method"], url=file["remote_url"], - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, upload_file_id=file["related_id"], created_by_role=CreatorUserRole.ACCOUNT if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 88714f3837..11fcbb7561 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -40,7 +40,7 @@ from dify_graph.model_runtime.entities.message_entities import ( from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from extensions.ext_database import db -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: @@ -419,7 +419,7 @@ class AppRunner: message_id=message_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, url=f"/files/tools/{tool_file.id}", upload_file_id=tool_file.id, created_by_role=( diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 4e9a191dae..64c28ca60f 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError @@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): message_id=message.id, type=file.type, transfer_method=file.transfer_method, - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, upload_file_id=file.related_id, created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 536ab02eae..62f27060b4 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -34,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator from core.tools.signature import sign_tool_file from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.enums import MessageFileBelongsTo from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService @@ -233,7 +234,7 @@ class MessageCycleManager: task_id=self._application_generate_entity.task_id, id=message_file.id, type=message_file.type, - belongs_to=message_file.belongs_to or "user", + belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER, url=url, ) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 0f0eacbdc4..64212a2636 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -34,7 +34,7 @@ from core.tools.workflow_as_tool.tool import WorkflowTool from dify_graph.file import FileType from dify_graph.file.models import FileTransferMethod from extensions.ext_database import db -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile logger = logging.getLogger(__name__) @@ -352,7 +352,7 @@ class ToolEngine: message_id=agent_message.id, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE, - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, url=message.url, upload_file_id=tool_file_id, created_by_role=( diff --git a/api/models/enums.py b/api/models/enums.py index 6499c5b443..4849099d30 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -158,6 +158,13 @@ class FeedbackFromSource(StrEnum): ADMIN = "admin" +class FeedbackRating(StrEnum): + """MessageFeedback rating""" + + LIKE = "like" + DISLIKE = "dislike" + + class InvokeFrom(StrEnum): """How a conversation/message was invoked""" diff --git a/api/models/model.py b/api/models/model.py index 45d9c501ae..3bd68d1d95 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -36,7 +36,10 @@ from .enums import ( BannerStatus, ConversationStatus, CreatorUserRole, + FeedbackFromSource, + FeedbackRating, MessageChainType, + MessageFileBelongsTo, MessageStatus, ) from .provider_ids import GenericProviderID @@ -1165,7 +1168,7 @@ class Conversation(Base): select(func.count(MessageFeedback.id)).where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "user", - MessageFeedback.rating == "like", + MessageFeedback.rating == FeedbackRating.LIKE, ) ) or 0 @@ -1176,7 +1179,7 @@ class Conversation(Base): select(func.count(MessageFeedback.id)).where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "user", - MessageFeedback.rating == "dislike", + MessageFeedback.rating == FeedbackRating.DISLIKE, ) ) or 0 @@ -1191,7 +1194,7 @@ class Conversation(Base): select(func.count(MessageFeedback.id)).where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "admin", - MessageFeedback.rating == "like", + MessageFeedback.rating == FeedbackRating.LIKE, ) ) or 0 @@ -1202,7 +1205,7 @@ class Conversation(Base): select(func.count(MessageFeedback.id)).where( MessageFeedback.conversation_id == self.id, MessageFeedback.from_source == "admin", - MessageFeedback.rating == "dislike", + MessageFeedback.rating == FeedbackRating.DISLIKE, ) ) or 0 @@ -1725,8 +1728,8 @@ class MessageFeedback(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - rating: Mapped[str] = mapped_column(String(255), nullable=False) - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False) + from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False) content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) @@ -1779,7 +1782,9 @@ class MessageFile(TypeBase): ) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None) + belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column( + EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None + ) url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( diff --git a/api/services/feedback_service.py b/api/services/feedback_service.py index 1a1cbbb450..e7473d371b 100644 --- a/api/services/feedback_service.py +++ b/api/services/feedback_service.py @@ -7,6 +7,7 @@ from flask import Response from sqlalchemy import or_ from extensions.ext_database import db +from models.enums import FeedbackRating from models.model import Account, App, Conversation, Message, MessageFeedback @@ -100,7 +101,7 @@ class FeedbackService: "ai_response": message.answer[:500] + "..." if len(message.answer) > 500 else message.answer, # Truncate long responses - "feedback_rating": "πŸ‘" if feedback.rating == "like" else "πŸ‘Ž", + "feedback_rating": "πŸ‘" if feedback.rating == FeedbackRating.LIKE else "πŸ‘Ž", "feedback_rating_raw": feedback.rating, "feedback_comment": feedback.content or "", "feedback_source": feedback.from_source, diff --git a/api/services/message_service.py b/api/services/message_service.py index 789b6c2f8c..fc87802f51 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -16,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback from repositories.execution_extra_content_repository import ExecutionExtraContentRepository from repositories.sqlalchemy_execution_extra_content_repository import ( @@ -172,7 +173,7 @@ class MessageService: app_model: App, message_id: str, user: Union[Account, EndUser] | None, - rating: str | None, + rating: FeedbackRating | None, content: str | None, ): if not user: @@ -197,7 +198,7 @@ class MessageService: message_id=message.id, rating=rating, content=content, - from_source=("user" if isinstance(user, EndUser) else "admin"), + from_source=(FeedbackFromSource.USER if isinstance(user, EndUser) else FeedbackFromSource.ADMIN), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), ) diff --git a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py index 0f8b42e98b..309a0b015a 100644 --- a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py +++ b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py @@ -14,6 +14,7 @@ from controllers.console.app import wraps from libs.datetime_utils import naive_utc_now from models import App, Tenant from models.account import Account, TenantAccountJoin, TenantAccountRole +from models.enums import FeedbackFromSource, FeedbackRating from models.model import AppMode, MessageFeedback from services.feedback_service import FeedbackService @@ -77,8 +78,8 @@ class TestFeedbackExportApi: app_id=app_id, conversation_id=conversation_id, message_id=message_id, - rating="like", - from_source="user", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, content=None, from_end_user_id=str(uuid.uuid4()), from_account_id=None, @@ -90,8 +91,8 @@ class TestFeedbackExportApi: app_id=app_id, conversation_id=conversation_id, message_id=message_id, - rating="dislike", - from_source="admin", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.ADMIN, content="The response was not helpful", from_end_user_id=None, from_account_id=str(uuid.uuid4()), @@ -277,8 +278,8 @@ class TestFeedbackExportApi: # Verify service was called with correct parameters mock_export_feedbacks.assert_called_once_with( app_id=mock_app_model.id, - from_source="user", - rating="dislike", + from_source=FeedbackFromSource.USER, + rating=FeedbackRating.DISLIKE, has_comment=True, start_date="2024-01-01", end_date="2024-12-31", diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 4759d244fd..ee34b65831 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from core.plugin.impl.exc import PluginDaemonClientSideError from models import Account +from models.enums import MessageFileBelongsTo from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought from services.account_service import AccountService, TenantService from services.agent_service import AgentService @@ -852,7 +853,7 @@ class TestAgentService: type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="http://example.com/file1.jpg", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) @@ -861,7 +862,7 @@ class TestAgentService: type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="http://example.com/file2.png", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index 60919dff0d..771f406775 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -8,6 +8,7 @@ from unittest import mock import pytest from extensions.ext_database import db +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, Conversation, Message from services.feedback_service import FeedbackService @@ -47,8 +48,8 @@ class TestFeedbackService: app_id=app_id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="like", - from_source="user", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, content="Great answer!", from_end_user_id="user-123", from_account_id=None, @@ -61,8 +62,8 @@ class TestFeedbackService: app_id=app_id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="dislike", - from_source="admin", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.ADMIN, content="Could be more detailed", from_end_user_id=None, from_account_id="admin-456", @@ -179,8 +180,8 @@ class TestFeedbackService: # Test with filters result = FeedbackService.export_feedbacks( app_id=sample_data["app"].id, - from_source="admin", - rating="dislike", + from_source=FeedbackFromSource.ADMIN, + rating=FeedbackRating.DISLIKE, has_comment=True, start_date="2024-01-01", end_date="2024-12-31", @@ -293,8 +294,8 @@ class TestFeedbackService: app_id=sample_data["app"].id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="dislike", - from_source="user", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.USER, content="ε›žη­”δΈε€Ÿθ―¦η»†οΌŒιœ€θ¦ζ›΄ε€šδΏ‘ζ―", from_end_user_id="user-123", from_account_id=None, diff --git a/api/tests/test_containers_integration_tests/services/test_message_export_service.py b/api/tests/test_containers_integration_tests/services/test_message_export_service.py index 200f688ae9..805bab9b9d 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_export_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_export_service.py @@ -7,6 +7,7 @@ import pytest from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import FeedbackFromSource, FeedbackRating from models.model import ( App, AppAnnotationHitHistory, @@ -172,8 +173,8 @@ class TestAppMessageExportServiceIntegration: app_id=app.id, conversation_id=conversation.id, message_id=first_message.id, - rating="like", - from_source="user", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, content="first", from_end_user_id=conversation.from_end_user_id, ) @@ -181,8 +182,8 @@ class TestAppMessageExportServiceIntegration: app_id=app.id, conversation_id=conversation.id, message_id=first_message.id, - rating="dislike", - from_source="user", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.USER, content="second", from_end_user_id=conversation.from_end_user_id, ) @@ -190,8 +191,8 @@ class TestAppMessageExportServiceIntegration: app_id=app.id, conversation_id=conversation.id, message_id=first_message.id, - rating="like", - from_source="admin", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, content="should-be-filtered", from_account_id=str(uuid.uuid4()), ) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index a6d7bf27fd..af666a0375 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from models.enums import FeedbackRating from models.model import MessageFeedback from services.app_service import AppService from services.errors.message import ( @@ -405,7 +406,7 @@ class TestMessageService: message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create feedback - rating = "like" + rating = FeedbackRating.LIKE content = fake.text(max_nb_chars=100) feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=rating, content=content @@ -435,7 +436,11 @@ class TestMessageService: # Test creating feedback with no user with pytest.raises(ValueError, match="user cannot be None"): MessageService.create_feedback( - app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100) + app_model=app, + message_id=message.id, + user=None, + rating=FeedbackRating.LIKE, + content=fake.text(max_nb_chars=100), ) def test_create_feedback_update_existing( @@ -452,14 +457,14 @@ class TestMessageService: message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback - initial_rating = "like" + initial_rating = FeedbackRating.LIKE initial_content = fake.text(max_nb_chars=100) feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content ) # Update feedback - updated_rating = "dislike" + updated_rating = FeedbackRating.DISLIKE updated_content = fake.text(max_nb_chars=100) updated_feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content @@ -487,7 +492,11 @@ class TestMessageService: # Create initial feedback feedback = MessageService.create_feedback( - app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100) + app_model=app, + message_id=message.id, + user=account, + rating=FeedbackRating.LIKE, + content=fake.text(max_nb_chars=100), ) # Delete feedback by setting rating to None @@ -538,7 +547,7 @@ class TestMessageService: app_model=app, message_id=message.id, user=account, - rating="like" if i % 2 == 0 else "dislike", + rating=FeedbackRating.LIKE if i % 2 == 0 else FeedbackRating.DISLIKE, content=f"Feedback {i}: {fake.text(max_nb_chars=50)}", ) feedbacks.append(feedback) @@ -568,7 +577,11 @@ class TestMessageService: message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) MessageService.create_feedback( - app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}" + app_model=app, + message_id=message.id, + user=account, + rating=FeedbackRating.LIKE, + content=f"Feedback {i}", ) # Get feedbacks with pagination diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 7b5157fa61..863f013e19 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import DataSourceType, MessageChainType +from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo from models.model import ( App, AppAnnotationHitHistory, @@ -166,7 +166,7 @@ class TestMessagesCleanServiceIntegration: name="Test conversation", inputs={}, status="normal", - from_source="api", + from_source=FeedbackFromSource.USER, from_end_user_id=str(uuid.uuid4()), ) db_session_with_containers.add(conversation) @@ -196,7 +196,7 @@ class TestMessagesCleanServiceIntegration: answer_unit_price=Decimal("0.002"), total_price=Decimal("0.003"), currency="USD", - from_source="api", + from_source=FeedbackFromSource.USER, from_account_id=conversation.from_end_user_id, created_at=created_at, ) @@ -216,8 +216,8 @@ class TestMessagesCleanServiceIntegration: app_id=message.app_id, conversation_id=message.conversation_id, message_id=message.id, - rating="like", - from_source="api", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, from_end_user_id=str(uuid.uuid4()), ) db_session_with_containers.add(feedback) @@ -249,7 +249,7 @@ class TestMessagesCleanServiceIntegration: type="image", transfer_method="local_file", url="http://example.com/test.jpg", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role="end_user", created_by=str(uuid.uuid4()), ) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_message.py b/api/tests/unit_tests/controllers/service_api/app/test_message.py index 4de12de829..c2b8aed1ae 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_message.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_message.py @@ -31,6 +31,7 @@ from controllers.service_api.app.message import ( MessageListQuery, MessageSuggestedApi, ) +from models.enums import FeedbackRating from models.model import App, AppMode, EndUser from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( @@ -310,7 +311,7 @@ class TestMessageService: app_model=Mock(spec=App), message_id=str(uuid.uuid4()), user=Mock(spec=EndUser), - rating="like", + rating=FeedbackRating.LIKE, content="Great response!", ) @@ -326,7 +327,7 @@ class TestMessageService: app_model=Mock(spec=App), message_id="invalid_message_id", user=Mock(spec=EndUser), - rating="like", + rating=FeedbackRating.LIKE, content=None, ) diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index 4b8bdde46b..e7740ef93a 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, AppMode, EndUser, Message from services.errors.message import ( FirstMessageNotExistsError, @@ -820,14 +821,14 @@ class TestMessageServiceFeedback: app_model=app, message_id="msg-123", user=user, - rating="like", + rating=FeedbackRating.LIKE, content="Good answer", ) # Assert - assert result.rating == "like" + assert result.rating == FeedbackRating.LIKE assert result.content == "Good answer" - assert result.from_source == "user" + assert result.from_source == FeedbackFromSource.USER mock_db.session.add.assert_called_once() mock_db.session.commit.assert_called_once() @@ -852,13 +853,13 @@ class TestMessageServiceFeedback: app_model=app, message_id="msg-123", user=user, - rating="dislike", + rating=FeedbackRating.DISLIKE, content="Bad answer", ) # Assert assert result == feedback - assert feedback.rating == "dislike" + assert feedback.rating == FeedbackRating.DISLIKE assert feedback.content == "Bad answer" mock_db.session.commit.assert_called_once()