mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into optional-plugin-invoke
This commit is contained in:
commit
c22b58d46c
|
|
@ -1,8 +1,11 @@
|
|||
"""Testcontainers integration tests for email register controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.email_register import (
|
||||
EmailRegisterCheckApi,
|
||||
|
|
@ -13,14 +16,11 @@ from services.account_service import AccountService
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
def app(flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
class TestEmailRegisterSendEmailApi:
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
|
||||
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
|
||||
|
|
@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi:
|
|||
mock_is_freeze,
|
||||
mock_send_mail,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_is_freeze.return_value = False
|
||||
mock_account = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
|
|
@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi:
|
|||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_is_freeze.assert_called_once_with("invitee@example.com")
|
||||
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
|
@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi:
|
|||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
|
|
@ -114,7 +107,6 @@ class TestEmailRegisterResetApi:
|
|||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
|
|
@ -125,7 +117,6 @@ class TestEmailRegisterResetApi:
|
|||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
|
|
@ -136,14 +127,10 @@ class TestEmailRegisterResetApi:
|
|||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
|
|
@ -159,19 +146,19 @@ class TestEmailRegisterResetApi:
|
|||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
"""Testcontainers integration tests for forgot password controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
|
|
@ -13,14 +16,11 @@ from services.account_service import AccountService
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
def app(flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
|
|
@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi:
|
|||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
controller_features = SimpleNamespace(is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch(
|
||||
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
|
||||
return_value=controller_features,
|
||||
|
|
@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=mock_account,
|
||||
email="user@example.com",
|
||||
|
|
@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi:
|
|||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
|
|
@ -126,7 +120,6 @@ class TestForgotPasswordResetApi:
|
|||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
|
|
@ -134,12 +127,8 @@ class TestForgotPasswordResetApi:
|
|||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
|
|
@ -157,20 +146,22 @@ class TestForgotPasswordResetApi:
|
|||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("token-123")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
|
@ -707,3 +707,104 @@ class TestDatasetServiceRetrievalConfiguration:
|
|||
db_session_with_containers.refresh(dataset)
|
||||
assert result.id == dataset.id
|
||||
assert dataset.retrieval_model == update_data["retrieval_model"]
|
||||
|
||||
|
||||
class TestDocumentServicePauseRecoverRetry:
|
||||
"""Tests for pause/recover/retry orchestration using real DB and Redis."""
|
||||
|
||||
def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"):
|
||||
factory = DatasetServiceIntegrationDataFactory
|
||||
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
|
||||
doc = factory.create_document(db_session_with_containers, dataset, account.id)
|
||||
doc.indexing_status = indexing_status
|
||||
db_session_with_containers.commit()
|
||||
return doc, account
|
||||
|
||||
def test_pause_document_success(self, db_session_with_containers):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing")
|
||||
|
||||
with patch("services.dataset_service.current_user") as mock_user:
|
||||
mock_user.id = account.id
|
||||
DocumentService.pause_document(doc)
|
||||
|
||||
db_session_with_containers.refresh(doc)
|
||||
assert doc.is_paused is True
|
||||
assert doc.paused_by == account.id
|
||||
assert doc.paused_at is not None
|
||||
|
||||
cache_key = f"document_{doc.id}_is_paused"
|
||||
assert redis_client.get(cache_key) is not None
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
def test_pause_document_invalid_status_error(self, db_session_with_containers):
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
|
||||
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="completed")
|
||||
|
||||
with patch("services.dataset_service.current_user") as mock_user:
|
||||
mock_user.id = account.id
|
||||
with pytest.raises(DocumentIndexingError):
|
||||
DocumentService.pause_document(doc)
|
||||
|
||||
def test_recover_document_success(self, db_session_with_containers):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing")
|
||||
|
||||
# Pause first
|
||||
with patch("services.dataset_service.current_user") as mock_user:
|
||||
mock_user.id = account.id
|
||||
DocumentService.pause_document(doc)
|
||||
|
||||
# Recover
|
||||
with patch("services.dataset_service.recover_document_indexing_task") as recover_task:
|
||||
DocumentService.recover_document(doc)
|
||||
|
||||
db_session_with_containers.refresh(doc)
|
||||
assert doc.is_paused is False
|
||||
assert doc.paused_by is None
|
||||
assert doc.paused_at is None
|
||||
|
||||
cache_key = f"document_{doc.id}_is_paused"
|
||||
assert redis_client.get(cache_key) is None
|
||||
recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id)
|
||||
|
||||
def test_retry_document_indexing_success(self, db_session_with_containers):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
factory = DatasetServiceIntegrationDataFactory
|
||||
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
|
||||
doc1 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc1.txt")
|
||||
doc2 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc2.txt")
|
||||
doc2.position = 2
|
||||
doc1.indexing_status = "error"
|
||||
doc2.indexing_status = "error"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.current_user") as mock_user,
|
||||
patch("services.dataset_service.retry_document_indexing_task") as retry_task,
|
||||
):
|
||||
mock_user.id = account.id
|
||||
DocumentService.retry_document(dataset.id, [doc1, doc2])
|
||||
|
||||
db_session_with_containers.refresh(doc1)
|
||||
db_session_with_containers.refresh(doc2)
|
||||
assert doc1.indexing_status == "waiting"
|
||||
assert doc2.indexing_status == "waiting"
|
||||
|
||||
# Verify redis keys were set
|
||||
assert redis_client.get(f"document_{doc1.id}_is_retried") is not None
|
||||
assert redis_client.get(f"document_{doc2.id}_is_retried") is not None
|
||||
retry_task.delay.assert_called_once_with(dataset.id, [doc1.id, doc2.id], account.id)
|
||||
|
||||
# Cleanup
|
||||
redis_client.delete(f"document_{doc1.id}_is_retried", f"document_{doc2.id}_is_retried")
|
||||
|
|
|
|||
|
|
@ -1,129 +0,0 @@
|
|||
"""Unit tests for non-SQL DocumentService orchestration behaviors.
|
||||
|
||||
This file intentionally keeps only collaborator-oriented document indexing
|
||||
orchestration tests. SQL-backed dataset lifecycle cases are covered by
|
||||
integration tests under testcontainers.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import Document
|
||||
from services.errors.document import DocumentIndexingError
|
||||
|
||||
|
||||
class DatasetServiceUnitDataFactory:
|
||||
"""Factory for creating lightweight document doubles used in unit tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_document_mock(
|
||||
document_id: str = "doc-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
indexing_status: str = "completed",
|
||||
is_paused: bool = False,
|
||||
) -> Mock:
|
||||
"""Create a document-shaped mock for DocumentService orchestration tests."""
|
||||
document = Mock(spec=Document)
|
||||
document.id = document_id
|
||||
document.dataset_id = dataset_id
|
||||
document.indexing_status = indexing_status
|
||||
document.is_paused = is_paused
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
return document
|
||||
|
||||
|
||||
class TestDatasetServiceDocumentIndexing:
|
||||
"""Unit tests for pause/recover/retry orchestration without SQL assertions."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_service_dependencies(self):
|
||||
"""Patch non-SQL collaborators used by DocumentService methods."""
|
||||
with (
|
||||
patch("services.dataset_service.redis_client") as mock_redis,
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.current_user") as mock_current_user,
|
||||
):
|
||||
mock_current_user.id = "user-123"
|
||||
yield {
|
||||
"redis_client": mock_redis,
|
||||
"db_session": mock_db,
|
||||
"current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def test_pause_document_success(self, mock_document_service_dependencies):
|
||||
"""Pause a document that is currently in an indexable status."""
|
||||
# Arrange
|
||||
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing")
|
||||
|
||||
# Act
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
DocumentService.pause_document(document)
|
||||
|
||||
# Assert
|
||||
assert document.is_paused is True
|
||||
assert document.paused_by == "user-123"
|
||||
mock_document_service_dependencies["db_session"].add.assert_called_once_with(document)
|
||||
mock_document_service_dependencies["db_session"].commit.assert_called_once()
|
||||
mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(
|
||||
f"document_{document.id}_is_paused",
|
||||
"True",
|
||||
)
|
||||
|
||||
def test_pause_document_invalid_status_error(self, mock_document_service_dependencies):
|
||||
"""Raise DocumentIndexingError when pausing a completed document."""
|
||||
# Arrange
|
||||
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed")
|
||||
|
||||
# Act / Assert
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
with pytest.raises(DocumentIndexingError):
|
||||
DocumentService.pause_document(document)
|
||||
|
||||
def test_recover_document_success(self, mock_document_service_dependencies):
|
||||
"""Recover a paused document and dispatch the recover indexing task."""
|
||||
# Arrange
|
||||
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.recover_document_indexing_task") as recover_task:
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
DocumentService.recover_document(document)
|
||||
|
||||
# Assert
|
||||
assert document.is_paused is False
|
||||
assert document.paused_by is None
|
||||
assert document.paused_at is None
|
||||
mock_document_service_dependencies["db_session"].add.assert_called_once_with(document)
|
||||
mock_document_service_dependencies["db_session"].commit.assert_called_once()
|
||||
mock_document_service_dependencies["redis_client"].delete.assert_called_once_with(
|
||||
f"document_{document.id}_is_paused"
|
||||
)
|
||||
recover_task.delay.assert_called_once_with(document.dataset_id, document.id)
|
||||
|
||||
def test_retry_document_indexing_success(self, mock_document_service_dependencies):
|
||||
"""Reset documents to waiting state and dispatch retry indexing task."""
|
||||
# Arrange
|
||||
dataset_id = "dataset-123"
|
||||
documents = [
|
||||
DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"),
|
||||
DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"),
|
||||
]
|
||||
mock_document_service_dependencies["redis_client"].get.return_value = None
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.retry_document_indexing_task") as retry_task:
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
DocumentService.retry_document(dataset_id, documents)
|
||||
|
||||
# Assert
|
||||
assert all(document.indexing_status == "waiting" for document in documents)
|
||||
assert mock_document_service_dependencies["db_session"].add.call_count == 2
|
||||
assert mock_document_service_dependencies["db_session"].commit.call_count == 2
|
||||
assert mock_document_service_dependencies["redis_client"].setex.call_count == 2
|
||||
retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123")
|
||||
Loading…
Reference in New Issue