dify/api/tests/unit_tests/services/segment_service.py

1115 lines
45 KiB
Python
Raw Normal View History

from unittest.mock import MagicMock, Mock, patch
import pytest
from models.account import Account
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
from models.enums import SegmentType
from services.dataset_service import SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
class SegmentTestDataFactory:
"""Factory class for creating test data and mock objects for segment service tests."""
@staticmethod
def create_segment_mock(
segment_id: str = "segment-123",
document_id: str = "doc-123",
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
content: str = "Test segment content",
position: int = 1,
enabled: bool = True,
status: str = "completed",
word_count: int = 3,
tokens: int = 5,
**kwargs,
) -> Mock:
"""Create a mock segment with specified attributes."""
segment = Mock(spec=DocumentSegment)
segment.id = segment_id
segment.document_id = document_id
segment.dataset_id = dataset_id
segment.tenant_id = tenant_id
segment.content = content
segment.position = position
segment.enabled = enabled
segment.status = status
segment.word_count = word_count
segment.tokens = tokens
segment.index_node_id = f"node-{segment_id}"
segment.index_node_hash = "hash-123"
segment.keywords = []
segment.answer = None
segment.disabled_at = None
segment.disabled_by = None
segment.updated_by = None
segment.updated_at = None
segment.indexing_at = None
segment.completed_at = None
segment.error = None
for key, value in kwargs.items():
setattr(segment, key, value)
return segment
@staticmethod
def create_child_chunk_mock(
chunk_id: str = "chunk-123",
segment_id: str = "segment-123",
document_id: str = "doc-123",
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
content: str = "Test child chunk content",
position: int = 1,
word_count: int = 3,
**kwargs,
) -> Mock:
"""Create a mock child chunk with specified attributes."""
chunk = Mock(spec=ChildChunk)
chunk.id = chunk_id
chunk.segment_id = segment_id
chunk.document_id = document_id
chunk.dataset_id = dataset_id
chunk.tenant_id = tenant_id
chunk.content = content
chunk.position = position
chunk.word_count = word_count
chunk.index_node_id = f"node-{chunk_id}"
chunk.index_node_hash = "hash-123"
chunk.type = SegmentType.AUTOMATIC
chunk.created_by = "user-123"
chunk.updated_by = None
chunk.updated_at = None
for key, value in kwargs.items():
setattr(chunk, key, value)
return chunk
@staticmethod
def create_document_mock(
document_id: str = "doc-123",
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
doc_form: str = "text_model",
word_count: int = 100,
**kwargs,
) -> Mock:
"""Create a mock document with specified attributes."""
document = Mock(spec=Document)
document.id = document_id
document.dataset_id = dataset_id
document.tenant_id = tenant_id
document.doc_form = doc_form
document.word_count = word_count
for key, value in kwargs.items():
setattr(document, key, value)
return document
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
indexing_technique: str = "high_quality",
embedding_model: str = "text-embedding-ada-002",
embedding_model_provider: str = "openai",
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = indexing_technique
dataset.embedding_model = embedding_model
dataset.embedding_model_provider = embedding_model_provider
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(
user_id: str = "user-789",
tenant_id: str = "tenant-123",
**kwargs,
) -> Mock:
"""Create a mock user with specified attributes."""
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
user.name = "Test User"
for key, value in kwargs.items():
setattr(user, key, value)
return user
class TestSegmentServiceCreateSegment:
"""Tests for SegmentService.create_segment method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
def mock_current_user(self):
"""Mock current_user."""
user = SegmentTestDataFactory.create_user_mock()
with patch("services.dataset_service.current_user", user):
yield user
def test_create_segment_success(self, mock_db_session, mock_current_user):
"""Test successful creation of a segment."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
args = {"content": "New segment content", "keywords": ["test", "segment"]}
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = None # No existing segments
mock_db_session.query.return_value = mock_query
mock_segment = SegmentTestDataFactory.create_segment_mock()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
mock_hash.return_value = "hash-123"
mock_now.return_value = "2024-01-01T00:00:00"
# Act
result = SegmentService.create_segment(args, document, dataset)
# Assert
assert mock_db_session.add.call_count == 2
created_segment = mock_db_session.add.call_args_list[0].args[0]
assert isinstance(created_segment, DocumentSegment)
assert created_segment.content == args["content"]
assert created_segment.word_count == len(args["content"])
mock_db_session.commit.assert_called_once()
mock_vector_service.assert_called_once()
vector_call_args = mock_vector_service.call_args[0]
assert vector_call_args[0] == [args["keywords"]]
assert vector_call_args[1][0] == created_segment
assert vector_call_args[2] == dataset
assert vector_call_args[3] == document.doc_form
assert result == mock_segment
def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user):
"""Test creation of segment with QA model (requires answer)."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]}
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = None
mock_db_session.query.return_value = mock_query
mock_segment = SegmentTestDataFactory.create_segment_mock()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
mock_hash.return_value = "hash-123"
mock_now.return_value = "2024-01-01T00:00:00"
# Act
result = SegmentService.create_segment(args, document, dataset)
# Assert
assert result == mock_segment
mock_db_session.add.assert_called()
mock_db_session.commit.assert_called()
def test_create_segment_with_high_quality_indexing(self, mock_db_session, mock_current_user):
"""Test creation of segment with high quality indexing technique."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="high_quality")
args = {"content": "New segment content", "keywords": ["test"]}
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = None
mock_db_session.query.return_value = mock_query
mock_embedding_model = MagicMock()
mock_embedding_model.get_text_embedding_num_tokens.return_value = [10]
mock_model_manager = MagicMock()
mock_model_manager.get_model_instance.return_value = mock_embedding_model
mock_segment = SegmentTestDataFactory.create_segment_mock()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
mock_model_manager_class.return_value = mock_model_manager
mock_hash.return_value = "hash-123"
mock_now.return_value = "2024-01-01T00:00:00"
# Act
result = SegmentService.create_segment(args, document, dataset)
# Assert
assert result == mock_segment
mock_model_manager.get_model_instance.assert_called_once()
mock_embedding_model.get_text_embedding_num_tokens.assert_called_once()
def test_create_segment_vector_index_failure(self, mock_db_session, mock_current_user):
"""Test segment creation when vector indexing fails."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
args = {"content": "New segment content", "keywords": ["test"]}
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = None
mock_db_session.query.return_value = mock_query
mock_segment = SegmentTestDataFactory.create_segment_mock(enabled=False, status="error")
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
mock_vector_service.side_effect = Exception("Vector indexing failed")
mock_hash.return_value = "hash-123"
mock_now.return_value = "2024-01-01T00:00:00"
# Act
result = SegmentService.create_segment(args, document, dataset)
# Assert
assert result == mock_segment
assert mock_db_session.commit.call_count == 2 # Once for creation, once for error update
class TestSegmentServiceUpdateSegment:
"""Tests for SegmentService.update_segment method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
def mock_current_user(self):
"""Mock current_user."""
user = SegmentTestDataFactory.create_user_mock()
with patch("services.dataset_service.current_user", user):
yield user
def test_update_segment_content_success(self, mock_db_session, mock_current_user):
"""Test successful update of segment content."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
args = SegmentUpdateArgs(content="Updated content", keywords=["updated"])
mock_db_session.query.return_value.where.return_value.first.return_value = segment
with (
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None # Not indexing
mock_hash.return_value = "new-hash"
mock_now.return_value = "2024-01-01T00:00:00"
# Act
result = SegmentService.update_segment(args, segment, document, dataset)
# Assert
assert result == segment
assert segment.content == "Updated content"
assert segment.keywords == ["updated"]
assert segment.word_count == len("Updated content")
assert document.word_count == 100 + (len("Updated content") - 10)
mock_db_session.add.assert_called()
mock_db_session.commit.assert_called()
def test_update_segment_disable(self, mock_db_session, mock_current_user):
"""Test disabling a segment."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True)
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
args = SegmentUpdateArgs(enabled=False)
with (
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
patch("services.dataset_service.disable_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None
mock_now.return_value = "2024-01-01T00:00:00"
# Act
result = SegmentService.update_segment(args, segment, document, dataset)
# Assert
assert result == segment
assert segment.enabled is False
mock_db_session.add.assert_called()
mock_db_session.commit.assert_called()
mock_task.delay.assert_called_once()
def test_update_segment_indexing_in_progress(self, mock_db_session, mock_current_user):
"""Test update fails when segment is currently indexing."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True)
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
args = SegmentUpdateArgs(content="Updated content")
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = "1" # Indexing in progress
# Act & Assert
with pytest.raises(ValueError, match="Segment is indexing"):
SegmentService.update_segment(args, segment, document, dataset)
def test_update_segment_disabled_segment(self, mock_db_session, mock_current_user):
"""Test update fails when segment is disabled."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=False)
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
args = SegmentUpdateArgs(content="Updated content")
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="Can't update disabled segment"):
SegmentService.update_segment(args, segment, document, dataset)
def test_update_segment_with_qa_model(self, mock_db_session, mock_current_user):
"""Test update segment with QA model (includes answer)."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"])
mock_db_session.query.return_value.where.return_value.first.return_value = segment
with (
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None
mock_hash.return_value = "new-hash"
mock_now.return_value = "2024-01-01T00:00:00"
# Act
result = SegmentService.update_segment(args, segment, document, dataset)
# Assert
assert result == segment
assert segment.content == "Updated question"
assert segment.answer == "Updated answer"
assert segment.keywords == ["qa"]
new_word_count = len("Updated question") + len("Updated answer")
assert segment.word_count == new_word_count
assert document.word_count == 100 + (new_word_count - 10)
mock_db_session.commit.assert_called()
class TestSegmentServiceDeleteSegment:
"""Tests for SegmentService.delete_segment method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_delete_segment_success(self, mock_db_session):
"""Test successful deletion of a segment."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=50)
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = []
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select,
):
mock_redis_get.return_value = None
mock_select.return_value.where.return_value = mock_select
# Act
SegmentService.delete_segment(segment, document, dataset)
# Assert
mock_db_session.delete.assert_called_once_with(segment)
mock_db_session.commit.assert_called_once()
mock_task.delay.assert_called_once()
def test_delete_segment_disabled(self, mock_db_session):
"""Test deletion of disabled segment (no index deletion)."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=False, word_count=50)
document = SegmentTestDataFactory.create_document_mock(word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
):
mock_redis_get.return_value = None
# Act
SegmentService.delete_segment(segment, document, dataset)
# Assert
mock_db_session.delete.assert_called_once_with(segment)
mock_db_session.commit.assert_called_once()
mock_task.delay.assert_not_called()
def test_delete_segment_indexing_in_progress(self, mock_db_session):
"""Test deletion fails when segment is currently being deleted."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True)
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = "1" # Deletion in progress
# Act & Assert
with pytest.raises(ValueError, match="Segment is deleting"):
SegmentService.delete_segment(segment, document, dataset)
class TestSegmentServiceDeleteSegments:
"""Tests for SegmentService.delete_segments method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
def mock_current_user(self):
"""Mock current_user."""
user = SegmentTestDataFactory.create_user_mock()
with patch("services.dataset_service.current_user", user):
yield user
def test_delete_segments_success(self, mock_db_session, mock_current_user):
"""Test successful deletion of multiple segments."""
# Arrange
segment_ids = ["segment-1", "segment-2"]
document = SegmentTestDataFactory.create_document_mock(word_count=200)
dataset = SegmentTestDataFactory.create_dataset_mock()
segments_info = [
("node-1", "segment-1", 50),
("node-2", "segment-2", 30),
]
mock_query = MagicMock()
mock_query.with_entities.return_value.where.return_value.all.return_value = segments_info
mock_db_session.query.return_value = mock_query
mock_scalars = MagicMock()
mock_scalars.all.return_value = []
mock_select = MagicMock()
mock_select.where.return_value = mock_select
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_select_func.return_value = mock_select
# Act
SegmentService.delete_segments(segment_ids, document, dataset)
# Assert
mock_db_session.query.return_value.where.return_value.delete.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_task.delay.assert_called_once()
def test_delete_segments_empty_list(self, mock_db_session, mock_current_user):
"""Test deletion with empty list (should return early)."""
# Arrange
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
# Act
SegmentService.delete_segments([], document, dataset)
# Assert
mock_db_session.query.assert_not_called()
class TestSegmentServiceUpdateSegmentsStatus:
"""Tests for SegmentService.update_segments_status method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
def mock_current_user(self):
"""Mock current_user."""
user = SegmentTestDataFactory.create_user_mock()
with patch("services.dataset_service.current_user", user):
yield user
def test_update_segments_status_enable(self, mock_db_session, mock_current_user):
"""Test enabling multiple segments."""
# Arrange
segment_ids = ["segment-1", "segment-2"]
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
segments = [
SegmentTestDataFactory.create_segment_mock(segment_id="segment-1", enabled=False),
SegmentTestDataFactory.create_segment_mock(segment_id="segment-2", enabled=False),
]
mock_scalars = MagicMock()
mock_scalars.all.return_value = segments
mock_select = MagicMock()
mock_select.where.return_value = mock_select
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.enable_segments_to_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_redis_get.return_value = None
mock_select_func.return_value = mock_select
# Act
SegmentService.update_segments_status(segment_ids, "enable", dataset, document)
# Assert
assert all(seg.enabled is True for seg in segments)
mock_db_session.commit.assert_called_once()
mock_task.delay.assert_called_once()
def test_update_segments_status_disable(self, mock_db_session, mock_current_user):
"""Test disabling multiple segments."""
# Arrange
segment_ids = ["segment-1", "segment-2"]
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
segments = [
SegmentTestDataFactory.create_segment_mock(segment_id="segment-1", enabled=True),
SegmentTestDataFactory.create_segment_mock(segment_id="segment-2", enabled=True),
]
mock_scalars = MagicMock()
mock_scalars.all.return_value = segments
mock_select = MagicMock()
mock_select.where.return_value = mock_select
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.disable_segments_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_redis_get.return_value = None
mock_now.return_value = "2024-01-01T00:00:00"
mock_select_func.return_value = mock_select
# Act
SegmentService.update_segments_status(segment_ids, "disable", dataset, document)
# Assert
assert all(seg.enabled is False for seg in segments)
mock_db_session.commit.assert_called_once()
mock_task.delay.assert_called_once()
def test_update_segments_status_empty_list(self, mock_db_session, mock_current_user):
"""Test update with empty list (should return early)."""
# Arrange
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
# Act
SegmentService.update_segments_status([], "enable", dataset, document)
# Assert
mock_db_session.scalars.assert_not_called()
class TestSegmentServiceGetSegments:
"""Tests for SegmentService.get_segments method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
def mock_current_user(self):
"""Mock current_user."""
user = SegmentTestDataFactory.create_user_mock()
with patch("services.dataset_service.current_user", user):
yield user
def test_get_segments_success(self, mock_db_session, mock_current_user):
"""Test successful retrieval of segments."""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
segments = [
SegmentTestDataFactory.create_segment_mock(segment_id="segment-1"),
SegmentTestDataFactory.create_segment_mock(segment_id="segment-2"),
]
mock_paginate = MagicMock()
mock_paginate.items = segments
mock_paginate.total = 2
mock_db_session.paginate.return_value = mock_paginate
# Act
items, total = SegmentService.get_segments(document_id, tenant_id)
# Assert
assert len(items) == 2
assert total == 2
mock_db_session.paginate.assert_called_once()
def test_get_segments_with_status_filter(self, mock_db_session, mock_current_user):
"""Test retrieval with status filter."""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed", "error"]
mock_paginate = MagicMock()
mock_paginate.items = []
mock_paginate.total = 0
mock_db_session.paginate.return_value = mock_paginate
# Act
items, total = SegmentService.get_segments(document_id, tenant_id, status_list=status_list)
# Assert
assert len(items) == 0
assert total == 0
def test_get_segments_with_keyword(self, mock_db_session, mock_current_user):
"""Test retrieval with keyword search."""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
keyword = "test"
mock_paginate = MagicMock()
mock_paginate.items = [SegmentTestDataFactory.create_segment_mock()]
mock_paginate.total = 1
mock_db_session.paginate.return_value = mock_paginate
# Act
items, total = SegmentService.get_segments(document_id, tenant_id, keyword=keyword)
# Assert
assert len(items) == 1
assert total == 1
class TestSegmentServiceGetSegmentById:
"""Tests for SegmentService.get_segment_by_id method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_get_segment_by_id_success(self, mock_db_session):
"""Test successful retrieval of segment by ID."""
# Arrange
segment_id = "segment-123"
tenant_id = "tenant-123"
segment = SegmentTestDataFactory.create_segment_mock(segment_id=segment_id)
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = segment
mock_db_session.query.return_value = mock_query
# Act
result = SegmentService.get_segment_by_id(segment_id, tenant_id)
# Assert
assert result == segment
def test_get_segment_by_id_not_found(self, mock_db_session):
"""Test retrieval when segment is not found."""
# Arrange
segment_id = "non-existent"
tenant_id = "tenant-123"
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = None
mock_db_session.query.return_value = mock_query
# Act
result = SegmentService.get_segment_by_id(segment_id, tenant_id)
# Assert
assert result is None
class TestSegmentServiceGetChildChunks:
"""Tests for SegmentService.get_child_chunks method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
def mock_current_user(self):
"""Mock current_user."""
user = SegmentTestDataFactory.create_user_mock()
with patch("services.dataset_service.current_user", user):
yield user
def test_get_child_chunks_success(self, mock_db_session, mock_current_user):
"""Test successful retrieval of child chunks."""
# Arrange
segment_id = "segment-123"
document_id = "doc-123"
dataset_id = "dataset-123"
page = 1
limit = 20
mock_paginate = MagicMock()
mock_paginate.items = [
SegmentTestDataFactory.create_child_chunk_mock(chunk_id="chunk-1"),
SegmentTestDataFactory.create_child_chunk_mock(chunk_id="chunk-2"),
]
mock_paginate.total = 2
mock_db_session.paginate.return_value = mock_paginate
# Act
result = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit)
# Assert
assert result == mock_paginate
mock_db_session.paginate.assert_called_once()
def test_get_child_chunks_with_keyword(self, mock_db_session, mock_current_user):
"""Test retrieval with keyword search."""
# Arrange
segment_id = "segment-123"
document_id = "doc-123"
dataset_id = "dataset-123"
page = 1
limit = 20
keyword = "test"
mock_paginate = MagicMock()
mock_paginate.items = []
mock_paginate.total = 0
mock_db_session.paginate.return_value = mock_paginate
# Act
result = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword=keyword)
# Assert
assert result == mock_paginate
class TestSegmentServiceGetChildChunkById:
"""Tests for SegmentService.get_child_chunk_by_id method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_get_child_chunk_by_id_success(self, mock_db_session):
"""Test successful retrieval of child chunk by ID."""
# Arrange
chunk_id = "chunk-123"
tenant_id = "tenant-123"
chunk = SegmentTestDataFactory.create_child_chunk_mock(chunk_id=chunk_id)
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = chunk
mock_db_session.query.return_value = mock_query
# Act
result = SegmentService.get_child_chunk_by_id(chunk_id, tenant_id)
# Assert
assert result == chunk
def test_get_child_chunk_by_id_not_found(self, mock_db_session):
"""Test retrieval when child chunk is not found."""
# Arrange
chunk_id = "non-existent"
tenant_id = "tenant-123"
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = None
mock_db_session.query.return_value = mock_query
# Act
result = SegmentService.get_child_chunk_by_id(chunk_id, tenant_id)
# Assert
assert result is None
class TestSegmentServiceCreateChildChunk:
"""Tests for SegmentService.create_child_chunk method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
def mock_current_user(self):
"""Mock current_user."""
user = SegmentTestDataFactory.create_user_mock()
with patch("services.dataset_service.current_user", user):
yield user
def test_create_child_chunk_success(self, mock_db_session, mock_current_user):
"""Test successful creation of a child chunk."""
# Arrange
content = "New child chunk content"
segment = SegmentTestDataFactory.create_segment_mock()
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = None
mock_db_session.query.return_value = mock_query
with (
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
mock_hash.return_value = "hash-123"
# Act
result = SegmentService.create_child_chunk(content, segment, document, dataset)
# Assert
assert result is not None
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_vector_service.assert_called_once()
def test_create_child_chunk_vector_index_failure(self, mock_db_session, mock_current_user):
"""Test child chunk creation when vector indexing fails."""
# Arrange
content = "New child chunk content"
segment = SegmentTestDataFactory.create_segment_mock()
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = None
mock_db_session.query.return_value = mock_query
with (
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
mock_vector_service.side_effect = Exception("Vector indexing failed")
mock_hash.return_value = "hash-123"
# Act & Assert
with pytest.raises(ChildChunkIndexingError):
SegmentService.create_child_chunk(content, segment, document, dataset)
mock_db_session.rollback.assert_called_once()
class TestSegmentServiceUpdateChildChunk:
"""Tests for SegmentService.update_child_chunk method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
def mock_current_user(self):
"""Mock current_user."""
user = SegmentTestDataFactory.create_user_mock()
with patch("services.dataset_service.current_user", user):
yield user
def test_update_child_chunk_success(self, mock_db_session, mock_current_user):
"""Test successful update of a child chunk."""
# Arrange
content = "Updated child chunk content"
chunk = SegmentTestDataFactory.create_child_chunk_mock()
segment = SegmentTestDataFactory.create_segment_mock()
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch(
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_now.return_value = "2024-01-01T00:00:00"
# Act
result = SegmentService.update_child_chunk(content, chunk, segment, document, dataset)
# Assert
assert result == chunk
assert chunk.content == content
assert chunk.word_count == len(content)
mock_db_session.add.assert_called_once_with(chunk)
mock_db_session.commit.assert_called_once()
mock_vector_service.assert_called_once()
def test_update_child_chunk_vector_index_failure(self, mock_db_session, mock_current_user):
"""Test child chunk update when vector indexing fails."""
# Arrange
content = "Updated content"
chunk = SegmentTestDataFactory.create_child_chunk_mock()
segment = SegmentTestDataFactory.create_segment_mock()
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch(
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_vector_service.side_effect = Exception("Vector indexing failed")
mock_now.return_value = "2024-01-01T00:00:00"
# Act & Assert
with pytest.raises(ChildChunkIndexingError):
SegmentService.update_child_chunk(content, chunk, segment, document, dataset)
mock_db_session.rollback.assert_called_once()
class TestSegmentServiceDeleteChildChunk:
"""Tests for SegmentService.delete_child_chunk method."""
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_delete_child_chunk_success(self, mock_db_session):
"""Test successful deletion of a child chunk."""
# Arrange
chunk = SegmentTestDataFactory.create_child_chunk_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch(
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
) as mock_vector_service:
# Act
SegmentService.delete_child_chunk(chunk, dataset)
# Assert
mock_db_session.delete.assert_called_once_with(chunk)
mock_db_session.commit.assert_called_once()
mock_vector_service.assert_called_once_with(chunk, dataset)
def test_delete_child_chunk_vector_index_failure(self, mock_db_session):
"""Test child chunk deletion when vector indexing fails."""
# Arrange
chunk = SegmentTestDataFactory.create_child_chunk_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch(
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
) as mock_vector_service:
mock_vector_service.side_effect = Exception("Vector deletion failed")
# Act & Assert
with pytest.raises(ChildChunkDeleteIndexError):
SegmentService.delete_child_chunk(chunk, dataset)
mock_db_session.rollback.assert_called_once()