merge: resolve conflicts with upstream main

This commit is contained in:
Desel72 2026-03-24 14:38:12 +00:00
commit 88b6a15170
16 changed files with 226 additions and 1406 deletions

View File

@ -10,7 +10,7 @@ import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -211,12 +211,11 @@ class GetProcessRuleApi(Resource):
raise Forbidden(str(e))
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
dataset_process_rule = db.session.scalar(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
)
if dataset_process_rule:
mode = dataset_process_rule.mode
@ -330,21 +329,23 @@ class DatasetDocumentListApi(Resource):
if fetch:
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
document.completed_segments = completed_segments
document.total_segments = total_segments
@ -521,10 +522,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file = (
db.session.query(UploadFile)
file = db.session.scalar(
select(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
.limit(1)
)
# raise error if file not found
@ -586,10 +587,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
file_detail = db.session.scalar(
select(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
.limit(1)
)
if file_detail is None:
@ -672,20 +673,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
# Create a dictionary with document attributes and additional fields
document_dict = {
@ -723,18 +727,23 @@ class DocumentIndexingStatusApi(DocumentResource):
document = self.get_document(dataset_id, document_id)
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
.count()
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
or 0
)
# Create a dictionary with document attributes and additional fields
@ -1258,11 +1267,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
log = db.session.scalar(
select(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
.limit(1)
)
if not log:
return {

View File

@ -45,7 +45,7 @@ def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields))
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
@ -206,7 +206,7 @@ class DatasetDocumentSegmentListApi(Resource):
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields))
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)

View File

@ -2,6 +2,8 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
del kwargs["pipeline_id"]
pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
pipeline = db.session.scalar(
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
)
if not pipeline:

View File

@ -153,15 +153,15 @@ class DatasetListApi(DatasetApiResource):
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: # type: ignore
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) # type: ignore
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # type: ignore
if item_model in model_names:
item["embedding_available"] = True
item["embedding_available"] = True # type: ignore
else:
item["embedding_available"] = False
item["embedding_available"] = False # type: ignore
else:
item["embedding_available"] = True
item["embedding_available"] = True # type: ignore
response = {
"data": data,
"has_more": len(datasets) == query.limit,

View File

@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
if field_name == "inputs":
data = {
"messages": [
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
for msg in v
]
if isinstance(v, list)
else v,

View File

@ -174,7 +174,7 @@ dev = [
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.55.0",
"pyrefly>=0.57.1",
]
############################################################

View File

@ -1,8 +1,11 @@
"""Testcontainers integration tests for email register controller endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.email_register import (
EmailRegisterCheckApi,
@ -13,14 +16,11 @@ from services.account_service import AccountService
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
def app(flask_app_with_containers):
return flask_app_with_containers
class TestEmailRegisterSendEmailApi:
@patch("controllers.console.auth.email_register.sessionmaker")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi:
mock_is_freeze,
mock_send_mail,
mock_get_account,
mock_session_cls,
app,
):
mock_send_mail.return_value = "token-123"
mock_is_freeze.return_value = False
mock_account = MagicMock()
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = mock_account
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi:
assert response == {"result": "success", "data": "token-123"}
mock_is_freeze.assert_called_once_with("invitee@example.com")
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
mock_extract_ip.assert_called_once()
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi:
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
@ -114,7 +107,6 @@ class TestEmailRegisterResetApi:
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.sessionmaker")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@ -125,7 +117,6 @@ class TestEmailRegisterResetApi:
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_create_account,
mock_login,
mock_reset_login_rate,
@ -136,14 +127,10 @@ class TestEmailRegisterResetApi:
token_pair = MagicMock()
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
@ -159,19 +146,19 @@ class TestEmailRegisterResetApi:
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
"""Test that case fallback tries lowercase when exact match fails."""
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
second_result = MagicMock()
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert account is expected_account
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -1,8 +1,11 @@
"""Testcontainers integration tests for forgot password controller endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.forgot_password import (
ForgotPasswordCheckApi,
@ -13,14 +16,11 @@ from services.account_service import AccountService
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
def app(flask_app_with_containers):
return flask_app_with_containers
class TestForgotPasswordSendEmailApi:
@patch("controllers.console.auth.forgot_password.sessionmaker")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi:
mock_is_ip_limit,
mock_send_email,
mock_get_account,
mock_session_cls,
app,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
controller_features = SimpleNamespace(is_allow_register=True)
with (
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
patch(
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
return_value=controller_features,
@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi:
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_send_email.assert_called_once_with(
account=mock_account,
email="user@example.com",
@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.console.auth.forgot_password.sessionmaker")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@ -126,7 +120,6 @@ class TestForgotPasswordResetApi:
mock_get_reset_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_update_account,
app,
):
@ -134,12 +127,8 @@ class TestForgotPasswordResetApi:
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
wraps_features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
):
@ -157,20 +146,22 @@ class TestForgotPasswordResetApi:
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("token-123")
mock_revoke_token.assert_called_once_with("token-123")
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_update_account.assert_called_once()
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
"""Test that case fallback tries lowercase when exact match fails."""
from unittest.mock import MagicMock
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
second_result = MagicMock()
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
assert account is expected_account
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -1,7 +1,10 @@
"""Testcontainers integration tests for OAuth controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import (
OAuthCallback,
@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError
class TestGetOAuthProviders:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.mark.parametrize(
("github_config", "google_config", "expected_github", "expected_google"),
@ -64,10 +65,8 @@ class TestOAuthLogin:
return OAuthLogin()
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_oauth_provider(self):
@ -131,10 +130,8 @@ class TestOAuthCallback:
return OAuthCallback()
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def oauth_setup(self):
@ -190,15 +187,8 @@ class TestOAuthCallback:
(KeyError("Missing key"), "OAuth process failed"),
],
)
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_should_handle_oauth_exceptions(
self, mock_get_providers, mock_db, resource, app, exception, expected_error
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error):
# Import the real requests module to create a proper exception
import httpx
@ -258,7 +248,6 @@ class TestOAuthCallback:
)
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@ -269,7 +258,6 @@ class TestOAuthCallback:
mock_generate_account,
mock_get_providers,
mock_config,
mock_db,
mock_tenant_service,
mock_account_service,
resource,
@ -278,10 +266,6 @@ class TestOAuthCallback:
account_status,
expected_redirect,
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
mock_db.session.commit = MagicMock()
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
@ -306,14 +290,12 @@ class TestOAuthCallback:
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService")
def test_should_activate_pending_account(
self,
mock_account_service,
mock_tenant_service,
mock_db,
mock_generate_account,
mock_get_providers,
mock_config,
@ -338,12 +320,10 @@ class TestOAuthCallback:
assert mock_account.status == AccountStatus.ACTIVE
assert mock_account.initialized_at is not None
mock_db.session.commit.assert_called_once()
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.redirect")
@ -352,7 +332,6 @@ class TestOAuthCallback:
mock_redirect,
mock_account_service,
mock_tenant_service,
mock_db,
mock_generate_account,
mock_get_providers,
mock_config,
@ -414,6 +393,10 @@ class TestOAuthCallback:
class TestAccountGeneration:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def user_info(self):
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
@ -425,15 +408,10 @@ class TestAccountGeneration:
return account
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.oauth.sessionmaker")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.db")
def test_should_get_account_by_openid_or_email(
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account
):
# Mock db.engine for sessionmaker creation
mock_db.engine = MagicMock()
# Test OpenID found
mock_account_model.get_by_openid.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info)
@ -443,15 +421,14 @@ class TestAccountGeneration:
# Test fallback to email lookup
mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock()
mock_session.return_value.begin.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
mock_get_account.assert_called_once()
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self):
"""Test that case fallback tries lowercase when exact match fails."""
mock_session = MagicMock()
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
@ -462,7 +439,7 @@ class TestAccountGeneration:
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert result == expected_account
assert result is expected_account
assert mock_session.execute.call_count == 2
@pytest.mark.parametrize(
@ -478,10 +455,8 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_handle_account_generation_scenarios(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,
@ -519,10 +494,8 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_register_with_lowercase_email(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,

View File

@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.expunge_all()
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
assert deleted_run is None
def test_delete_run_dry_run(self, db_session_with_containers):
"""Dry run should return success without actually deleting."""
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=datetime.now(UTC),
)
run_id = run.id
deleter = ArchivedWorkflowRunDeletion(dry_run=True)
result = deleter._delete_run(run)
assert result.success is True
assert result.run_id == run_id
# Run should still exist because it's a dry run
db_session_with_containers.expire_all()
assert db_session_with_containers.get(WorkflowRun, run_id) is not None
def test_delete_run_exception_returns_error(self, db_session_with_containers):
"""Exception during deletion should return failure result."""
from unittest.mock import MagicMock, patch
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=datetime.now(UTC),
)
deleter = ArchivedWorkflowRunDeletion(dry_run=False)
with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo:
mock_repo = MagicMock()
mock_get_repo.return_value = mock_repo
mock_repo.delete_runs_with_related.side_effect = Exception("Database error")
result = deleter._delete_run(run)
assert result.success is False
assert result.error == "Database error"
def test_delete_by_run_id_success(self, db_session_with_containers):
"""Successfully delete an archived workflow run by ID."""
tenant_id = str(uuid4())
base_time = datetime.now(UTC)
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=base_time,
)
self._create_archive_log(db_session_with_containers, run=run)
run_id = run.id
deleter = ArchivedWorkflowRunDeletion()
result = deleter.delete_by_run_id(run_id)
assert result.success is True
db_session_with_containers.expunge_all()
assert db_session_with_containers.get(WorkflowRun, run_id) is None
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers):
"""_get_workflow_run_repo should return a cached repo on subsequent calls."""
deleter = ArchivedWorkflowRunDeletion()
repo1 = deleter._get_workflow_run_repo()
repo2 = deleter._get_workflow_run_repo()
assert repo1 is repo2
assert deleter.workflow_run_repo is repo1

View File

@ -140,8 +140,8 @@ class TestDatasetDocumentListApi:
return_value=pagination,
),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=2,
),
patch(
"controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status",
@ -700,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi:
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(
filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log))
),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=log,
),
):
response, status = method(api, "ds-1", "doc-1")
@ -827,15 +825,12 @@ class TestDocumentIndexingEstimateApi:
dataset_process_rule=None,
)
query_mock = MagicMock()
query_mock.where.return_value.first.return_value = None
with (
app.test_request_context("/"),
patch.object(api, "get_document", return_value=document),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=query_mock,
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=None,
),
):
with pytest.raises(NotFound):
@ -863,10 +858,8 @@ class TestDocumentIndexingEstimateApi:
app.test_request_context("/"),
patch.object(api, "get_document", return_value=document),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(
where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file)))
),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_document.ExtractSetting",
@ -1239,12 +1232,8 @@ class TestDocumentPermissionCases:
return_value=None,
),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(
where=lambda *a: MagicMock(
order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule))
)
),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=process_rule,
),
):
result = method(api)
@ -1364,8 +1353,8 @@ class TestDocumentIndexingEdgeCases:
app.test_request_context("/"),
patch.object(api, "get_document", return_value=document),
patch(
"controllers.console.datasets.datasets_document.db.session.query",
return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)),
"controllers.console.datasets.datasets_document.db.session.scalar",
return_value=upload_file,
),
patch(
"controllers.console.datasets.datasets_document.ExtractSetting",

View File

@ -26,12 +26,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = None
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
"controllers.console.datasets.wraps.db.session.scalar",
return_value=None,
)
with pytest.raises(PipelineNotFoundError):
@ -51,12 +48,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
"controllers.console.datasets.wraps.db.session.scalar",
return_value=pipeline,
)
result = dummy_view(pipeline_id="pipeline-1")
@ -76,12 +70,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
"controllers.console.datasets.wraps.db.session.scalar",
return_value=pipeline,
)
result = dummy_view(pipeline_id="pipeline-1")
@ -100,18 +91,15 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"),
)
def where_side_effect(*args, **kwargs):
assert args[0].right.value == "123"
return Mock(first=lambda: pipeline)
mock_query = Mock()
mock_query.where.side_effect = where_side_effect
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
mock_scalar = mocker.patch(
"controllers.console.datasets.wraps.db.session.scalar",
return_value=pipeline,
)
result = dummy_view(pipeline_id=123)
assert result is pipeline
# Verify the pipeline_id was cast to string in the where clause
stmt = mock_scalar.call_args[0][0]
where_clauses = stmt.whereclause.clauses
assert where_clauses[0].right.value == "123"

View File

@ -1,216 +0,0 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy.orm import Session
from models.workflow import WorkflowRun
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult
class TestArchivedWorkflowRunDeletion:
@pytest.fixture
def mock_db(self):
with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db:
mock_db.engine = MagicMock()
yield mock_db
@pytest.fixture
def mock_sessionmaker(self):
with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm:
mock_session = MagicMock(spec=Session)
mock_sm.return_value.return_value.__enter__.return_value = mock_session
yield mock_sm, mock_session
@pytest.fixture
def mock_workflow_run_repo(self):
with patch(
"services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository"
) as mock_repo_cls:
mock_repo = MagicMock()
yield mock_repo
def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker):
mock_sm, mock_session = mock_sessionmaker
run_id = "run-123"
tenant_id = "tenant-456"
mock_run = MagicMock(spec=WorkflowRun)
mock_run.id = run_id
mock_run.tenant_id = tenant_id
mock_session.get.return_value = mock_run
deletion = ArchivedWorkflowRunDeletion()
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
mock_repo = MagicMock()
mock_get_repo.return_value = mock_repo
mock_repo.get_archived_run_ids.return_value = [run_id]
with patch.object(deletion, "_delete_run") as mock_delete_run:
expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True)
mock_delete_run.return_value = expected_result
result = deletion.delete_by_run_id(run_id)
assert result == expected_result
mock_session.get.assert_called_once_with(WorkflowRun, run_id)
mock_repo.get_archived_run_ids.assert_called_once()
mock_delete_run.assert_called_once_with(mock_run)
def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker):
mock_sm, mock_session = mock_sessionmaker
run_id = "run-123"
mock_session.get.return_value = None
deletion = ArchivedWorkflowRunDeletion()
with patch.object(deletion, "_get_workflow_run_repo"):
result = deletion.delete_by_run_id(run_id)
assert result.success is False
assert "not found" in result.error
assert result.run_id == run_id
def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker):
mock_sm, mock_session = mock_sessionmaker
run_id = "run-123"
mock_run = MagicMock(spec=WorkflowRun)
mock_run.id = run_id
mock_session.get.return_value = mock_run
deletion = ArchivedWorkflowRunDeletion()
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
mock_repo = MagicMock()
mock_get_repo.return_value = mock_repo
mock_repo.get_archived_run_ids.return_value = []
result = deletion.delete_by_run_id(run_id)
assert result.success is False
assert "is not archived" in result.error
def test_delete_batch(self, mock_db, mock_sessionmaker):
mock_sm, mock_session = mock_sessionmaker
deletion = ArchivedWorkflowRunDeletion()
mock_run1 = MagicMock(spec=WorkflowRun)
mock_run1.id = "run-1"
mock_run2 = MagicMock(spec=WorkflowRun)
mock_run2.id = "run-2"
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
mock_repo = MagicMock()
mock_get_repo.return_value = mock_repo
mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2]
with patch.object(deletion, "_delete_run") as mock_delete_run:
mock_delete_run.side_effect = [
DeleteResult(run_id="run-1", tenant_id="t1", success=True),
DeleteResult(run_id="run-2", tenant_id="t1", success=True),
]
results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now())
assert len(results) == 2
assert results[0].run_id == "run-1"
assert results[1].run_id == "run-2"
assert mock_delete_run.call_count == 2
def test_delete_run_dry_run(self):
deletion = ArchivedWorkflowRunDeletion(dry_run=True)
mock_run = MagicMock(spec=WorkflowRun)
mock_run.id = "run-123"
mock_run.tenant_id = "tenant-456"
result = deletion._delete_run(mock_run)
assert result.success is True
assert result.run_id == "run-123"
def test_delete_run_success(self):
deletion = ArchivedWorkflowRunDeletion(dry_run=False)
mock_run = MagicMock(spec=WorkflowRun)
mock_run.id = "run-123"
mock_run.tenant_id = "tenant-456"
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
mock_repo = MagicMock()
mock_get_repo.return_value = mock_repo
mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1}
result = deletion._delete_run(mock_run)
assert result.success is True
assert result.deleted_counts == {"workflow_runs": 1}
def test_delete_run_exception(self):
deletion = ArchivedWorkflowRunDeletion(dry_run=False)
mock_run = MagicMock(spec=WorkflowRun)
mock_run.id = "run-123"
with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo:
mock_repo = MagicMock()
mock_get_repo.return_value = mock_repo
mock_repo.delete_runs_with_related.side_effect = Exception("Database error")
result = deletion._delete_run(mock_run)
assert result.success is False
assert result.error == "Database error"
def test_delete_trigger_logs(self):
mock_session = MagicMock(spec=Session)
run_ids = ["run-1", "run-2"]
with patch(
"services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository"
) as mock_repo_cls:
mock_repo = MagicMock()
mock_repo_cls.return_value = mock_repo
mock_repo.delete_by_run_ids.return_value = 5
count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids)
assert count == 5
mock_repo_cls.assert_called_once_with(mock_session)
mock_repo.delete_by_run_ids.assert_called_once_with(run_ids)
def test_delete_node_executions(self):
mock_session = MagicMock(spec=Session)
mock_run = MagicMock(spec=WorkflowRun)
mock_run.id = "run-1"
runs = [mock_run]
with patch(
"repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository"
) as mock_create_repo:
mock_repo = MagicMock()
mock_create_repo.return_value = mock_repo
mock_repo.delete_by_runs.return_value = (1, 2)
with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm:
result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs)
assert result == (1, 2)
mock_create_repo.assert_called_once()
mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"])
def test_get_workflow_run_repo(self, mock_db):
deletion = ArchivedWorkflowRunDeletion()
with patch(
"repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository"
) as mock_create_repo:
mock_repo = MagicMock()
mock_create_repo.return_value = mock_repo
# First call
repo1 = deletion._get_workflow_run_repo()
assert repo1 == mock_repo
assert deletion.workflow_run_repo == mock_repo
# Second call (should return cached)
repo2 = deletion._get_workflow_run_repo()
assert repo2 == mock_repo
mock_create_repo.assert_called_once()

View File

@ -1,346 +0,0 @@
"""
Unit tests for services.agent_service
"""
from collections.abc import Callable
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
import pytz
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
from services.agent_service import AgentService
def _make_current_user_account(timezone: str = "UTC") -> Account:
account = Account(name="Test User", email="test@example.com")
account.timezone = timezone
return account
def _make_app_model(app_model_config: MagicMock | None) -> MagicMock:
app_model = MagicMock(spec=App)
app_model.id = "app-123"
app_model.tenant_id = "tenant-123"
app_model.app_model_config = app_model_config
return app_model
def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock:
conversation = MagicMock(spec=Conversation)
conversation.id = "conv-123"
conversation.app_id = "app-123"
conversation.from_end_user_id = from_end_user_id
conversation.from_account_id = from_account_id
return conversation
def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock:
message = MagicMock(spec=Message)
message.id = "msg-123"
message.conversation_id = "conv-123"
message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC)
message.provider_response_latency = 1.23
message.answer_tokens = 4
message.message_tokens = 6
message.agent_thoughts = agent_thoughts
message.message_files = ["file-a.txt"]
return message
def _make_agent_thought() -> MagicMock:
agent_thought = MagicMock(spec=MessageAgentThought)
agent_thought.tokens = 3
agent_thought.tool_input = "raw-input"
agent_thought.observation = "raw-output"
agent_thought.thought = "thinking"
agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC)
agent_thought.files = []
agent_thought.tools = ["tool_a", "dataset_tool"]
agent_thought.tool_labels = {"tool_a": "Tool A"}
agent_thought.tool_meta = {
"tool_a": {
"tool_config": {
"tool_provider_type": "custom",
"tool_provider": "provider-1",
},
"tool_parameters": {"param": "value"},
"time_cost": 2.5,
},
"dataset_tool": {
"tool_config": {
"tool_provider_type": "dataset-retrieval",
"tool_provider": "dataset-provider",
}
},
}
agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}}
agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}}
return agent_thought
def _build_query_side_effect(
conversation: Conversation | None,
message: Message | None,
executor: EndUser | Account | None,
) -> Callable[..., MagicMock]:
def _query_side_effect(*args: object, **kwargs: object) -> MagicMock:
query = MagicMock()
query.where.return_value = query
if any(arg is Conversation for arg in args):
query.first.return_value = conversation
elif any(arg is Message for arg in args):
query.first.return_value = message
elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args):
query.first.return_value = executor
return query
return _query_side_effect
class TestAgentServiceGetAgentLogs:
"""Test suite for AgentService.get_agent_logs."""
def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None:
"""Test missing conversation raises ValueError."""
# Arrange
app_model = _make_app_model(MagicMock())
with patch("services.agent_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_logs(app_model, "missing-conv", "msg-1")
def test_get_agent_logs_should_raise_when_message_missing(self) -> None:
"""Test missing message raises ValueError."""
# Arrange
app_model = _make_app_model(MagicMock())
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
with patch("services.agent_service.db") as mock_db:
conversation_query = MagicMock()
conversation_query.where.return_value = conversation_query
conversation_query.first.return_value = conversation
message_query = MagicMock()
message_query.where.return_value = message_query
message_query.first.return_value = None
mock_db.session.query.side_effect = [conversation_query, message_query]
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_logs(app_model, conversation.id, "missing-msg")
def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None:
"""Test missing app model config raises ValueError."""
# Arrange
app_model = _make_app_model(None)
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
message = _make_message([])
current_user = _make_current_user_account()
with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock())
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_logs(app_model, conversation.id, message.id)
def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None:
"""Test missing agent config raises ValueError."""
# Arrange
app_model_config = MagicMock()
app_model_config.agent_mode_dict = {"strategy": "react"}
app_model_config.to_dict.return_value = {"tools": []}
app_model = _make_app_model(app_model_config)
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
message = _make_message([])
current_user = _make_current_user_account()
with (
patch("services.agent_service.db") as mock_db,
patch("services.agent_service.AgentConfigManager.convert", return_value=None),
patch("services.agent_service.current_user", current_user),
):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock())
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_logs(app_model, conversation.id, message.id)
def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None:
"""Test agent logs returned for end-user executor with tool icons."""
# Arrange
agent_thought = _make_agent_thought()
message = _make_message([agent_thought])
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
executor = MagicMock(spec=EndUser)
executor.name = "End User"
app_model_config = MagicMock()
app_model_config.agent_mode_dict = {"strategy": "react"}
app_model_config.to_dict.return_value = {"tools": []}
app_model = _make_app_model(app_model_config)
current_user = _make_current_user_account()
agent_tool = MagicMock()
agent_tool.tool_name = "tool_a"
agent_tool.provider_type = "custom"
agent_tool.provider_id = "provider-2"
agent_config = MagicMock()
agent_config.tools = [agent_tool]
with (
patch("services.agent_service.db") as mock_db,
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert,
patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon,
patch("services.agent_service.current_user", current_user),
):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor)
mock_get_icon.side_effect = [None, "icon-a"]
# Act
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
# Assert
assert result["meta"]["status"] == "success"
assert result["meta"]["executor"] == "End User"
assert result["meta"]["total_tokens"] == 10
assert result["meta"]["agent_mode"] == "react"
assert result["meta"]["iterations"] == 1
assert result["files"] == ["file-a.txt"]
assert len(result["iterations"]) == 1
tool_calls = result["iterations"][0]["tool_calls"]
assert tool_calls[0]["tool_name"] == "tool_a"
assert tool_calls[0]["tool_icon"] == "icon-a"
assert tool_calls[1]["tool_name"] == "dataset_tool"
assert tool_calls[1]["tool_icon"] == ""
mock_convert.assert_called_once()
def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None:
"""Test agent logs fall back to account executor when end user is missing."""
# Arrange
agent_thought = _make_agent_thought()
message = _make_message([agent_thought])
conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1")
executor = MagicMock(spec=Account)
executor.name = "Account User"
app_model_config = MagicMock()
app_model_config.agent_mode_dict = {"strategy": "react"}
app_model_config.to_dict.return_value = {"tools": []}
app_model = _make_app_model(app_model_config)
current_user = _make_current_user_account()
agent_config = MagicMock()
agent_config.tools = []
with (
patch("services.agent_service.db") as mock_db,
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config),
patch("services.agent_service.ToolManager.get_tool_icon", return_value=""),
patch("services.agent_service.current_user", current_user),
):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor)
# Act
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
# Assert
assert result["meta"]["executor"] == "Account User"
def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None:
"""Test unknown executor and missing tool details fall back to defaults."""
# Arrange
agent_thought = _make_agent_thought()
agent_thought.tool_labels = {}
agent_thought.tool_inputs_dict = {}
agent_thought.tool_outputs_dict = None
agent_thought.tool_meta = {"tool_a": {"error": "failed"}}
agent_thought.tools = ["tool_a"]
message = _make_message([agent_thought])
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
app_model_config = MagicMock()
app_model_config.agent_mode_dict = {}
app_model_config.to_dict.return_value = {"tools": []}
app_model = _make_app_model(app_model_config)
current_user = _make_current_user_account()
agent_config = MagicMock()
agent_config.tools = []
with (
patch("services.agent_service.db") as mock_db,
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config),
patch("services.agent_service.ToolManager.get_tool_icon", return_value=None),
patch("services.agent_service.current_user", current_user),
):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None)
# Act
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
# Assert
assert result["meta"]["executor"] == "Unknown"
assert result["meta"]["agent_mode"] == "react"
tool_call = result["iterations"][0]["tool_calls"][0]
assert tool_call["status"] == "error"
assert tool_call["error"] == "failed"
assert tool_call["tool_label"] == "tool_a"
assert tool_call["tool_input"] == {}
assert tool_call["tool_output"] == {}
assert tool_call["time_cost"] == 0
assert tool_call["tool_parameters"] == {}
assert tool_call["tool_icon"] is None
class TestAgentServiceProviders:
"""Test suite for AgentService provider methods."""
def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None:
"""Test list_agent_providers delegates to PluginAgentClient."""
# Arrange
tenant_id = "tenant-1"
expected = [{"name": "provider"}]
with patch("services.agent_service.PluginAgentClient") as mock_client:
mock_client.return_value.fetch_agent_strategy_providers.return_value = expected
# Act
result = AgentService.list_agent_providers("user-1", tenant_id)
# Assert
assert result == expected
mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id)
def test_get_agent_provider_should_return_provider_when_successful(self) -> None:
"""Test get_agent_provider returns provider when successful."""
# Arrange
tenant_id = "tenant-1"
provider_name = "provider-a"
expected = {"name": provider_name}
with patch("services.agent_service.PluginAgentClient") as mock_client:
mock_client.return_value.fetch_agent_strategy_provider.return_value = expected
# Act
result = AgentService.get_agent_provider("user-1", tenant_id, provider_name)
# Assert
assert result == expected
mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name)
def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None:
"""Test get_agent_provider wraps PluginDaemonClientSideError into ValueError."""
# Arrange
tenant_id = "tenant-1"
provider_name = "provider-a"
with patch("services.agent_service.PluginAgentClient") as mock_client:
mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError(
"plugin error"
)
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_provider("user-1", tenant_id, provider_name)

View File

@ -1,626 +0,0 @@
import csv
import io
import json
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from services.feedback_service import FeedbackService
class TestFeedbackServiceFactory:
"""Factory class for creating test data and mock objects for feedback service tests."""
@staticmethod
def create_feedback_mock(
feedback_id: str = "feedback-123",
app_id: str = "app-456",
conversation_id: str = "conv-789",
message_id: str = "msg-001",
rating: str = "like",
content: str | None = "Great response!",
from_source: str = "user",
from_account_id: str | None = None,
from_end_user_id: str | None = "end-user-001",
created_at: datetime | None = None,
) -> MagicMock:
"""Create a mock MessageFeedback object."""
feedback = MagicMock()
feedback.id = feedback_id
feedback.app_id = app_id
feedback.conversation_id = conversation_id
feedback.message_id = message_id
feedback.rating = rating
feedback.content = content
feedback.from_source = from_source
feedback.from_account_id = from_account_id
feedback.from_end_user_id = from_end_user_id
feedback.created_at = created_at or datetime.now()
return feedback
@staticmethod
def create_message_mock(
message_id: str = "msg-001",
query: str = "What is AI?",
answer: str = "AI stands for Artificial Intelligence.",
inputs: dict | None = None,
created_at: datetime | None = None,
):
"""Create a mock Message object."""
# Create a simple object with instance attributes
# Using a class with __init__ ensures attributes are instance attributes
class Message:
def __init__(self):
self.id = message_id
self.query = query
self.answer = answer
self.inputs = inputs
self.created_at = created_at or datetime.now()
return Message()
@staticmethod
def create_conversation_mock(
conversation_id: str = "conv-789",
name: str | None = "Test Conversation",
) -> MagicMock:
"""Create a mock Conversation object."""
conversation = MagicMock()
conversation.id = conversation_id
conversation.name = name
return conversation
@staticmethod
def create_app_mock(
app_id: str = "app-456",
name: str = "Test App",
) -> MagicMock:
"""Create a mock App object."""
app = MagicMock()
app.id = app_id
app.name = name
return app
@staticmethod
def create_account_mock(
account_id: str = "account-123",
name: str = "Test Admin",
) -> MagicMock:
"""Create a mock Account object."""
account = MagicMock()
account.id = account_id
account.name = name
return account
class TestFeedbackService:
"""
Comprehensive unit tests for FeedbackService.
This test suite covers:
- CSV and JSON export formats
- All filter combinations
- Edge cases and error handling
- Response validation
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestFeedbackServiceFactory()
@pytest.fixture
def sample_feedback_data(self, factory):
"""Create sample feedback data for testing."""
feedback = factory.create_feedback_mock(
rating="like",
content="Excellent answer!",
from_source="user",
)
message = factory.create_message_mock(
query="What is Python?",
answer="Python is a programming language.",
)
conversation = factory.create_conversation_mock(name="Python Discussion")
app = factory.create_app_mock(name="AI Assistant")
account = factory.create_account_mock(name="Admin User")
return [(feedback, message, conversation, app, account)]
# Test 01: CSV Export - Basic Functionality
@patch("services.feedback_service.db")
def test_export_feedbacks_csv_basic(self, mock_db, factory, sample_feedback_data):
"""Test basic CSV export with single feedback record."""
# Arrange
mock_query = MagicMock()
# Configure the mock to return itself for all chaining methods
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = sample_feedback_data
# Set up the session.query to return our mock
mock_db.session.query.return_value = mock_query
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv")
# Assert
assert response.mimetype == "text/csv"
assert "charset=utf-8-sig" in response.content_type
assert "attachment" in response.headers["Content-Disposition"]
assert "dify_feedback_export_app-456" in response.headers["Content-Disposition"]
# Verify CSV content
csv_content = response.get_data(as_text=True)
reader = csv.DictReader(io.StringIO(csv_content))
rows = list(reader)
assert len(rows) == 1
assert rows[0]["feedback_rating"] == "👍"
assert rows[0]["feedback_rating_raw"] == "like"
assert rows[0]["feedback_comment"] == "Excellent answer!"
assert rows[0]["user_query"] == "What is Python?"
assert rows[0]["ai_response"] == "Python is a programming language."
# Test 02: JSON Export - Basic Functionality
@patch("services.feedback_service.db")
def test_export_feedbacks_json_basic(self, mock_db, factory, sample_feedback_data):
"""Test basic JSON export with metadata structure."""
# Arrange
mock_query = MagicMock()
# Configure the mock to return itself for all chaining methods
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = sample_feedback_data
# Set up the session.query to return our mock
mock_db.session.query.return_value = mock_query
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
# Assert
assert response.mimetype == "application/json"
assert "charset=utf-8" in response.content_type
assert "attachment" in response.headers["Content-Disposition"]
# Verify JSON structure
json_content = json.loads(response.get_data(as_text=True))
assert "export_info" in json_content
assert "feedback_data" in json_content
assert json_content["export_info"]["app_id"] == "app-456"
assert json_content["export_info"]["total_records"] == 1
assert len(json_content["feedback_data"]) == 1
# Test 03: Filter by from_source
@patch("services.feedback_service.db")
def test_export_feedbacks_filter_from_source(self, mock_db, factory):
"""Test filtering by feedback source (user/admin)."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Act
FeedbackService.export_feedbacks(app_id="app-456", from_source="admin")
# Assert
mock_query.filter.assert_called()
# Test 04: Filter by rating
@patch("services.feedback_service.db")
def test_export_feedbacks_filter_rating(self, mock_db, factory):
"""Test filtering by rating (like/dislike)."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Act
FeedbackService.export_feedbacks(app_id="app-456", rating="dislike")
# Assert
mock_query.filter.assert_called()
# Test 05: Filter by has_comment (True)
@patch("services.feedback_service.db")
def test_export_feedbacks_filter_has_comment_true(self, mock_db, factory):
"""Test filtering for feedback with comments."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Act
FeedbackService.export_feedbacks(app_id="app-456", has_comment=True)
# Assert
mock_query.filter.assert_called()
# Test 06: Filter by has_comment (False)
@patch("services.feedback_service.db")
def test_export_feedbacks_filter_has_comment_false(self, mock_db, factory):
"""Test filtering for feedback without comments."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Act
FeedbackService.export_feedbacks(app_id="app-456", has_comment=False)
# Assert
mock_query.filter.assert_called()
# Test 07: Filter by date range
@patch("services.feedback_service.db")
def test_export_feedbacks_filter_date_range(self, mock_db, factory):
"""Test filtering by start and end dates."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Act
FeedbackService.export_feedbacks(
app_id="app-456",
start_date="2024-01-01",
end_date="2024-12-31",
)
# Assert
assert mock_query.filter.call_count >= 2 # Called for both start and end dates
# Test 08: Invalid date format - start_date
@patch("services.feedback_service.db")
def test_export_feedbacks_invalid_start_date(self, mock_db):
"""Test error handling for invalid start_date format."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
# Act & Assert
with pytest.raises(ValueError, match="Invalid start_date format"):
FeedbackService.export_feedbacks(app_id="app-456", start_date="invalid-date")
# Test 09: Invalid date format - end_date
@patch("services.feedback_service.db")
def test_export_feedbacks_invalid_end_date(self, mock_db):
"""Test error handling for invalid end_date format."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
# Act & Assert
with pytest.raises(ValueError, match="Invalid end_date format"):
FeedbackService.export_feedbacks(app_id="app-456", end_date="2024-13-45")
# Test 10: Unsupported format
def test_export_feedbacks_unsupported_format(self):
"""Test error handling for unsupported export format."""
# Act & Assert
with pytest.raises(ValueError, match="Unsupported format"):
FeedbackService.export_feedbacks(app_id="app-456", format_type="xml")
# Test 11: Empty result set - CSV
@patch("services.feedback_service.db")
def test_export_feedbacks_empty_results_csv(self, mock_db):
"""Test CSV export with no feedback records."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv")
# Assert
csv_content = response.get_data(as_text=True)
reader = csv.DictReader(io.StringIO(csv_content))
rows = list(reader)
assert len(rows) == 0
# But headers should still be present
assert reader.fieldnames is not None
# Test 12: Empty result set - JSON
@patch("services.feedback_service.db")
def test_export_feedbacks_empty_results_json(self, mock_db):
"""Test JSON export with no feedback records."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
# Assert
json_content = json.loads(response.get_data(as_text=True))
assert json_content["export_info"]["total_records"] == 0
assert len(json_content["feedback_data"]) == 0
# Test 13: Long response truncation
@patch("services.feedback_service.db")
def test_export_feedbacks_long_response_truncation(self, mock_db, factory):
"""Test that long AI responses are truncated to 500 characters."""
# Arrange
long_answer = "A" * 600 # 600 characters
feedback = factory.create_feedback_mock()
message = factory.create_message_mock(answer=long_answer)
conversation = factory.create_conversation_mock()
app = factory.create_app_mock()
account = factory.create_account_mock()
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
# Assert
json_content = json.loads(response.get_data(as_text=True))
ai_response = json_content["feedback_data"][0]["ai_response"]
assert len(ai_response) == 503 # 500 + "..."
assert ai_response.endswith("...")
# Test 14: Null account (end user feedback)
@patch("services.feedback_service.db")
def test_export_feedbacks_null_account(self, mock_db, factory):
"""Test handling of feedback from end users (no account)."""
# Arrange
feedback = factory.create_feedback_mock(from_account_id=None)
message = factory.create_message_mock()
conversation = factory.create_conversation_mock()
app = factory.create_app_mock()
account = None # No account for end user
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
# Assert
json_content = json.loads(response.get_data(as_text=True))
assert json_content["feedback_data"][0]["from_account_name"] == ""
# Test 15: Null conversation name
@patch("services.feedback_service.db")
def test_export_feedbacks_null_conversation_name(self, mock_db, factory):
"""Test handling of conversations without names."""
# Arrange
feedback = factory.create_feedback_mock()
message = factory.create_message_mock()
conversation = factory.create_conversation_mock(name=None)
app = factory.create_app_mock()
account = factory.create_account_mock()
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
# Assert
json_content = json.loads(response.get_data(as_text=True))
assert json_content["feedback_data"][0]["conversation_name"] == ""
# Test 16: Dislike rating emoji
@patch("services.feedback_service.db")
def test_export_feedbacks_dislike_rating(self, mock_db, factory):
"""Test that dislike rating shows thumbs down emoji."""
# Arrange
feedback = factory.create_feedback_mock(rating="dislike")
message = factory.create_message_mock()
conversation = factory.create_conversation_mock()
app = factory.create_app_mock()
account = factory.create_account_mock()
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
# Assert
json_content = json.loads(response.get_data(as_text=True))
assert json_content["feedback_data"][0]["feedback_rating"] == "👎"
assert json_content["feedback_data"][0]["feedback_rating_raw"] == "dislike"
# Test 17: Combined filters
@patch("services.feedback_service.db")
def test_export_feedbacks_combined_filters(self, mock_db, factory):
"""Test applying multiple filters simultaneously."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Act
FeedbackService.export_feedbacks(
app_id="app-456",
from_source="admin",
rating="like",
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",
)
# Assert
# Should have called filter multiple times for each condition
assert mock_query.filter.call_count >= 4
# Test 18: Message query fallback to inputs
@patch("services.feedback_service.db")
def test_export_feedbacks_message_query_from_inputs(self, mock_db, factory):
"""Test fallback to inputs.query when message.query is None."""
# Arrange
feedback = factory.create_feedback_mock()
message = factory.create_message_mock(query=None, inputs={"query": "Query from inputs"})
conversation = factory.create_conversation_mock()
app = factory.create_app_mock()
account = factory.create_account_mock()
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
# Assert
json_content = json.loads(response.get_data(as_text=True))
assert json_content["feedback_data"][0]["user_query"] == "Query from inputs"
# Test 19: Empty feedback content
@patch("services.feedback_service.db")
def test_export_feedbacks_empty_feedback_content(self, mock_db, factory):
"""Test handling of feedback with empty/null content."""
# Arrange
feedback = factory.create_feedback_mock(content=None)
message = factory.create_message_mock()
conversation = factory.create_conversation_mock()
app = factory.create_app_mock()
account = factory.create_account_mock()
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [(feedback, message, conversation, app, account)]
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json")
# Assert
json_content = json.loads(response.get_data(as_text=True))
assert json_content["feedback_data"][0]["feedback_comment"] == ""
assert json_content["feedback_data"][0]["has_comment"] == "No"
# Test 20: CSV headers validation
@patch("services.feedback_service.db")
def test_export_feedbacks_csv_headers(self, mock_db, factory, sample_feedback_data):
"""Test that CSV contains all expected headers."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = sample_feedback_data
expected_headers = [
"feedback_id",
"app_name",
"app_id",
"conversation_id",
"conversation_name",
"message_id",
"user_query",
"ai_response",
"feedback_rating",
"feedback_rating_raw",
"feedback_comment",
"feedback_source",
"feedback_date",
"message_date",
"from_account_name",
"from_end_user_id",
"has_comment",
]
# Act
response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv")
# Assert
csv_content = response.get_data(as_text=True)
reader = csv.DictReader(io.StringIO(csv_content))
assert list(reader.fieldnames) == expected_headers

View File

@ -1767,7 +1767,7 @@ dev = [
{ name = "lxml-stubs", specifier = "~=0.5.1" },
{ name = "mypy", specifier = "~=1.19.1" },
{ name = "pandas-stubs", specifier = "~=3.0.0" },
{ name = "pyrefly", specifier = ">=0.55.0" },
{ name = "pyrefly", specifier = ">=0.57.1" },
{ name = "pytest", specifier = "~=9.0.2" },
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
{ name = "pytest-cov", specifier = "~=7.1.0" },
@ -5402,18 +5402,18 @@ wheels = [
[[package]]
name = "pyrefly"
version = "0.55.0"
version = "0.57.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/bf/c4/76e0797215e62d007f81f86c9c4fb5d6202685a3f5e70810f3fd94294f92/pyrefly-0.55.0.tar.gz", hash = "sha256:434c3282532dd4525c4840f2040ed0eb79b0ec8224fe18d957956b15471f2441", size = 5135682, upload-time = "2026-03-03T00:46:38.122Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/39/b0/16e50cf716784513648e23e726a24f71f9544aa4f86103032dcaa5ff71a2/pyrefly-0.55.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:49aafcefe5e2dd4256147db93e5b0ada42bff7d9a60db70e03d1f7055338eec9", size = 12210073, upload-time = "2026-03-03T00:46:15.51Z" },
{ url = "https://files.pythonhosted.org/packages/3a/ad/89500c01bac3083383011600370289fbc67700c5be46e781787392628a3a/pyrefly-0.55.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2827426e6b28397c13badb93c0ede0fb0f48046a7a89e3d774cda04e8e2067cd", size = 11767474, upload-time = "2026-03-03T00:46:18.003Z" },
{ url = "https://files.pythonhosted.org/packages/78/68/4c66b260f817f304ead11176ff13985625f7c269e653304b4bdb546551af/pyrefly-0.55.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7346b2d64dc575bd61aa3bca854fbf8b5a19a471cbdb45e0ca1e09861b63488c", size = 33260395, upload-time = "2026-03-03T00:46:20.509Z" },
{ url = "https://files.pythonhosted.org/packages/47/09/10bd48c9f860064f29f412954126a827d60f6451512224912c265e26bbe6/pyrefly-0.55.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:233b861b4cff008b1aff62f4f941577ed752e4d0060834229eb9b6826e6973c9", size = 35848269, upload-time = "2026-03-03T00:46:23.418Z" },
{ url = "https://files.pythonhosted.org/packages/a9/39/bc65cdd5243eb2dfea25dd1321f9a5a93e8d9c3a308501c4c6c05d011585/pyrefly-0.55.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5aa85657d76da1d25d081a49f0e33c8fc3ec91c1a0f185a8ed393a5a3d9e178", size = 38449820, upload-time = "2026-03-03T00:46:26.309Z" },
{ url = "https://files.pythonhosted.org/packages/e5/64/58b38963b011af91209e87f868cc85cfc762ec49a4568ce610c45e7a5f40/pyrefly-0.55.0-py3-none-win32.whl", hash = "sha256:23f786a78536a56fed331b245b7d10ec8945bebee7b723491c8d66fdbc155fe6", size = 11259415, upload-time = "2026-03-03T00:46:30.875Z" },
{ url = "https://files.pythonhosted.org/packages/7a/0b/a4aa519ff632a1ea69eec942566951670b870b99b5c08407e1387b85b6a4/pyrefly-0.55.0-py3-none-win_amd64.whl", hash = "sha256:d465b49e999b50eeb069ad23f0f5710651cad2576f9452a82991bef557df91ee", size = 12043581, upload-time = "2026-03-03T00:46:33.674Z" },
{ url = "https://files.pythonhosted.org/packages/f1/51/89017636fbe1ffd166ad478990c6052df615b926182fa6d3c0842b407e89/pyrefly-0.55.0-py3-none-win_arm64.whl", hash = "sha256:732ff490e0e863b296e7c0b2471e08f8ba7952f9fa6e9de09d8347fd67dde77f", size = 11548076, upload-time = "2026-03-03T00:46:36.193Z" },
{ url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" },
{ url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" },
{ url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" },
{ url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" },
{ url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" },
{ url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" },
{ url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" },
{ url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" },
]
[[package]]