mirror of https://github.com/langgenius/dify.git
1017 lines
46 KiB
Python
1017 lines
46 KiB
Python
"""Unit tests for SegmentService behaviors in dataset_service."""
|
|
|
|
from .dataset_service_test_helpers import (
|
|
Account,
|
|
ChildChunk,
|
|
ChildChunkDeleteIndexError,
|
|
ChildChunkIndexingError,
|
|
ChildChunkUpdateArgs,
|
|
DocumentSegment,
|
|
IndexStructureType,
|
|
MagicMock,
|
|
SegmentService,
|
|
SegmentUpdateArgs,
|
|
SimpleNamespace,
|
|
_make_child_chunk,
|
|
_make_dataset,
|
|
_make_document,
|
|
_make_lock_context,
|
|
_make_segment,
|
|
create_autospec,
|
|
patch,
|
|
pytest,
|
|
)
|
|
|
|
|
|
class TestSegmentServiceChildChunks:
|
|
"""Unit tests for child-chunk CRUD helpers."""
|
|
|
|
@pytest.fixture
|
|
def account_context(self):
|
|
account = create_autospec(Account, instance=True)
|
|
account.id = "user-1"
|
|
account.current_tenant_id = "tenant-1"
|
|
|
|
with patch("services.dataset_service.current_user", account):
|
|
yield account
|
|
|
|
def test_create_child_chunk_assigns_next_position_and_commits(self, account_context):
|
|
dataset = SimpleNamespace(id="dataset-1")
|
|
document = _make_document()
|
|
segment = _make_segment()
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.uuid.uuid4", return_value="node-1"),
|
|
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"),
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
mock_redis.lock.return_value = _make_lock_context()
|
|
mock_db.session.query.return_value.where.return_value.scalar.return_value = 2
|
|
|
|
child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset)
|
|
|
|
assert isinstance(child_chunk, ChildChunk)
|
|
assert child_chunk.position == 3
|
|
assert child_chunk.index_node_id == "node-1"
|
|
assert child_chunk.index_node_hash == "hash-1"
|
|
assert child_chunk.word_count == len("child content")
|
|
mock_db.session.add.assert_called_once_with(child_chunk)
|
|
vector_service.create_child_chunk_vector.assert_called_once_with(child_chunk, dataset)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_create_child_chunk_rolls_back_and_raises_on_vector_failure(self, account_context):
|
|
dataset = SimpleNamespace(id="dataset-1")
|
|
document = _make_document()
|
|
segment = _make_segment()
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.uuid.uuid4", return_value="node-1"),
|
|
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"),
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
mock_redis.lock.return_value = _make_lock_context()
|
|
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
|
vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed")
|
|
|
|
with pytest.raises(ChildChunkIndexingError, match="vector failed"):
|
|
SegmentService.create_child_chunk("child content", segment, document, dataset)
|
|
|
|
mock_db.session.rollback.assert_called_once()
|
|
mock_db.session.commit.assert_not_called()
|
|
|
|
def test_update_child_chunks_updates_deletes_and_creates_records(self, account_context):
|
|
dataset = SimpleNamespace(id="dataset-1")
|
|
document = _make_document()
|
|
segment = _make_segment()
|
|
existing_a = ChildChunk(
|
|
id="child-a",
|
|
tenant_id="tenant-1",
|
|
dataset_id="dataset-1",
|
|
document_id="doc-1",
|
|
segment_id="segment-1",
|
|
position=1,
|
|
content="old content",
|
|
word_count=11,
|
|
created_by="user-1",
|
|
)
|
|
existing_b = ChildChunk(
|
|
id="child-b",
|
|
tenant_id="tenant-1",
|
|
dataset_id="dataset-1",
|
|
document_id="doc-1",
|
|
segment_id="segment-1",
|
|
position=2,
|
|
content="remove me",
|
|
word_count=9,
|
|
created_by="user-1",
|
|
)
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.uuid.uuid4", return_value="node-new"),
|
|
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-new"),
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
mock_db.session.scalars.return_value.all.return_value = [existing_a, existing_b]
|
|
|
|
result = SegmentService.update_child_chunks(
|
|
[
|
|
ChildChunkUpdateArgs(id="child-a", content="updated content"),
|
|
ChildChunkUpdateArgs(content="brand new"),
|
|
],
|
|
segment,
|
|
document,
|
|
dataset,
|
|
)
|
|
|
|
assert [chunk.position for chunk in result] == [1, 3]
|
|
assert existing_a.content == "updated content"
|
|
assert existing_a.updated_by == account_context.id
|
|
assert existing_a.updated_at == "now"
|
|
mock_db.session.bulk_save_objects.assert_called_once_with([existing_a])
|
|
mock_db.session.delete.assert_called_once_with(existing_b)
|
|
new_chunk = result[1]
|
|
assert isinstance(new_chunk, ChildChunk)
|
|
assert new_chunk.position == 3
|
|
assert new_chunk.index_node_id == "node-new"
|
|
vector_service.update_child_chunk_vector.assert_called_once_with(
|
|
[new_chunk], [existing_a], [existing_b], dataset
|
|
)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_update_child_chunks_rolls_back_on_vector_failure(self, account_context):
|
|
dataset = SimpleNamespace(id="dataset-1")
|
|
document = _make_document()
|
|
segment = _make_segment()
|
|
existing_chunk = _make_child_chunk()
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
mock_db.session.scalars.return_value.all.return_value = [existing_chunk]
|
|
vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed")
|
|
|
|
with pytest.raises(ChildChunkIndexingError, match="vector failed"):
|
|
SegmentService.update_child_chunks(
|
|
[ChildChunkUpdateArgs(id="child-a", content="updated content")],
|
|
segment,
|
|
document,
|
|
dataset,
|
|
)
|
|
|
|
mock_db.session.rollback.assert_called_once()
|
|
|
|
def test_update_child_chunk_updates_vector_and_commits(self, account_context):
|
|
dataset = SimpleNamespace(id="dataset-1")
|
|
child_chunk = _make_child_chunk()
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
result = SegmentService.update_child_chunk(
|
|
"new content", child_chunk, _make_segment(), _make_document(), dataset
|
|
)
|
|
|
|
assert result is child_chunk
|
|
assert child_chunk.content == "new content"
|
|
assert child_chunk.word_count == len("new content")
|
|
assert child_chunk.updated_by == "user-1"
|
|
assert child_chunk.updated_at == "now"
|
|
mock_db.session.add.assert_called_once_with(child_chunk)
|
|
vector_service.update_child_chunk_vector.assert_called_once_with([], [child_chunk], [], dataset)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_delete_child_chunk_raises_delete_index_error_on_vector_failure(self):
|
|
dataset = SimpleNamespace(id="dataset-1")
|
|
child_chunk = _make_child_chunk()
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
vector_service.delete_child_chunk_vector.side_effect = RuntimeError("delete failed")
|
|
|
|
with pytest.raises(ChildChunkDeleteIndexError, match="delete failed"):
|
|
SegmentService.delete_child_chunk(child_chunk, dataset)
|
|
|
|
mock_db.session.delete.assert_called_once_with(child_chunk)
|
|
mock_db.session.rollback.assert_called_once()
|
|
|
|
|
|
class TestSegmentServiceQueries:
|
|
"""Unit tests for child-chunk and segment query helpers."""
|
|
|
|
@pytest.fixture
|
|
def account_context(self):
|
|
account = create_autospec(Account, instance=True)
|
|
account.id = "user-1"
|
|
account.current_tenant_id = "tenant-1"
|
|
|
|
with patch("services.dataset_service.current_user", account):
|
|
yield account
|
|
|
|
def test_get_child_chunks_applies_keyword_filter_and_paginate(self, account_context):
|
|
paginated = SimpleNamespace(items=["chunk"], total=1)
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like,
|
|
):
|
|
mock_db.paginate.return_value = paginated
|
|
|
|
result = SegmentService.get_child_chunks(
|
|
segment_id="segment-1",
|
|
document_id="doc-1",
|
|
dataset_id="dataset-1",
|
|
page=2,
|
|
limit=10,
|
|
keyword="needle",
|
|
)
|
|
|
|
assert result is paginated
|
|
escape_like.assert_called_once_with("needle")
|
|
mock_db.paginate.assert_called_once()
|
|
|
|
def test_get_child_chunk_by_id_returns_only_child_chunk_instances(self):
|
|
child_chunk = _make_child_chunk()
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk
|
|
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
|
|
|
|
assert result is child_chunk
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace()
|
|
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
|
|
|
|
assert result is None
|
|
|
|
def test_get_segments_uses_status_and_keyword_filters(self):
|
|
paginated = SimpleNamespace(items=["segment"], total=1)
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like,
|
|
):
|
|
mock_db.paginate.return_value = paginated
|
|
|
|
items, total = SegmentService.get_segments(
|
|
document_id="doc-1",
|
|
tenant_id="tenant-1",
|
|
status_list=["completed"],
|
|
keyword="needle",
|
|
page=1,
|
|
limit=20,
|
|
)
|
|
|
|
assert items == ["segment"]
|
|
assert total == 1
|
|
escape_like.assert_called_once_with("needle")
|
|
mock_db.paginate.assert_called_once()
|
|
|
|
def test_get_segment_by_id_returns_only_document_segment_instances(self):
|
|
segment = DocumentSegment(
|
|
id="segment-1",
|
|
tenant_id="tenant-1",
|
|
dataset_id="dataset-1",
|
|
document_id="doc-1",
|
|
position=1,
|
|
content="segment",
|
|
word_count=7,
|
|
tokens=2,
|
|
created_by="user-1",
|
|
)
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.first.return_value = segment
|
|
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
|
|
|
|
assert result is segment
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace()
|
|
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
|
|
|
|
assert result is None
|
|
|
|
def test_get_segments_by_document_and_dataset_returns_scalars_result(self):
|
|
segment = DocumentSegment(
|
|
id="segment-1",
|
|
tenant_id="tenant-1",
|
|
dataset_id="dataset-1",
|
|
document_id="doc-1",
|
|
position=1,
|
|
content="segment",
|
|
word_count=7,
|
|
tokens=2,
|
|
created_by="user-1",
|
|
)
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.scalars.return_value.all.return_value = [segment]
|
|
|
|
result = SegmentService.get_segments_by_document_and_dataset(
|
|
document_id="doc-1",
|
|
dataset_id="dataset-1",
|
|
status="completed",
|
|
enabled=True,
|
|
)
|
|
|
|
assert result == [segment]
|
|
mock_db.session.scalars.assert_called_once()
|
|
|
|
|
|
class TestSegmentServiceValidation:
|
|
"""Unit tests for segment-create argument validation."""
|
|
|
|
def test_segment_create_args_validate_requires_answer_for_qa_model(self):
|
|
document = _make_document(doc_form=IndexStructureType.QA_INDEX)
|
|
|
|
with pytest.raises(ValueError, match="Answer is required"):
|
|
SegmentService.segment_create_args_validate({"content": "question"}, document)
|
|
|
|
def test_segment_create_args_validate_requires_non_empty_content(self):
|
|
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
|
|
|
with pytest.raises(ValueError, match="Content is empty"):
|
|
SegmentService.segment_create_args_validate({"content": " "}, document)
|
|
|
|
def test_segment_create_args_validate_enforces_attachment_limit(self):
|
|
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
|
args = {"content": "hello", "attachment_ids": ["a-1", "a-2"]}
|
|
|
|
with patch("services.dataset_service.dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT", 1):
|
|
with pytest.raises(ValueError, match="Exceeded maximum attachment limit of 1"):
|
|
SegmentService.segment_create_args_validate(args, document)
|
|
|
|
def test_segment_create_args_validate_requires_attachment_ids_list(self):
|
|
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
|
|
|
with pytest.raises(ValueError, match="Attachment IDs is invalid"):
|
|
SegmentService.segment_create_args_validate({"content": "hello", "attachment_ids": "bad-type"}, document)
|
|
|
|
|
|
class TestSegmentServiceMutations:
|
|
"""Unit tests for segment create, update, delete, and bulk status flows."""
|
|
|
|
@pytest.fixture
|
|
def account_context(self):
|
|
account = create_autospec(Account, instance=True)
|
|
account.id = "user-1"
|
|
account.current_tenant_id = "tenant-1"
|
|
|
|
with patch("services.dataset_service.current_user", account):
|
|
yield account
|
|
|
|
def test_create_segment_creates_bindings_and_marks_segment_error_on_vector_failure(self, account_context):
|
|
dataset = _make_dataset(indexing_technique="economy")
|
|
document = _make_document(
|
|
dataset_id=dataset.id,
|
|
tenant_id=dataset.tenant_id,
|
|
doc_form=IndexStructureType.QA_INDEX,
|
|
word_count=0,
|
|
)
|
|
refreshed_segment = SimpleNamespace(id="segment-1")
|
|
args = {
|
|
"content": "question",
|
|
"answer": "answer",
|
|
"keywords": ["kw-1"],
|
|
"attachment_ids": ["att-1", "att-2"],
|
|
}
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"),
|
|
patch("services.dataset_service.uuid.uuid4", return_value="node-1"),
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
):
|
|
mock_redis.lock.return_value = _make_lock_context()
|
|
|
|
max_position_query = MagicMock()
|
|
max_position_query.where.return_value.scalar.return_value = 2
|
|
refresh_query = MagicMock()
|
|
refresh_query.where.return_value.first.return_value = refreshed_segment
|
|
mock_db.session.query.side_effect = [max_position_query, refresh_query]
|
|
|
|
def add_side_effect(obj):
|
|
if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None:
|
|
obj.id = "segment-1"
|
|
|
|
mock_db.session.add.side_effect = add_side_effect
|
|
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
|
|
|
|
result = SegmentService.create_segment(args=args, document=document, dataset=dataset)
|
|
|
|
created_segment = vector_service.create_segments_vector.call_args.args[1][0]
|
|
attachment_bindings = [
|
|
call.args[0]
|
|
for call in mock_db.session.add.call_args_list
|
|
if call.args and call.args[0].__class__.__name__ == "SegmentAttachmentBinding"
|
|
]
|
|
|
|
assert result is refreshed_segment
|
|
assert created_segment.position == 3
|
|
assert created_segment.answer == "answer"
|
|
assert created_segment.word_count == len("question") + len("answer")
|
|
assert created_segment.status == "error"
|
|
assert created_segment.enabled is False
|
|
assert created_segment.error == "vector failed"
|
|
assert document.word_count == len("question") + len("answer")
|
|
assert len(attachment_bindings) == 2
|
|
assert {binding.attachment_id for binding in attachment_bindings} == {"att-1", "att-2"}
|
|
assert mock_db.session.commit.call_count == 3
|
|
|
|
def test_multi_create_segment_high_quality_marks_segments_error_when_vector_creation_fails(self, account_context):
|
|
dataset = _make_dataset(indexing_technique="high_quality")
|
|
document = _make_document(
|
|
dataset_id=dataset.id,
|
|
tenant_id=dataset.tenant_id,
|
|
doc_form=IndexStructureType.QA_INDEX,
|
|
word_count=5,
|
|
)
|
|
segments = [
|
|
{"content": "question-1", "answer": "answer-1", "keywords": ["k1"]},
|
|
{"content": "question-2", "answer": "answer-2"},
|
|
]
|
|
embedding_model = MagicMock()
|
|
embedding_model.get_text_embedding_num_tokens.side_effect = [[11], [13]]
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
patch("services.dataset_service.helper.generate_text_hash", side_effect=["hash-1", "hash-2"]),
|
|
patch("services.dataset_service.uuid.uuid4", side_effect=["node-1", "node-2"]),
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
):
|
|
mock_redis.lock.return_value = _make_lock_context()
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
|
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
|
|
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
|
|
|
|
result = SegmentService.multi_create_segment(segments, document, dataset)
|
|
|
|
assert len(result) == 2
|
|
assert [segment.position for segment in result] == [2, 3]
|
|
assert [segment.tokens for segment in result] == [11, 13]
|
|
assert all(segment.status == "error" for segment in result)
|
|
assert all(segment.enabled is False for segment in result)
|
|
assert all(segment.error == "vector failed" for segment in result)
|
|
assert document.word_count == 5 + sum(len(item["content"]) + len(item["answer"]) for item in segments)
|
|
vector_service.create_segments_vector.assert_called_once_with(
|
|
[["k1"], None], result, dataset, document.doc_form
|
|
)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_update_segment_disables_enabled_segment_and_dispatches_index_cleanup(self, account_context):
|
|
segment = _make_segment(enabled=True)
|
|
document = _make_document()
|
|
dataset = _make_dataset()
|
|
args = SegmentUpdateArgs(enabled=False)
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.dataset_service.disable_segment_from_index_task") as disable_task,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
|
|
result = SegmentService.update_segment(args, segment, document, dataset)
|
|
|
|
assert result is segment
|
|
assert segment.enabled is False
|
|
assert segment.disabled_at == "now"
|
|
assert segment.disabled_by == account_context.id
|
|
mock_db.session.add.assert_called_once_with(segment)
|
|
mock_db.session.commit.assert_called_once()
|
|
mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_indexing", 600, 1)
|
|
disable_task.delay.assert_called_once_with(segment.id)
|
|
|
|
def test_update_segment_rejects_updates_for_disabled_segment(self, account_context):
|
|
segment = _make_segment(enabled=False)
|
|
document = _make_document()
|
|
dataset = _make_dataset()
|
|
|
|
with patch("services.dataset_service.redis_client") as mock_redis:
|
|
mock_redis.get.return_value = None
|
|
|
|
with pytest.raises(ValueError, match="Can't update disabled segment"):
|
|
SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset)
|
|
|
|
def test_update_segment_rejects_when_indexing_cache_exists(self, account_context):
|
|
segment = _make_segment(enabled=True)
|
|
document = _make_document()
|
|
dataset = _make_dataset()
|
|
|
|
with patch("services.dataset_service.redis_client") as mock_redis:
|
|
mock_redis.get.return_value = "1"
|
|
|
|
with pytest.raises(ValueError, match="Segment is indexing"):
|
|
SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset)
|
|
|
|
def test_update_segment_updates_keywords_for_same_content_segment(self, account_context):
|
|
segment = _make_segment(content="same content", keywords=["old"])
|
|
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=20)
|
|
dataset = _make_dataset()
|
|
refreshed_segment = SimpleNamespace(id=segment.id)
|
|
args = SegmentUpdateArgs(content="same content", keywords=["new"])
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
|
|
|
|
result = SegmentService.update_segment(args, segment, document, dataset)
|
|
|
|
assert result is refreshed_segment
|
|
assert segment.keywords == ["new"]
|
|
vector_service.update_segment_vector.assert_called_once_with(["new"], segment, dataset)
|
|
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
|
|
|
|
def test_update_segment_regenerates_child_chunks_and_updates_manual_summary(self, account_context):
|
|
segment = _make_segment(content="same content", word_count=len("same content"))
|
|
document = _make_document(
|
|
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
|
|
word_count=20,
|
|
)
|
|
dataset = _make_dataset(indexing_technique="high_quality")
|
|
refreshed_segment = SimpleNamespace(id=segment.id)
|
|
processing_rule = SimpleNamespace(id=document.dataset_process_rule_id)
|
|
existing_summary = SimpleNamespace(summary_content="old summary")
|
|
embedding_model_instance = object()
|
|
args = SegmentUpdateArgs(
|
|
content="same content",
|
|
regenerate_child_chunks=True,
|
|
summary="new summary",
|
|
)
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model_instance
|
|
|
|
processing_rule_query = MagicMock()
|
|
processing_rule_query.where.return_value.first.return_value = processing_rule
|
|
summary_query = MagicMock()
|
|
summary_query.where.return_value.first.return_value = existing_summary
|
|
refreshed_query = MagicMock()
|
|
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
|
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
|
|
|
|
result = SegmentService.update_segment(args, segment, document, dataset)
|
|
|
|
assert result is refreshed_segment
|
|
vector_service.generate_child_chunks.assert_called_once_with(
|
|
segment,
|
|
document,
|
|
dataset,
|
|
embedding_model_instance,
|
|
processing_rule,
|
|
True,
|
|
)
|
|
update_summary.assert_called_once_with(segment, dataset, "new summary")
|
|
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
|
|
|
|
def test_update_segment_auto_regenerates_summary_after_content_change(self, account_context):
|
|
segment = _make_segment(content="old", word_count=3)
|
|
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10)
|
|
dataset = _make_dataset(indexing_technique="high_quality")
|
|
dataset.summary_index_setting = {"enable": True}
|
|
refreshed_segment = SimpleNamespace(id=segment.id)
|
|
existing_summary = SimpleNamespace(summary_content="old summary")
|
|
embedding_model = MagicMock()
|
|
embedding_model.get_text_embedding_num_tokens.return_value = [9]
|
|
args = SegmentUpdateArgs(content="new content", keywords=["kw-1"])
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"),
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch(
|
|
"services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary"
|
|
) as generate_summary,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
|
|
|
summary_query = MagicMock()
|
|
summary_query.where.return_value.first.return_value = existing_summary
|
|
refreshed_query = MagicMock()
|
|
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
|
mock_db.session.query.side_effect = [summary_query, refreshed_query]
|
|
|
|
result = SegmentService.update_segment(args, segment, document, dataset)
|
|
|
|
assert result is refreshed_segment
|
|
assert segment.content == "new content"
|
|
assert segment.index_node_hash == "hash-1"
|
|
assert segment.tokens == 9
|
|
assert document.word_count == 18
|
|
vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset)
|
|
generate_summary.assert_called_once_with(segment, dataset, {"enable": True})
|
|
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
|
|
|
|
def test_update_segment_regenerates_summary_when_manual_summary_is_unchanged(self, account_context):
|
|
segment = _make_segment(content="old", word_count=3)
|
|
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10)
|
|
dataset = _make_dataset(indexing_technique="high_quality")
|
|
dataset.summary_index_setting = {"enable": True}
|
|
refreshed_segment = SimpleNamespace(id=segment.id)
|
|
existing_summary = SimpleNamespace(summary_content="same summary")
|
|
embedding_model = MagicMock()
|
|
embedding_model.get_text_embedding_num_tokens.return_value = [7]
|
|
args = SegmentUpdateArgs(content="new text", summary="same summary")
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-2"),
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch(
|
|
"services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary"
|
|
) as generate_summary,
|
|
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
|
|
|
summary_query = MagicMock()
|
|
summary_query.where.return_value.first.return_value = existing_summary
|
|
refreshed_query = MagicMock()
|
|
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
|
mock_db.session.query.side_effect = [summary_query, refreshed_query]
|
|
|
|
result = SegmentService.update_segment(args, segment, document, dataset)
|
|
|
|
assert result is refreshed_segment
|
|
generate_summary.assert_called_once_with(segment, dataset, {"enable": True})
|
|
update_summary.assert_not_called()
|
|
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
|
|
|
|
def test_delete_segment_removes_index_and_updates_document_word_count(self):
|
|
segment = _make_segment(word_count=4, index_node_id="parent-node")
|
|
document = _make_document(word_count=10)
|
|
dataset = _make_dataset()
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)]
|
|
|
|
SegmentService.delete_segment(segment, document, dataset)
|
|
|
|
assert document.word_count == 6
|
|
mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_delete_indexing", 600, 1)
|
|
delete_task.delay.assert_called_once_with(
|
|
["parent-node"],
|
|
dataset.id,
|
|
document.id,
|
|
[segment.id],
|
|
["child-1", "child-2"],
|
|
)
|
|
mock_db.session.delete.assert_called_once_with(segment)
|
|
mock_db.session.add.assert_called_once_with(document)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_delete_segment_rejects_when_delete_is_already_in_progress(self):
|
|
segment = _make_segment()
|
|
document = _make_document()
|
|
dataset = _make_dataset()
|
|
|
|
with patch("services.dataset_service.redis_client") as mock_redis:
|
|
mock_redis.get.return_value = "1"
|
|
|
|
with pytest.raises(ValueError, match="Segment is deleting"):
|
|
SegmentService.delete_segment(segment, document, dataset)
|
|
|
|
def test_delete_segments_removes_records_and_clamps_document_word_count(self):
|
|
dataset = _make_dataset()
|
|
document = _make_document(word_count=3)
|
|
current_user = SimpleNamespace(current_tenant_id="tenant-1")
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
|
|
):
|
|
segments_query = MagicMock()
|
|
segments_query.with_entities.return_value.where.return_value.all.return_value = [
|
|
("node-1", "segment-1", 2),
|
|
("node-2", "segment-2", 5),
|
|
]
|
|
child_query = MagicMock()
|
|
child_query.where.return_value.all.return_value = [("child-1",)]
|
|
delete_query = MagicMock()
|
|
delete_query.where.return_value.delete.return_value = 2
|
|
mock_db.session.query.side_effect = [segments_query, child_query, delete_query]
|
|
|
|
SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset)
|
|
|
|
assert document.word_count == 0
|
|
mock_db.session.add.assert_called_once_with(document)
|
|
delete_task.delay.assert_called_once_with(
|
|
["node-1", "node-2"],
|
|
dataset.id,
|
|
document.id,
|
|
["segment-1", "segment-2"],
|
|
["child-1"],
|
|
)
|
|
delete_query.where.return_value.delete.assert_called_once()
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_update_segments_status_enables_only_segments_without_indexing_cache(self):
|
|
dataset = _make_dataset()
|
|
document = _make_document()
|
|
segment_a = _make_segment(segment_id="segment-a", enabled=False)
|
|
segment_b = _make_segment(segment_id="segment-b", enabled=False)
|
|
current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1")
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.dataset_service.enable_segments_to_index_task") as enable_task,
|
|
):
|
|
mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b]
|
|
mock_redis.get.side_effect = [None, "1"]
|
|
|
|
SegmentService.update_segments_status(["segment-a", "segment-b"], "enable", dataset, document)
|
|
|
|
assert segment_a.enabled is True
|
|
assert segment_a.disabled_at is None
|
|
assert segment_a.disabled_by is None
|
|
assert segment_b.enabled is False
|
|
mock_db.session.add.assert_called_once_with(segment_a)
|
|
mock_db.session.commit.assert_called_once()
|
|
enable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id)
|
|
|
|
def test_update_segments_status_disables_only_segments_without_indexing_cache(self):
|
|
dataset = _make_dataset()
|
|
document = _make_document()
|
|
segment_a = _make_segment(segment_id="segment-a", enabled=True)
|
|
segment_b = _make_segment(segment_id="segment-b", enabled=True)
|
|
current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1")
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.dataset_service.disable_segments_from_index_task") as disable_task,
|
|
):
|
|
mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b]
|
|
mock_redis.get.side_effect = [None, "1"]
|
|
|
|
SegmentService.update_segments_status(["segment-a", "segment-b"], "disable", dataset, document)
|
|
|
|
assert segment_a.enabled is False
|
|
assert segment_a.disabled_at == "now"
|
|
assert segment_a.disabled_by == current_user.id
|
|
assert segment_b.enabled is True
|
|
mock_db.session.add.assert_called_once_with(segment_a)
|
|
mock_db.session.commit.assert_called_once()
|
|
disable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id)
|
|
|
|
|
|
class TestSegmentServiceChildChunkTailHelpers:
|
|
"""Unit tests for the remaining child-chunk helper branches."""
|
|
|
|
def test_update_child_chunk_rolls_back_on_vector_failure(self):
|
|
dataset = SimpleNamespace(id="dataset-1")
|
|
child_chunk = _make_child_chunk()
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed")
|
|
|
|
with pytest.raises(ChildChunkIndexingError, match="vector failed"):
|
|
SegmentService.update_child_chunk(
|
|
"new content", child_chunk, SimpleNamespace(), SimpleNamespace(), dataset
|
|
)
|
|
|
|
mock_db.session.rollback.assert_called_once()
|
|
mock_db.session.commit.assert_not_called()
|
|
|
|
def test_delete_child_chunk_commits_after_successful_vector_delete(self):
|
|
dataset = SimpleNamespace(id="dataset-1")
|
|
child_chunk = _make_child_chunk()
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
SegmentService.delete_child_chunk(child_chunk, dataset)
|
|
|
|
mock_db.session.delete.assert_called_once_with(child_chunk)
|
|
vector_service.delete_child_chunk_vector.assert_called_once_with(child_chunk, dataset)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
|
|
class TestSegmentServiceAdditionalRegenerationBranches:
|
|
"""Additional unit tests for segment update and regeneration edge cases."""
|
|
|
|
@pytest.fixture
|
|
def account_context(self):
|
|
account = create_autospec(Account, instance=True)
|
|
account.id = "user-1"
|
|
account.current_tenant_id = "tenant-1"
|
|
|
|
with patch("services.dataset_service.current_user", account):
|
|
yield account
|
|
|
|
def test_update_segment_same_content_updates_answer_and_document_word_count_for_qa_segments(self, account_context):
|
|
segment = _make_segment(content="question", word_count=8)
|
|
document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=20)
|
|
dataset = _make_dataset()
|
|
refreshed_segment = SimpleNamespace(id=segment.id)
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
|
|
|
|
result = SegmentService.update_segment(
|
|
SegmentUpdateArgs(content="question", answer="new answer"),
|
|
segment,
|
|
document,
|
|
dataset,
|
|
)
|
|
|
|
assert result is refreshed_segment
|
|
assert segment.answer == "new answer"
|
|
assert segment.word_count == len("question") + len("new answer")
|
|
assert document.word_count == 20 + (len("question") + len("new answer") - 8)
|
|
vector_service.update_segment_vector.assert_not_called()
|
|
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
|
|
|
|
def test_update_segment_content_change_uses_answer_when_counting_tokens_for_qa_segments(self, account_context):
|
|
segment = _make_segment(content="old", word_count=3)
|
|
document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=10)
|
|
dataset = _make_dataset(indexing_technique="high_quality")
|
|
refreshed_segment = SimpleNamespace(id=segment.id)
|
|
embedding_model = MagicMock()
|
|
embedding_model.get_text_embedding_num_tokens.return_value = [21]
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-qa"),
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
):
|
|
mock_redis.get.return_value = None
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
|
summary_query = MagicMock()
|
|
summary_query.where.return_value.first.return_value = None
|
|
refreshed_query = MagicMock()
|
|
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
|
mock_db.session.query.side_effect = [summary_query, refreshed_query]
|
|
|
|
result = SegmentService.update_segment(
|
|
SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]),
|
|
segment,
|
|
document,
|
|
dataset,
|
|
)
|
|
|
|
assert result is refreshed_segment
|
|
embedding_model.get_text_embedding_num_tokens.assert_called_once_with(texts=["new questionnew answer"])
|
|
assert segment.answer == "new answer"
|
|
assert segment.tokens == 21
|
|
assert segment.word_count == len("new question") + len("new answer")
|
|
vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset)
|
|
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
|
|
|
|
def test_update_segment_content_change_parent_child_uses_default_embedding_and_ignores_summary_failures(
|
|
self, account_context
|
|
):
|
|
segment = _make_segment(content="old", word_count=3)
|
|
document = _make_document(
|
|
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
|
|
word_count=10,
|
|
)
|
|
dataset = _make_dataset(indexing_technique="high_quality")
|
|
dataset.embedding_model_provider = None
|
|
refreshed_segment = SimpleNamespace(id=segment.id)
|
|
processing_rule = SimpleNamespace(id=document.dataset_process_rule_id)
|
|
existing_summary = SimpleNamespace(summary_content="old summary")
|
|
embedding_model_instance = object()
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-parent"),
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
model_manager_cls.return_value.get_default_model_instance.return_value = embedding_model_instance
|
|
update_summary.side_effect = RuntimeError("summary failed")
|
|
|
|
processing_rule_query = MagicMock()
|
|
processing_rule_query.where.return_value.first.return_value = processing_rule
|
|
summary_query = MagicMock()
|
|
summary_query.where.return_value.first.return_value = existing_summary
|
|
refreshed_query = MagicMock()
|
|
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
|
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
|
|
|
|
result = SegmentService.update_segment(
|
|
SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"),
|
|
segment,
|
|
document,
|
|
dataset,
|
|
)
|
|
|
|
assert result is refreshed_segment
|
|
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
|
|
tenant_id="tenant-1",
|
|
model_type="text-embedding",
|
|
)
|
|
vector_service.generate_child_chunks.assert_called_once_with(
|
|
segment,
|
|
document,
|
|
dataset,
|
|
embedding_model_instance,
|
|
processing_rule,
|
|
True,
|
|
)
|
|
update_summary.assert_called_once_with(segment, dataset, "new summary")
|
|
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
|
|
|
|
def test_update_segment_same_content_parent_child_marks_segment_error_for_non_high_quality_dataset(
|
|
self, account_context
|
|
):
|
|
segment = _make_segment(content="same content", word_count=len("same content"))
|
|
document = _make_document(
|
|
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
|
|
word_count=20,
|
|
)
|
|
dataset = _make_dataset(indexing_technique="economy")
|
|
refreshed_segment = SimpleNamespace(id=segment.id)
|
|
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
|
patch("services.dataset_service.VectorService") as vector_service,
|
|
):
|
|
mock_redis.get.return_value = None
|
|
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
|
|
|
|
result = SegmentService.update_segment(
|
|
SegmentUpdateArgs(content="same content", regenerate_child_chunks=True),
|
|
segment,
|
|
document,
|
|
dataset,
|
|
)
|
|
|
|
assert result is refreshed_segment
|
|
assert segment.enabled is False
|
|
assert segment.disabled_at == "now"
|
|
assert segment.status == "error"
|
|
assert segment.error == "The knowledge base index technique is not high quality!"
|
|
vector_service.update_multimodel_vector.assert_not_called()
|