mirror of https://github.com/langgenius/dify.git
merge: resolve conflicts with upstream main
This commit is contained in:
commit
88b6a15170
|
|
@ -10,7 +10,7 @@ import sqlalchemy as sa
|
|||
from flask import request, send_file
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, desc, select
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -211,12 +211,11 @@ class GetProcessRuleApi(Resource):
|
|||
raise Forbidden(str(e))
|
||||
|
||||
# get the latest process rule
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
dataset_process_rule = db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == document.dataset_id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
|
|
@ -330,21 +329,23 @@ class DatasetDocumentListApi(Resource):
|
|||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
|
|
@ -521,10 +522,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
file = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
|
|
@ -586,10 +587,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if file_detail is None:
|
||||
|
|
@ -672,20 +673,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
|
|
@ -723,18 +727,23 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
|
|
@ -1258,11 +1267,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
|||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ def _get_segment_with_summary(segment, dataset_id):
|
|||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
|
|
@ -206,7 +206,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
|
|
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
|
|||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
pipeline = db.session.scalar(
|
||||
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
|
|
|
|||
|
|
@ -153,15 +153,15 @@ class DatasetListApi(DatasetApiResource):
|
|||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: # type: ignore
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) # type: ignore
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # type: ignore
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = False
|
||||
item["embedding_available"] = False # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
|||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
|
||||
for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ dev = [
|
|||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.55.0",
|
||||
"pyrefly>=0.57.1",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
"""Testcontainers integration tests for email register controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.email_register import (
|
||||
EmailRegisterCheckApi,
|
||||
|
|
@ -13,14 +16,11 @@ from services.account_service import AccountService
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
def app(flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
class TestEmailRegisterSendEmailApi:
|
||||
@patch("controllers.console.auth.email_register.sessionmaker")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
|
||||
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
|
||||
|
|
@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi:
|
|||
mock_is_freeze,
|
||||
mock_send_mail,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_is_freeze.return_value = False
|
||||
mock_account = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
|
|
@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi:
|
|||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_is_freeze.assert_called_once_with("invitee@example.com")
|
||||
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
|
@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi:
|
|||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
|
|
@ -114,7 +107,6 @@ class TestEmailRegisterResetApi:
|
|||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.sessionmaker")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
|
|
@ -125,7 +117,6 @@ class TestEmailRegisterResetApi:
|
|||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
|
|
@ -136,14 +127,10 @@ class TestEmailRegisterResetApi:
|
|||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
|
|
@ -159,19 +146,19 @@ class TestEmailRegisterResetApi:
|
|||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
"""Testcontainers integration tests for forgot password controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
|
|
@ -13,14 +16,11 @@ from services.account_service import AccountService
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
def app(flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.console.auth.forgot_password.sessionmaker")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
|
|
@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi:
|
|||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
controller_features = SimpleNamespace(is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch(
|
||||
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
|
||||
return_value=controller_features,
|
||||
|
|
@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=mock_account,
|
||||
email="user@example.com",
|
||||
|
|
@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi:
|
|||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.console.auth.forgot_password.sessionmaker")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
|
|
@ -126,7 +120,6 @@ class TestForgotPasswordResetApi:
|
|||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
|
|
@ -134,12 +127,8 @@ class TestForgotPasswordResetApi:
|
|||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
|
|
@ -157,20 +146,22 @@ class TestForgotPasswordResetApi:
|
|||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("token-123")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
|
@ -1,7 +1,10 @@
|
|||
"""Testcontainers integration tests for OAuth controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import (
|
||||
OAuthCallback,
|
||||
|
|
@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError
|
|||
|
||||
class TestGetOAuthProviders:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("github_config", "google_config", "expected_github", "expected_google"),
|
||||
|
|
@ -64,10 +65,8 @@ class TestOAuthLogin:
|
|||
return OAuthLogin()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider(self):
|
||||
|
|
@ -131,10 +130,8 @@ class TestOAuthCallback:
|
|||
return OAuthCallback()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def oauth_setup(self):
|
||||
|
|
@ -190,15 +187,8 @@ class TestOAuthCallback:
|
|||
(KeyError("Missing key"), "OAuth process failed"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_should_handle_oauth_exceptions(
|
||||
self, mock_get_providers, mock_db, resource, app, exception, expected_error
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
|
||||
def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error):
|
||||
# Import the real requests module to create a proper exception
|
||||
import httpx
|
||||
|
||||
|
|
@ -258,7 +248,6 @@ class TestOAuthCallback:
|
|||
)
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
|
|
@ -269,7 +258,6 @@ class TestOAuthCallback:
|
|||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
resource,
|
||||
|
|
@ -278,10 +266,6 @@ class TestOAuthCallback:
|
|||
account_status,
|
||||
expected_redirect,
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
mock_db.session.commit = MagicMock()
|
||||
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
|
@ -306,14 +290,12 @@ class TestOAuthCallback:
|
|||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
def test_should_activate_pending_account(
|
||||
self,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
|
|
@ -338,12 +320,10 @@ class TestOAuthCallback:
|
|||
|
||||
assert mock_account.status == AccountStatus.ACTIVE
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
|
|
@ -352,7 +332,6 @@ class TestOAuthCallback:
|
|||
mock_redirect,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
|
|
@ -414,6 +393,10 @@ class TestOAuthCallback:
|
|||
|
||||
|
||||
class TestAccountGeneration:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def user_info(self):
|
||||
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||
|
|
@ -425,15 +408,10 @@ class TestAccountGeneration:
|
|||
return account
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.oauth.sessionmaker")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_get_account_by_openid_or_email(
|
||||
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
|
||||
self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account
|
||||
):
|
||||
# Mock db.engine for sessionmaker creation
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Test OpenID found
|
||||
mock_account_model.get_by_openid.return_value = mock_account
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
|
|
@ -443,15 +421,14 @@ class TestAccountGeneration:
|
|||
|
||||
# Test fallback to email lookup
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session.return_value.begin.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
|
||||
mock_get_account.assert_called_once()
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self):
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
mock_session = MagicMock()
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
|
|
@ -462,7 +439,7 @@ class TestAccountGeneration:
|
|||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert result == expected_account
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -478,10 +455,8 @@ class TestAccountGeneration:
|
|||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_handle_account_generation_scenarios(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
|
|
@ -519,10 +494,8 @@ class TestAccountGeneration:
|
|||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_register_with_lowercase_email(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
|
|
@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion:
|
|||
db_session_with_containers.expunge_all()
|
||||
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
|
||||
assert deleted_run is None
|
||||
|
||||
def test_delete_run_dry_run(self, db_session_with_containers):
|
||||
"""Dry run should return success without actually deleting."""
|
||||
tenant_id = str(uuid4())
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
run_id = run.id
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=True)
|
||||
|
||||
result = deleter._delete_run(run)
|
||||
|
||||
assert result.success is True
|
||||
assert result.run_id == run_id
|
||||
# Run should still exist because it's a dry run
|
||||
db_session_with_containers.expire_all()
|
||||
assert db_session_with_containers.get(WorkflowRun, run_id) is not None
|
||||
|
||||
def test_delete_run_exception_returns_error(self, db_session_with_containers):
|
||||
"""Exception during deletion should return failure result."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
tenant_id = str(uuid4())
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=False)
|
||||
|
||||
with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.delete_runs_with_related.side_effect = Exception("Database error")
|
||||
|
||||
result = deleter._delete_run(run)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Database error"
|
||||
|
||||
def test_delete_by_run_id_success(self, db_session_with_containers):
|
||||
"""Successfully delete an archived workflow run by ID."""
|
||||
tenant_id = str(uuid4())
|
||||
base_time = datetime.now(UTC)
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
created_at=base_time,
|
||||
)
|
||||
self._create_archive_log(db_session_with_containers, run=run)
|
||||
run_id = run.id
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
result = deleter.delete_by_run_id(run_id)
|
||||
|
||||
assert result.success is True
|
||||
db_session_with_containers.expunge_all()
|
||||
assert db_session_with_containers.get(WorkflowRun, run_id) is None
|
||||
|
||||
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers):
|
||||
"""_get_workflow_run_repo should return a cached repo on subsequent calls."""
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
|
||||
repo1 = deleter._get_workflow_run_repo()
|
||||
repo2 = deleter._get_workflow_run_repo()
|
||||
|
||||
assert repo1 is repo2
|
||||
assert deleter.workflow_run_repo is repo1
|
||||
|
|
|
|||
|
|
@ -140,8 +140,8 @@ class TestDatasetDocumentListApi:
|
|||
return_value=pagination,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=2,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status",
|
||||
|
|
@ -700,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi:
|
|||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(
|
||||
filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log))
|
||||
),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=log,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
|
|
@ -827,15 +825,12 @@ class TestDocumentIndexingEstimateApi:
|
|||
dataset_process_rule=None,
|
||||
)
|
||||
|
||||
query_mock = MagicMock()
|
||||
query_mock.where.return_value.first.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(api, "get_document", return_value=document),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=query_mock,
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
|
|
@ -863,10 +858,8 @@ class TestDocumentIndexingEstimateApi:
|
|||
app.test_request_context("/"),
|
||||
patch.object(api, "get_document", return_value=document),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(
|
||||
where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file)))
|
||||
),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.ExtractSetting",
|
||||
|
|
@ -1239,12 +1232,8 @@ class TestDocumentPermissionCases:
|
|||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(
|
||||
where=lambda *a: MagicMock(
|
||||
order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule))
|
||||
)
|
||||
),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=process_rule,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
|
@ -1364,8 +1353,8 @@ class TestDocumentIndexingEdgeCases:
|
|||
app.test_request_context("/"),
|
||||
patch.object(api, "get_document", return_value=document),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.ExtractSetting",
|
||||
|
|
|
|||
|
|
@ -26,12 +26,9 @@ class TestGetRagPipeline:
|
|||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
"controllers.console.datasets.wraps.db.session.scalar",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(PipelineNotFoundError):
|
||||
|
|
@ -51,12 +48,9 @@ class TestGetRagPipeline:
|
|||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
"controllers.console.datasets.wraps.db.session.scalar",
|
||||
return_value=pipeline,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
|
@ -76,12 +70,9 @@ class TestGetRagPipeline:
|
|||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
"controllers.console.datasets.wraps.db.session.scalar",
|
||||
return_value=pipeline,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
|
@ -100,18 +91,15 @@ class TestGetRagPipeline:
|
|||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
def where_side_effect(*args, **kwargs):
|
||||
assert args[0].right.value == "123"
|
||||
return Mock(first=lambda: pipeline)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.side_effect = where_side_effect
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
mock_scalar = mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.scalar",
|
||||
return_value=pipeline,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id=123)
|
||||
|
||||
assert result is pipeline
|
||||
# Verify the pipeline_id was cast to string in the where clause
|
||||
stmt = mock_scalar.call_args[0][0]
|
||||
where_clauses = stmt.whereclause.clauses
|
||||
assert where_clauses[0].right.value == "123"
|
||||
|
|
|
|||
|
|
@ -1,216 +0,0 @@
|
|||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowRun
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult
|
||||
|
||||
|
||||
class TestArchivedWorkflowRunDeletion:
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessionmaker(self):
|
||||
with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm:
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_sm.return_value.return_value.__enter__.return_value = mock_session
|
||||
yield mock_sm, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run_repo(self):
|
||||
with patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository"
|
||||
) as mock_repo_cls:
|
||||
mock_repo = MagicMock()
|
||||
yield mock_repo
|
||||
|
||||
def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker):
|
||||
mock_sm, mock_session = mock_sessionmaker
|
||||
run_id = "run-123"
|
||||
tenant_id = "tenant-456"
|
||||
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = run_id
|
||||
mock_run.tenant_id = tenant_id
|
||||
mock_session.get.return_value = mock_run
|
||||
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.get_archived_run_ids.return_value = [run_id]
|
||||
|
||||
with patch.object(deletion, "_delete_run") as mock_delete_run:
|
||||
expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True)
|
||||
mock_delete_run.return_value = expected_result
|
||||
|
||||
result = deletion.delete_by_run_id(run_id)
|
||||
|
||||
assert result == expected_result
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, run_id)
|
||||
mock_repo.get_archived_run_ids.assert_called_once()
|
||||
mock_delete_run.assert_called_once_with(mock_run)
|
||||
|
||||
def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker):
|
||||
mock_sm, mock_session = mock_sessionmaker
|
||||
run_id = "run-123"
|
||||
mock_session.get.return_value = None
|
||||
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
with patch.object(deletion, "_get_workflow_run_repo"):
|
||||
result = deletion.delete_by_run_id(run_id)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.error
|
||||
assert result.run_id == run_id
|
||||
|
||||
def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker):
|
||||
mock_sm, mock_session = mock_sessionmaker
|
||||
run_id = "run-123"
|
||||
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = run_id
|
||||
mock_session.get.return_value = mock_run
|
||||
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.get_archived_run_ids.return_value = []
|
||||
|
||||
result = deletion.delete_by_run_id(run_id)
|
||||
|
||||
assert result.success is False
|
||||
assert "is not archived" in result.error
|
||||
|
||||
def test_delete_batch(self, mock_db, mock_sessionmaker):
|
||||
mock_sm, mock_session = mock_sessionmaker
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
|
||||
mock_run1 = MagicMock(spec=WorkflowRun)
|
||||
mock_run1.id = "run-1"
|
||||
mock_run2 = MagicMock(spec=WorkflowRun)
|
||||
mock_run2.id = "run-2"
|
||||
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2]
|
||||
|
||||
with patch.object(deletion, "_delete_run") as mock_delete_run:
|
||||
mock_delete_run.side_effect = [
|
||||
DeleteResult(run_id="run-1", tenant_id="t1", success=True),
|
||||
DeleteResult(run_id="run-2", tenant_id="t1", success=True),
|
||||
]
|
||||
|
||||
results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now())
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].run_id == "run-1"
|
||||
assert results[1].run_id == "run-2"
|
||||
assert mock_delete_run.call_count == 2
|
||||
|
||||
def test_delete_run_dry_run(self):
|
||||
deletion = ArchivedWorkflowRunDeletion(dry_run=True)
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = "run-123"
|
||||
mock_run.tenant_id = "tenant-456"
|
||||
|
||||
result = deletion._delete_run(mock_run)
|
||||
|
||||
assert result.success is True
|
||||
assert result.run_id == "run-123"
|
||||
|
||||
def test_delete_run_success(self):
|
||||
deletion = ArchivedWorkflowRunDeletion(dry_run=False)
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = "run-123"
|
||||
mock_run.tenant_id = "tenant-456"
|
||||
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1}
|
||||
|
||||
result = deletion._delete_run(mock_run)
|
||||
|
||||
assert result.success is True
|
||||
assert result.deleted_counts == {"workflow_runs": 1}
|
||||
|
||||
def test_delete_run_exception(self):
|
||||
deletion = ArchivedWorkflowRunDeletion(dry_run=False)
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = "run-123"
|
||||
|
||||
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_get_repo.return_value = mock_repo
|
||||
mock_repo.delete_runs_with_related.side_effect = Exception("Database error")
|
||||
|
||||
result = deletion._delete_run(mock_run)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "Database error"
|
||||
|
||||
def test_delete_trigger_logs(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
run_ids = ["run-1", "run-2"]
|
||||
|
||||
with patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository"
|
||||
) as mock_repo_cls:
|
||||
mock_repo = MagicMock()
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
mock_repo.delete_by_run_ids.return_value = 5
|
||||
|
||||
count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids)
|
||||
|
||||
assert count == 5
|
||||
mock_repo_cls.assert_called_once_with(mock_session)
|
||||
mock_repo.delete_by_run_ids.assert_called_once_with(run_ids)
|
||||
|
||||
def test_delete_node_executions(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_run = MagicMock(spec=WorkflowRun)
|
||||
mock_run.id = "run-1"
|
||||
runs = [mock_run]
|
||||
|
||||
with patch(
|
||||
"repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository"
|
||||
) as mock_create_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_create_repo.return_value = mock_repo
|
||||
mock_repo.delete_by_runs.return_value = (1, 2)
|
||||
|
||||
with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm:
|
||||
result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs)
|
||||
|
||||
assert result == (1, 2)
|
||||
mock_create_repo.assert_called_once()
|
||||
mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"])
|
||||
|
||||
def test_get_workflow_run_repo(self, mock_db):
|
||||
deletion = ArchivedWorkflowRunDeletion()
|
||||
|
||||
with patch(
|
||||
"repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository"
|
||||
) as mock_create_repo:
|
||||
mock_repo = MagicMock()
|
||||
mock_create_repo.return_value = mock_repo
|
||||
|
||||
# First call
|
||||
repo1 = deletion._get_workflow_run_repo()
|
||||
assert repo1 == mock_repo
|
||||
assert deletion.workflow_run_repo == mock_repo
|
||||
|
||||
# Second call (should return cached)
|
||||
repo2 = deletion._get_workflow_run_repo()
|
||||
assert repo2 == mock_repo
|
||||
mock_create_repo.assert_called_once()
|
||||
|
|
@ -1,346 +0,0 @@
|
|||
"""
|
||||
Unit tests for services.agent_service
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models import Account
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
def _make_current_user_account(timezone: str = "UTC") -> Account:
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.timezone = timezone
|
||||
return account
|
||||
|
||||
|
||||
def _make_app_model(app_model_config: MagicMock | None) -> MagicMock:
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app-123"
|
||||
app_model.tenant_id = "tenant-123"
|
||||
app_model.app_model_config = app_model_config
|
||||
return app_model
|
||||
|
||||
|
||||
def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock:
|
||||
conversation = MagicMock(spec=Conversation)
|
||||
conversation.id = "conv-123"
|
||||
conversation.app_id = "app-123"
|
||||
conversation.from_end_user_id = from_end_user_id
|
||||
conversation.from_account_id = from_account_id
|
||||
return conversation
|
||||
|
||||
|
||||
def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock:
|
||||
message = MagicMock(spec=Message)
|
||||
message.id = "msg-123"
|
||||
message.conversation_id = "conv-123"
|
||||
message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC)
|
||||
message.provider_response_latency = 1.23
|
||||
message.answer_tokens = 4
|
||||
message.message_tokens = 6
|
||||
message.agent_thoughts = agent_thoughts
|
||||
message.message_files = ["file-a.txt"]
|
||||
return message
|
||||
|
||||
|
||||
def _make_agent_thought() -> MagicMock:
|
||||
agent_thought = MagicMock(spec=MessageAgentThought)
|
||||
agent_thought.tokens = 3
|
||||
agent_thought.tool_input = "raw-input"
|
||||
agent_thought.observation = "raw-output"
|
||||
agent_thought.thought = "thinking"
|
||||
agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC)
|
||||
agent_thought.files = []
|
||||
agent_thought.tools = ["tool_a", "dataset_tool"]
|
||||
agent_thought.tool_labels = {"tool_a": "Tool A"}
|
||||
agent_thought.tool_meta = {
|
||||
"tool_a": {
|
||||
"tool_config": {
|
||||
"tool_provider_type": "custom",
|
||||
"tool_provider": "provider-1",
|
||||
},
|
||||
"tool_parameters": {"param": "value"},
|
||||
"time_cost": 2.5,
|
||||
},
|
||||
"dataset_tool": {
|
||||
"tool_config": {
|
||||
"tool_provider_type": "dataset-retrieval",
|
||||
"tool_provider": "dataset-provider",
|
||||
}
|
||||
},
|
||||
}
|
||||
agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}}
|
||||
agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}}
|
||||
return agent_thought
|
||||
|
||||
|
||||
def _build_query_side_effect(
|
||||
conversation: Conversation | None,
|
||||
message: Message | None,
|
||||
executor: EndUser | Account | None,
|
||||
) -> Callable[..., MagicMock]:
|
||||
def _query_side_effect(*args: object, **kwargs: object) -> MagicMock:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
if any(arg is Conversation for arg in args):
|
||||
query.first.return_value = conversation
|
||||
elif any(arg is Message for arg in args):
|
||||
query.first.return_value = message
|
||||
elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args):
|
||||
query.first.return_value = executor
|
||||
return query
|
||||
|
||||
return _query_side_effect
|
||||
|
||||
|
||||
class TestAgentServiceGetAgentLogs:
|
||||
"""Test suite for AgentService.get_agent_logs."""
|
||||
|
||||
def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None:
|
||||
"""Test missing conversation raises ValueError."""
|
||||
# Arrange
|
||||
app_model = _make_app_model(MagicMock())
|
||||
with patch("services.agent_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_logs(app_model, "missing-conv", "msg-1")
|
||||
|
||||
def test_get_agent_logs_should_raise_when_message_missing(self) -> None:
|
||||
"""Test missing message raises ValueError."""
|
||||
# Arrange
|
||||
app_model = _make_app_model(MagicMock())
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
with patch("services.agent_service.db") as mock_db:
|
||||
conversation_query = MagicMock()
|
||||
conversation_query.where.return_value = conversation_query
|
||||
conversation_query.first.return_value = conversation
|
||||
|
||||
message_query = MagicMock()
|
||||
message_query.where.return_value = message_query
|
||||
message_query.first.return_value = None
|
||||
|
||||
mock_db.session.query.side_effect = [conversation_query, message_query]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_logs(app_model, conversation.id, "missing-msg")
|
||||
|
||||
def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None:
|
||||
"""Test missing app model config raises ValueError."""
|
||||
# Arrange
|
||||
app_model = _make_app_model(None)
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
message = _make_message([])
|
||||
current_user = _make_current_user_account()
|
||||
|
||||
with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock())
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None:
|
||||
"""Test missing agent config raises ValueError."""
|
||||
# Arrange
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.agent_mode_dict = {"strategy": "react"}
|
||||
app_model_config.to_dict.return_value = {"tools": []}
|
||||
app_model = _make_app_model(app_model_config)
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
message = _make_message([])
|
||||
current_user = _make_current_user_account()
|
||||
|
||||
with (
|
||||
patch("services.agent_service.db") as mock_db,
|
||||
patch("services.agent_service.AgentConfigManager.convert", return_value=None),
|
||||
patch("services.agent_service.current_user", current_user),
|
||||
):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock())
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None:
|
||||
"""Test agent logs returned for end-user executor with tool icons."""
|
||||
# Arrange
|
||||
agent_thought = _make_agent_thought()
|
||||
message = _make_message([agent_thought])
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
executor = MagicMock(spec=EndUser)
|
||||
executor.name = "End User"
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.agent_mode_dict = {"strategy": "react"}
|
||||
app_model_config.to_dict.return_value = {"tools": []}
|
||||
app_model = _make_app_model(app_model_config)
|
||||
current_user = _make_current_user_account()
|
||||
agent_tool = MagicMock()
|
||||
agent_tool.tool_name = "tool_a"
|
||||
agent_tool.provider_type = "custom"
|
||||
agent_tool.provider_id = "provider-2"
|
||||
agent_config = MagicMock()
|
||||
agent_config.tools = [agent_tool]
|
||||
|
||||
with (
|
||||
patch("services.agent_service.db") as mock_db,
|
||||
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert,
|
||||
patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon,
|
||||
patch("services.agent_service.current_user", current_user),
|
||||
):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor)
|
||||
mock_get_icon.side_effect = [None, "icon-a"]
|
||||
|
||||
# Act
|
||||
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
# Assert
|
||||
assert result["meta"]["status"] == "success"
|
||||
assert result["meta"]["executor"] == "End User"
|
||||
assert result["meta"]["total_tokens"] == 10
|
||||
assert result["meta"]["agent_mode"] == "react"
|
||||
assert result["meta"]["iterations"] == 1
|
||||
assert result["files"] == ["file-a.txt"]
|
||||
assert len(result["iterations"]) == 1
|
||||
tool_calls = result["iterations"][0]["tool_calls"]
|
||||
assert tool_calls[0]["tool_name"] == "tool_a"
|
||||
assert tool_calls[0]["tool_icon"] == "icon-a"
|
||||
assert tool_calls[1]["tool_name"] == "dataset_tool"
|
||||
assert tool_calls[1]["tool_icon"] == ""
|
||||
mock_convert.assert_called_once()
|
||||
|
||||
def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None:
|
||||
"""Test agent logs fall back to account executor when end user is missing."""
|
||||
# Arrange
|
||||
agent_thought = _make_agent_thought()
|
||||
message = _make_message([agent_thought])
|
||||
conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1")
|
||||
executor = MagicMock(spec=Account)
|
||||
executor.name = "Account User"
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.agent_mode_dict = {"strategy": "react"}
|
||||
app_model_config.to_dict.return_value = {"tools": []}
|
||||
app_model = _make_app_model(app_model_config)
|
||||
current_user = _make_current_user_account()
|
||||
agent_config = MagicMock()
|
||||
agent_config.tools = []
|
||||
|
||||
with (
|
||||
patch("services.agent_service.db") as mock_db,
|
||||
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config),
|
||||
patch("services.agent_service.ToolManager.get_tool_icon", return_value=""),
|
||||
patch("services.agent_service.current_user", current_user),
|
||||
):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor)
|
||||
|
||||
# Act
|
||||
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
# Assert
|
||||
assert result["meta"]["executor"] == "Account User"
|
||||
|
||||
def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None:
|
||||
"""Test unknown executor and missing tool details fall back to defaults."""
|
||||
# Arrange
|
||||
agent_thought = _make_agent_thought()
|
||||
agent_thought.tool_labels = {}
|
||||
agent_thought.tool_inputs_dict = {}
|
||||
agent_thought.tool_outputs_dict = None
|
||||
agent_thought.tool_meta = {"tool_a": {"error": "failed"}}
|
||||
agent_thought.tools = ["tool_a"]
|
||||
|
||||
message = _make_message([agent_thought])
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.agent_mode_dict = {}
|
||||
app_model_config.to_dict.return_value = {"tools": []}
|
||||
app_model = _make_app_model(app_model_config)
|
||||
current_user = _make_current_user_account()
|
||||
agent_config = MagicMock()
|
||||
agent_config.tools = []
|
||||
|
||||
with (
|
||||
patch("services.agent_service.db") as mock_db,
|
||||
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config),
|
||||
patch("services.agent_service.ToolManager.get_tool_icon", return_value=None),
|
||||
patch("services.agent_service.current_user", current_user),
|
||||
):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None)
|
||||
|
||||
# Act
|
||||
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
# Assert
|
||||
assert result["meta"]["executor"] == "Unknown"
|
||||
assert result["meta"]["agent_mode"] == "react"
|
||||
tool_call = result["iterations"][0]["tool_calls"][0]
|
||||
assert tool_call["status"] == "error"
|
||||
assert tool_call["error"] == "failed"
|
||||
assert tool_call["tool_label"] == "tool_a"
|
||||
assert tool_call["tool_input"] == {}
|
||||
assert tool_call["tool_output"] == {}
|
||||
assert tool_call["time_cost"] == 0
|
||||
assert tool_call["tool_parameters"] == {}
|
||||
assert tool_call["tool_icon"] is None
|
||||
|
||||
|
||||
class TestAgentServiceProviders:
|
||||
"""Test suite for AgentService provider methods."""
|
||||
|
||||
def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None:
|
||||
"""Test list_agent_providers delegates to PluginAgentClient."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
expected = [{"name": "provider"}]
|
||||
with patch("services.agent_service.PluginAgentClient") as mock_client:
|
||||
mock_client.return_value.fetch_agent_strategy_providers.return_value = expected
|
||||
|
||||
# Act
|
||||
result = AgentService.list_agent_providers("user-1", tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_agent_provider_should_return_provider_when_successful(self) -> None:
|
||||
"""Test get_agent_provider returns provider when successful."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
provider_name = "provider-a"
|
||||
expected = {"name": provider_name}
|
||||
with patch("services.agent_service.PluginAgentClient") as mock_client:
|
||||
mock_client.return_value.fetch_agent_strategy_provider.return_value = expected
|
||||
|
||||
# Act
|
||||
result = AgentService.get_agent_provider("user-1", tenant_id, provider_name)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name)
|
||||
|
||||
def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None:
|
||||
"""Test get_agent_provider wraps PluginDaemonClientSideError into ValueError."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
provider_name = "provider-a"
|
||||
with patch("services.agent_service.PluginAgentClient") as mock_client:
|
||||
mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError(
|
||||
"plugin error"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_provider("user-1", tenant_id, provider_name)
|
||||
|
|
@ -1,626 +0,0 @@
|
|||
import csv
|
||||
import io
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
||||
class TestFeedbackServiceFactory:
|
||||
"""Factory class for creating test data and mock objects for feedback service tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_feedback_mock(
|
||||
feedback_id: str = "feedback-123",
|
||||
app_id: str = "app-456",
|
||||
conversation_id: str = "conv-789",
|
||||
message_id: str = "msg-001",
|
||||
rating: str = "like",
|
||||
content: str | None = "Great response!",
|
||||
from_source: str = "user",
|
||||
from_account_id: str | None = None,
|
||||
from_end_user_id: str | None = "end-user-001",
|
||||
created_at: datetime | None = None,
|
||||
) -> MagicMock:
|
||||
"""Create a mock MessageFeedback object."""
|
||||
feedback = MagicMock()
|
||||
feedback.id = feedback_id
|
||||
feedback.app_id = app_id
|
||||
feedback.conversation_id = conversation_id
|
||||
feedback.message_id = message_id
|
||||
feedback.rating = rating
|
||||
feedback.content = content
|
||||
feedback.from_source = from_source
|
||||
feedback.from_account_id = from_account_id
|
||||
feedback.from_end_user_id = from_end_user_id
|
||||
feedback.created_at = created_at or datetime.now()
|
||||
return feedback
|
||||
|
||||
@staticmethod
|
||||
def create_message_mock(
|
||||
message_id: str = "msg-001",
|
||||
query: str = "What is AI?",
|
||||
answer: str = "AI stands for Artificial Intelligence.",
|
||||
inputs: dict | None = None,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
"""Create a mock Message object."""
|
||||
|
||||
# Create a simple object with instance attributes
|
||||
# Using a class with __init__ ensures attributes are instance attributes
|
||||
class Message:
|
||||
def __init__(self):
|
||||
self.id = message_id
|
||||
self.query = query
|
||||
self.answer = answer
|
||||
self.inputs = inputs
|
||||
self.created_at = created_at or datetime.now()
|
||||
|
||||
return Message()
|
||||
|
||||
@staticmethod
|
||||
def create_conversation_mock(
|
||||
conversation_id: str = "conv-789",
|
||||
name: str | None = "Test Conversation",
|
||||
) -> MagicMock:
|
||||
"""Create a mock Conversation object."""
|
||||
conversation = MagicMock()
|
||||
conversation.id = conversation_id
|
||||
conversation.name = name
|
||||
return conversation
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(
|
||||
app_id: str = "app-456",
|
||||
name: str = "Test App",
|
||||
) -> MagicMock:
|
||||
"""Create a mock App object."""
|
||||
app = MagicMock()
|
||||
app.id = app_id
|
||||
app.name = name
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(
|
||||
account_id: str = "account-123",
|
||||
name: str = "Test Admin",
|
||||
) -> MagicMock:
|
||||
"""Create a mock Account object."""
|
||||
account = MagicMock()
|
||||
account.id = account_id
|
||||
account.name = name
|
||||
return account
|
||||
|
||||
|
||||
class TestFeedbackService:
|
||||
"""
|
||||
Comprehensive unit tests for FeedbackService.
|
||||
|
||||
This test suite covers:
|
||||
- CSV and JSON export formats
|
||||
- All filter combinations
|
||||
- Edge cases and error handling
|
||||
- Response validation
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
"""Provide test data factory."""
|
||||
return TestFeedbackServiceFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_feedback_data(self, factory):
|
||||
"""Create sample feedback data for testing."""
|
||||
feedback = factory.create_feedback_mock(
|
||||
rating="like",
|
||||
content="Excellent answer!",
|
||||
from_source="user",
|
||||
)
|
||||
message = factory.create_message_mock(
|
||||
query="What is Python?",
|
||||
answer="Python is a programming language.",
|
||||
)
|
||||
conversation = factory.create_conversation_mock(name="Python Discussion")
|
||||
app = factory.create_app_mock(name="AI Assistant")
|
||||
account = factory.create_account_mock(name="Admin User")
|
||||
|
||||
return [(feedback, message, conversation, app, account)]
|
||||
|
||||
# Test 01: CSV Export - Basic Functionality
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_csv_basic(self, mock_db, factory, sample_feedback_data):
|
||||
"""Test basic CSV export with single feedback record."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
# Configure the mock to return itself for all chaining methods
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = sample_feedback_data
|
||||
|
||||
# Set up the session.query to return our mock
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv")
|
||||
|
||||
# Assert
|
||||
assert response.mimetype == "text/csv"
|
||||
assert "charset=utf-8-sig" in response.content_type
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
assert "dify_feedback_export_app-456" in response.headers["Content-Disposition"]
|
||||
|
||||
# Verify CSV content
|
||||
csv_content = response.get_data(as_text=True)
|
||||
reader = csv.DictReader(io.StringIO(csv_content))
|
||||
rows = list(reader)
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["feedback_rating"] == "👍"
|
||||
assert rows[0]["feedback_rating_raw"] == "like"
|
||||
assert rows[0]["feedback_comment"] == "Excellent answer!"
|
||||
assert rows[0]["user_query"] == "What is Python?"
|
||||
assert rows[0]["ai_response"] == "Python is a programming language."
|
||||
|
||||
# Test 02: JSON Export - Basic Functionality
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_json_basic(self, mock_db, factory, sample_feedback_data):
|
||||
"""Test basic JSON export with metadata structure."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
# Configure the mock to return itself for all chaining methods
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = sample_feedback_data
|
||||
|
||||
# Set up the session.query to return our mock
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
|
||||
|
||||
# Assert
|
||||
assert response.mimetype == "application/json"
|
||||
assert "charset=utf-8" in response.content_type
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
|
||||
# Verify JSON structure
|
||||
json_content = json.loads(response.get_data(as_text=True))
|
||||
assert "export_info" in json_content
|
||||
assert "feedback_data" in json_content
|
||||
assert json_content["export_info"]["app_id"] == "app-456"
|
||||
assert json_content["export_info"]["total_records"] == 1
|
||||
assert len(json_content["feedback_data"]) == 1
|
||||
|
||||
# Test 03: Filter by from_source
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_filter_from_source(self, mock_db, factory):
|
||||
"""Test filtering by feedback source (user/admin)."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
FeedbackService.export_feedbacks(app_id="app-456", from_source="admin")
|
||||
|
||||
# Assert
|
||||
mock_query.filter.assert_called()
|
||||
|
||||
# Test 04: Filter by rating
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_filter_rating(self, mock_db, factory):
|
||||
"""Test filtering by rating (like/dislike)."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
FeedbackService.export_feedbacks(app_id="app-456", rating="dislike")
|
||||
|
||||
# Assert
|
||||
mock_query.filter.assert_called()
|
||||
|
||||
# Test 05: Filter by has_comment (True)
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_filter_has_comment_true(self, mock_db, factory):
|
||||
"""Test filtering for feedback with comments."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
FeedbackService.export_feedbacks(app_id="app-456", has_comment=True)
|
||||
|
||||
# Assert
|
||||
mock_query.filter.assert_called()
|
||||
|
||||
# Test 06: Filter by has_comment (False)
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_filter_has_comment_false(self, mock_db, factory):
|
||||
"""Test filtering for feedback without comments."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
FeedbackService.export_feedbacks(app_id="app-456", has_comment=False)
|
||||
|
||||
# Assert
|
||||
mock_query.filter.assert_called()
|
||||
|
||||
# Test 07: Filter by date range
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_filter_date_range(self, mock_db, factory):
|
||||
"""Test filtering by start and end dates."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
FeedbackService.export_feedbacks(
|
||||
app_id="app-456",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_query.filter.call_count >= 2 # Called for both start and end dates
|
||||
|
||||
# Test 08: Invalid date format - start_date
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_invalid_start_date(self, mock_db):
|
||||
"""Test error handling for invalid start_date format."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid start_date format"):
|
||||
FeedbackService.export_feedbacks(app_id="app-456", start_date="invalid-date")
|
||||
|
||||
# Test 09: Invalid date format - end_date
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_invalid_end_date(self, mock_db):
|
||||
"""Test error handling for invalid end_date format."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid end_date format"):
|
||||
FeedbackService.export_feedbacks(app_id="app-456", end_date="2024-13-45")
|
||||
|
||||
# Test 10: Unsupported format
|
||||
def test_export_feedbacks_unsupported_format(self):
|
||||
"""Test error handling for unsupported export format."""
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Unsupported format"):
|
||||
FeedbackService.export_feedbacks(app_id="app-456", format_type="xml")
|
||||
|
||||
# Test 11: Empty result set - CSV
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_empty_results_csv(self, mock_db):
|
||||
"""Test CSV export with no feedback records."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv")
|
||||
|
||||
# Assert
|
||||
csv_content = response.get_data(as_text=True)
|
||||
reader = csv.DictReader(io.StringIO(csv_content))
|
||||
rows = list(reader)
|
||||
assert len(rows) == 0
|
||||
# But headers should still be present
|
||||
assert reader.fieldnames is not None
|
||||
|
||||
# Test 12: Empty result set - JSON
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_empty_results_json(self, mock_db):
|
||||
"""Test JSON export with no feedback records."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
|
||||
|
||||
# Assert
|
||||
json_content = json.loads(response.get_data(as_text=True))
|
||||
assert json_content["export_info"]["total_records"] == 0
|
||||
assert len(json_content["feedback_data"]) == 0
|
||||
|
||||
# Test 13: Long response truncation
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_long_response_truncation(self, mock_db, factory):
|
||||
"""Test that long AI responses are truncated to 500 characters."""
|
||||
# Arrange
|
||||
long_answer = "A" * 600 # 600 characters
|
||||
feedback = factory.create_feedback_mock()
|
||||
message = factory.create_message_mock(answer=long_answer)
|
||||
conversation = factory.create_conversation_mock()
|
||||
app = factory.create_app_mock()
|
||||
account = factory.create_account_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
|
||||
|
||||
# Assert
|
||||
json_content = json.loads(response.get_data(as_text=True))
|
||||
ai_response = json_content["feedback_data"][0]["ai_response"]
|
||||
assert len(ai_response) == 503 # 500 + "..."
|
||||
assert ai_response.endswith("...")
|
||||
|
||||
# Test 14: Null account (end user feedback)
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_null_account(self, mock_db, factory):
|
||||
"""Test handling of feedback from end users (no account)."""
|
||||
# Arrange
|
||||
feedback = factory.create_feedback_mock(from_account_id=None)
|
||||
message = factory.create_message_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
app = factory.create_app_mock()
|
||||
account = None # No account for end user
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
|
||||
|
||||
# Assert
|
||||
json_content = json.loads(response.get_data(as_text=True))
|
||||
assert json_content["feedback_data"][0]["from_account_name"] == ""
|
||||
|
||||
# Test 15: Null conversation name
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_null_conversation_name(self, mock_db, factory):
|
||||
"""Test handling of conversations without names."""
|
||||
# Arrange
|
||||
feedback = factory.create_feedback_mock()
|
||||
message = factory.create_message_mock()
|
||||
conversation = factory.create_conversation_mock(name=None)
|
||||
app = factory.create_app_mock()
|
||||
account = factory.create_account_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
|
||||
|
||||
# Assert
|
||||
json_content = json.loads(response.get_data(as_text=True))
|
||||
assert json_content["feedback_data"][0]["conversation_name"] == ""
|
||||
|
||||
# Test 16: Dislike rating emoji
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_dislike_rating(self, mock_db, factory):
|
||||
"""Test that dislike rating shows thumbs down emoji."""
|
||||
# Arrange
|
||||
feedback = factory.create_feedback_mock(rating="dislike")
|
||||
message = factory.create_message_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
app = factory.create_app_mock()
|
||||
account = factory.create_account_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
|
||||
|
||||
# Assert
|
||||
json_content = json.loads(response.get_data(as_text=True))
|
||||
assert json_content["feedback_data"][0]["feedback_rating"] == "👎"
|
||||
assert json_content["feedback_data"][0]["feedback_rating_raw"] == "dislike"
|
||||
|
||||
# Test 17: Combined filters
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_combined_filters(self, mock_db, factory):
|
||||
"""Test applying multiple filters simultaneously."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Act
|
||||
FeedbackService.export_feedbacks(
|
||||
app_id="app-456",
|
||||
from_source="admin",
|
||||
rating="like",
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should have called filter multiple times for each condition
|
||||
assert mock_query.filter.call_count >= 4
|
||||
|
||||
# Test 18: Message query fallback to inputs
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_message_query_from_inputs(self, mock_db, factory):
|
||||
"""Test fallback to inputs.query when message.query is None."""
|
||||
# Arrange
|
||||
feedback = factory.create_feedback_mock()
|
||||
message = factory.create_message_mock(query=None, inputs={"query": "Query from inputs"})
|
||||
conversation = factory.create_conversation_mock()
|
||||
app = factory.create_app_mock()
|
||||
account = factory.create_account_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
|
||||
|
||||
# Assert
|
||||
json_content = json.loads(response.get_data(as_text=True))
|
||||
assert json_content["feedback_data"][0]["user_query"] == "Query from inputs"
|
||||
|
||||
# Test 19: Empty feedback content
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_empty_feedback_content(self, mock_db, factory):
|
||||
"""Test handling of feedback with empty/null content."""
|
||||
# Arrange
|
||||
feedback = factory.create_feedback_mock(content=None)
|
||||
message = factory.create_message_mock()
|
||||
conversation = factory.create_conversation_mock()
|
||||
app = factory.create_app_mock()
|
||||
account = factory.create_account_mock()
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
|
||||
|
||||
# Assert
|
||||
json_content = json.loads(response.get_data(as_text=True))
|
||||
assert json_content["feedback_data"][0]["feedback_comment"] == ""
|
||||
assert json_content["feedback_data"][0]["has_comment"] == "No"
|
||||
|
||||
# Test 20: CSV headers validation
|
||||
@patch("services.feedback_service.db")
|
||||
def test_export_feedbacks_csv_headers(self, mock_db, factory, sample_feedback_data):
|
||||
"""Test that CSV contains all expected headers."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = sample_feedback_data
|
||||
|
||||
expected_headers = [
|
||||
"feedback_id",
|
||||
"app_name",
|
||||
"app_id",
|
||||
"conversation_id",
|
||||
"conversation_name",
|
||||
"message_id",
|
||||
"user_query",
|
||||
"ai_response",
|
||||
"feedback_rating",
|
||||
"feedback_rating_raw",
|
||||
"feedback_comment",
|
||||
"feedback_source",
|
||||
"feedback_date",
|
||||
"message_date",
|
||||
"from_account_name",
|
||||
"from_end_user_id",
|
||||
"has_comment",
|
||||
]
|
||||
|
||||
# Act
|
||||
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv")
|
||||
|
||||
# Assert
|
||||
csv_content = response.get_data(as_text=True)
|
||||
reader = csv.DictReader(io.StringIO(csv_content))
|
||||
assert list(reader.fieldnames) == expected_headers
|
||||
22
api/uv.lock
22
api/uv.lock
|
|
@ -1767,7 +1767,7 @@ dev = [
|
|||
{ name = "lxml-stubs", specifier = "~=0.5.1" },
|
||||
{ name = "mypy", specifier = "~=1.19.1" },
|
||||
{ name = "pandas-stubs", specifier = "~=3.0.0" },
|
||||
{ name = "pyrefly", specifier = ">=0.55.0" },
|
||||
{ name = "pyrefly", specifier = ">=0.57.1" },
|
||||
{ name = "pytest", specifier = "~=9.0.2" },
|
||||
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
|
||||
{ name = "pytest-cov", specifier = "~=7.1.0" },
|
||||
|
|
@ -5402,18 +5402,18 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "pyrefly"
|
||||
version = "0.55.0"
|
||||
version = "0.57.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bf/c4/76e0797215e62d007f81f86c9c4fb5d6202685a3f5e70810f3fd94294f92/pyrefly-0.55.0.tar.gz", hash = "sha256:434c3282532dd4525c4840f2040ed0eb79b0ec8224fe18d957956b15471f2441", size = 5135682, upload-time = "2026-03-03T00:46:38.122Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/39/b0/16e50cf716784513648e23e726a24f71f9544aa4f86103032dcaa5ff71a2/pyrefly-0.55.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:49aafcefe5e2dd4256147db93e5b0ada42bff7d9a60db70e03d1f7055338eec9", size = 12210073, upload-time = "2026-03-03T00:46:15.51Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/ad/89500c01bac3083383011600370289fbc67700c5be46e781787392628a3a/pyrefly-0.55.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2827426e6b28397c13badb93c0ede0fb0f48046a7a89e3d774cda04e8e2067cd", size = 11767474, upload-time = "2026-03-03T00:46:18.003Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/68/4c66b260f817f304ead11176ff13985625f7c269e653304b4bdb546551af/pyrefly-0.55.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7346b2d64dc575bd61aa3bca854fbf8b5a19a471cbdb45e0ca1e09861b63488c", size = 33260395, upload-time = "2026-03-03T00:46:20.509Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/09/10bd48c9f860064f29f412954126a827d60f6451512224912c265e26bbe6/pyrefly-0.55.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:233b861b4cff008b1aff62f4f941577ed752e4d0060834229eb9b6826e6973c9", size = 35848269, upload-time = "2026-03-03T00:46:23.418Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/39/bc65cdd5243eb2dfea25dd1321f9a5a93e8d9c3a308501c4c6c05d011585/pyrefly-0.55.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5aa85657d76da1d25d081a49f0e33c8fc3ec91c1a0f185a8ed393a5a3d9e178", size = 38449820, upload-time = "2026-03-03T00:46:26.309Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/64/58b38963b011af91209e87f868cc85cfc762ec49a4568ce610c45e7a5f40/pyrefly-0.55.0-py3-none-win32.whl", hash = "sha256:23f786a78536a56fed331b245b7d10ec8945bebee7b723491c8d66fdbc155fe6", size = 11259415, upload-time = "2026-03-03T00:46:30.875Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/0b/a4aa519ff632a1ea69eec942566951670b870b99b5c08407e1387b85b6a4/pyrefly-0.55.0-py3-none-win_amd64.whl", hash = "sha256:d465b49e999b50eeb069ad23f0f5710651cad2576f9452a82991bef557df91ee", size = 12043581, upload-time = "2026-03-03T00:46:33.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/51/89017636fbe1ffd166ad478990c6052df615b926182fa6d3c0842b407e89/pyrefly-0.55.0-py3-none-win_arm64.whl", hash = "sha256:732ff490e0e863b296e7c0b2471e08f8ba7952f9fa6e9de09d8347fd67dde77f", size = 11548076, upload-time = "2026-03-03T00:46:36.193Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
Loading…
Reference in New Issue