dify/api/tests/unit_tests/services/test_dataset_service_segmen...

1017 lines
46 KiB
Python

"""Unit tests for SegmentService behaviors in dataset_service."""
from .dataset_service_test_helpers import (
Account,
ChildChunk,
ChildChunkDeleteIndexError,
ChildChunkIndexingError,
ChildChunkUpdateArgs,
DocumentSegment,
IndexStructureType,
MagicMock,
SegmentService,
SegmentUpdateArgs,
SimpleNamespace,
_make_child_chunk,
_make_dataset,
_make_document,
_make_lock_context,
_make_segment,
create_autospec,
patch,
pytest,
)
class TestSegmentServiceChildChunks:
"""Unit tests for child-chunk CRUD helpers."""
@pytest.fixture
def account_context(self):
account = create_autospec(Account, instance=True)
account.id = "user-1"
account.current_tenant_id = "tenant-1"
with patch("services.dataset_service.current_user", account):
yield account
def test_create_child_chunk_assigns_next_position_and_commits(self, account_context):
dataset = SimpleNamespace(id="dataset-1")
document = _make_document()
segment = _make_segment()
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.uuid.uuid4", return_value="node-1"),
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"),
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.lock.return_value = _make_lock_context()
mock_db.session.query.return_value.where.return_value.scalar.return_value = 2
child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset)
assert isinstance(child_chunk, ChildChunk)
assert child_chunk.position == 3
assert child_chunk.index_node_id == "node-1"
assert child_chunk.index_node_hash == "hash-1"
assert child_chunk.word_count == len("child content")
mock_db.session.add.assert_called_once_with(child_chunk)
vector_service.create_child_chunk_vector.assert_called_once_with(child_chunk, dataset)
mock_db.session.commit.assert_called_once()
def test_create_child_chunk_rolls_back_and_raises_on_vector_failure(self, account_context):
dataset = SimpleNamespace(id="dataset-1")
document = _make_document()
segment = _make_segment()
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.uuid.uuid4", return_value="node-1"),
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"),
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.lock.return_value = _make_lock_context()
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed")
with pytest.raises(ChildChunkIndexingError, match="vector failed"):
SegmentService.create_child_chunk("child content", segment, document, dataset)
mock_db.session.rollback.assert_called_once()
mock_db.session.commit.assert_not_called()
def test_update_child_chunks_updates_deletes_and_creates_records(self, account_context):
dataset = SimpleNamespace(id="dataset-1")
document = _make_document()
segment = _make_segment()
existing_a = ChildChunk(
id="child-a",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
segment_id="segment-1",
position=1,
content="old content",
word_count=11,
created_by="user-1",
)
existing_b = ChildChunk(
id="child-b",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
segment_id="segment-1",
position=2,
content="remove me",
word_count=9,
created_by="user-1",
)
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.uuid.uuid4", return_value="node-new"),
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-new"),
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.VectorService") as vector_service,
):
mock_db.session.scalars.return_value.all.return_value = [existing_a, existing_b]
result = SegmentService.update_child_chunks(
[
ChildChunkUpdateArgs(id="child-a", content="updated content"),
ChildChunkUpdateArgs(content="brand new"),
],
segment,
document,
dataset,
)
assert [chunk.position for chunk in result] == [1, 3]
assert existing_a.content == "updated content"
assert existing_a.updated_by == account_context.id
assert existing_a.updated_at == "now"
mock_db.session.bulk_save_objects.assert_called_once_with([existing_a])
mock_db.session.delete.assert_called_once_with(existing_b)
new_chunk = result[1]
assert isinstance(new_chunk, ChildChunk)
assert new_chunk.position == 3
assert new_chunk.index_node_id == "node-new"
vector_service.update_child_chunk_vector.assert_called_once_with(
[new_chunk], [existing_a], [existing_b], dataset
)
mock_db.session.commit.assert_called_once()
def test_update_child_chunks_rolls_back_on_vector_failure(self, account_context):
dataset = SimpleNamespace(id="dataset-1")
document = _make_document()
segment = _make_segment()
existing_chunk = _make_child_chunk()
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.VectorService") as vector_service,
):
mock_db.session.scalars.return_value.all.return_value = [existing_chunk]
vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed")
with pytest.raises(ChildChunkIndexingError, match="vector failed"):
SegmentService.update_child_chunks(
[ChildChunkUpdateArgs(id="child-a", content="updated content")],
segment,
document,
dataset,
)
mock_db.session.rollback.assert_called_once()
def test_update_child_chunk_updates_vector_and_commits(self, account_context):
dataset = SimpleNamespace(id="dataset-1")
child_chunk = _make_child_chunk()
with (
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.VectorService") as vector_service,
):
result = SegmentService.update_child_chunk(
"new content", child_chunk, _make_segment(), _make_document(), dataset
)
assert result is child_chunk
assert child_chunk.content == "new content"
assert child_chunk.word_count == len("new content")
assert child_chunk.updated_by == "user-1"
assert child_chunk.updated_at == "now"
mock_db.session.add.assert_called_once_with(child_chunk)
vector_service.update_child_chunk_vector.assert_called_once_with([], [child_chunk], [], dataset)
mock_db.session.commit.assert_called_once()
def test_delete_child_chunk_raises_delete_index_error_on_vector_failure(self):
dataset = SimpleNamespace(id="dataset-1")
child_chunk = _make_child_chunk()
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.VectorService") as vector_service,
):
vector_service.delete_child_chunk_vector.side_effect = RuntimeError("delete failed")
with pytest.raises(ChildChunkDeleteIndexError, match="delete failed"):
SegmentService.delete_child_chunk(child_chunk, dataset)
mock_db.session.delete.assert_called_once_with(child_chunk)
mock_db.session.rollback.assert_called_once()
class TestSegmentServiceQueries:
"""Unit tests for child-chunk and segment query helpers."""
@pytest.fixture
def account_context(self):
account = create_autospec(Account, instance=True)
account.id = "user-1"
account.current_tenant_id = "tenant-1"
with patch("services.dataset_service.current_user", account):
yield account
def test_get_child_chunks_applies_keyword_filter_and_paginate(self, account_context):
paginated = SimpleNamespace(items=["chunk"], total=1)
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like,
):
mock_db.paginate.return_value = paginated
result = SegmentService.get_child_chunks(
segment_id="segment-1",
document_id="doc-1",
dataset_id="dataset-1",
page=2,
limit=10,
keyword="needle",
)
assert result is paginated
escape_like.assert_called_once_with("needle")
mock_db.paginate.assert_called_once()
def test_get_child_chunk_by_id_returns_only_child_chunk_instances(self):
child_chunk = _make_child_chunk()
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
assert result is child_chunk
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace()
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
assert result is None
def test_get_segments_uses_status_and_keyword_filters(self):
paginated = SimpleNamespace(items=["segment"], total=1)
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like,
):
mock_db.paginate.return_value = paginated
items, total = SegmentService.get_segments(
document_id="doc-1",
tenant_id="tenant-1",
status_list=["completed"],
keyword="needle",
page=1,
limit=20,
)
assert items == ["segment"]
assert total == 1
escape_like.assert_called_once_with("needle")
mock_db.paginate.assert_called_once()
def test_get_segment_by_id_returns_only_document_segment_instances(self):
segment = DocumentSegment(
id="segment-1",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
position=1,
content="segment",
word_count=7,
tokens=2,
created_by="user-1",
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = segment
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
assert result is segment
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace()
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
assert result is None
def test_get_segments_by_document_and_dataset_returns_scalars_result(self):
segment = DocumentSegment(
id="segment-1",
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
position=1,
content="segment",
word_count=7,
tokens=2,
created_by="user-1",
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalars.return_value.all.return_value = [segment]
result = SegmentService.get_segments_by_document_and_dataset(
document_id="doc-1",
dataset_id="dataset-1",
status="completed",
enabled=True,
)
assert result == [segment]
mock_db.session.scalars.assert_called_once()
class TestSegmentServiceValidation:
"""Unit tests for segment-create argument validation."""
def test_segment_create_args_validate_requires_answer_for_qa_model(self):
document = _make_document(doc_form=IndexStructureType.QA_INDEX)
with pytest.raises(ValueError, match="Answer is required"):
SegmentService.segment_create_args_validate({"content": "question"}, document)
def test_segment_create_args_validate_requires_non_empty_content(self):
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX)
with pytest.raises(ValueError, match="Content is empty"):
SegmentService.segment_create_args_validate({"content": " "}, document)
def test_segment_create_args_validate_enforces_attachment_limit(self):
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX)
args = {"content": "hello", "attachment_ids": ["a-1", "a-2"]}
with patch("services.dataset_service.dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT", 1):
with pytest.raises(ValueError, match="Exceeded maximum attachment limit of 1"):
SegmentService.segment_create_args_validate(args, document)
def test_segment_create_args_validate_requires_attachment_ids_list(self):
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX)
with pytest.raises(ValueError, match="Attachment IDs is invalid"):
SegmentService.segment_create_args_validate({"content": "hello", "attachment_ids": "bad-type"}, document)
class TestSegmentServiceMutations:
"""Unit tests for segment create, update, delete, and bulk status flows."""
@pytest.fixture
def account_context(self):
account = create_autospec(Account, instance=True)
account.id = "user-1"
account.current_tenant_id = "tenant-1"
with patch("services.dataset_service.current_user", account):
yield account
def test_create_segment_creates_bindings_and_marks_segment_error_on_vector_failure(self, account_context):
dataset = _make_dataset(indexing_technique="economy")
document = _make_document(
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
doc_form=IndexStructureType.QA_INDEX,
word_count=0,
)
refreshed_segment = SimpleNamespace(id="segment-1")
args = {
"content": "question",
"answer": "answer",
"keywords": ["kw-1"],
"attachment_ids": ["att-1", "att-2"],
}
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.VectorService") as vector_service,
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"),
patch("services.dataset_service.uuid.uuid4", return_value="node-1"),
patch("services.dataset_service.naive_utc_now", return_value="now"),
):
mock_redis.lock.return_value = _make_lock_context()
max_position_query = MagicMock()
max_position_query.where.return_value.scalar.return_value = 2
refresh_query = MagicMock()
refresh_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [max_position_query, refresh_query]
def add_side_effect(obj):
if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None:
obj.id = "segment-1"
mock_db.session.add.side_effect = add_side_effect
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
result = SegmentService.create_segment(args=args, document=document, dataset=dataset)
created_segment = vector_service.create_segments_vector.call_args.args[1][0]
attachment_bindings = [
call.args[0]
for call in mock_db.session.add.call_args_list
if call.args and call.args[0].__class__.__name__ == "SegmentAttachmentBinding"
]
assert result is refreshed_segment
assert created_segment.position == 3
assert created_segment.answer == "answer"
assert created_segment.word_count == len("question") + len("answer")
assert created_segment.status == "error"
assert created_segment.enabled is False
assert created_segment.error == "vector failed"
assert document.word_count == len("question") + len("answer")
assert len(attachment_bindings) == 2
assert {binding.attachment_id for binding in attachment_bindings} == {"att-1", "att-2"}
assert mock_db.session.commit.call_count == 3
def test_multi_create_segment_high_quality_marks_segments_error_when_vector_creation_fails(self, account_context):
dataset = _make_dataset(indexing_technique="high_quality")
document = _make_document(
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
doc_form=IndexStructureType.QA_INDEX,
word_count=5,
)
segments = [
{"content": "question-1", "answer": "answer-1", "keywords": ["k1"]},
{"content": "question-2", "answer": "answer-2"},
]
embedding_model = MagicMock()
embedding_model.get_text_embedding_num_tokens.side_effect = [[11], [13]]
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.VectorService") as vector_service,
patch("services.dataset_service.helper.generate_text_hash", side_effect=["hash-1", "hash-2"]),
patch("services.dataset_service.uuid.uuid4", side_effect=["node-1", "node-2"]),
patch("services.dataset_service.naive_utc_now", return_value="now"),
):
mock_redis.lock.return_value = _make_lock_context()
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
result = SegmentService.multi_create_segment(segments, document, dataset)
assert len(result) == 2
assert [segment.position for segment in result] == [2, 3]
assert [segment.tokens for segment in result] == [11, 13]
assert all(segment.status == "error" for segment in result)
assert all(segment.enabled is False for segment in result)
assert all(segment.error == "vector failed" for segment in result)
assert document.word_count == 5 + sum(len(item["content"]) + len(item["answer"]) for item in segments)
vector_service.create_segments_vector.assert_called_once_with(
[["k1"], None], result, dataset, document.doc_form
)
mock_db.session.commit.assert_called_once()
def test_update_segment_disables_enabled_segment_and_dispatches_index_cleanup(self, account_context):
segment = _make_segment(enabled=True)
document = _make_document()
dataset = _make_dataset()
args = SegmentUpdateArgs(enabled=False)
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.disable_segment_from_index_task") as disable_task,
):
mock_redis.get.return_value = None
result = SegmentService.update_segment(args, segment, document, dataset)
assert result is segment
assert segment.enabled is False
assert segment.disabled_at == "now"
assert segment.disabled_by == account_context.id
mock_db.session.add.assert_called_once_with(segment)
mock_db.session.commit.assert_called_once()
mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_indexing", 600, 1)
disable_task.delay.assert_called_once_with(segment.id)
def test_update_segment_rejects_updates_for_disabled_segment(self, account_context):
segment = _make_segment(enabled=False)
document = _make_document()
dataset = _make_dataset()
with patch("services.dataset_service.redis_client") as mock_redis:
mock_redis.get.return_value = None
with pytest.raises(ValueError, match="Can't update disabled segment"):
SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset)
def test_update_segment_rejects_when_indexing_cache_exists(self, account_context):
segment = _make_segment(enabled=True)
document = _make_document()
dataset = _make_dataset()
with patch("services.dataset_service.redis_client") as mock_redis:
mock_redis.get.return_value = "1"
with pytest.raises(ValueError, match="Segment is indexing"):
SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset)
def test_update_segment_updates_keywords_for_same_content_segment(self, account_context):
segment = _make_segment(content="same content", keywords=["old"])
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=20)
dataset = _make_dataset()
refreshed_segment = SimpleNamespace(id=segment.id)
args = SegmentUpdateArgs(content="same content", keywords=["new"])
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
result = SegmentService.update_segment(args, segment, document, dataset)
assert result is refreshed_segment
assert segment.keywords == ["new"]
vector_service.update_segment_vector.assert_called_once_with(["new"], segment, dataset)
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
def test_update_segment_regenerates_child_chunks_and_updates_manual_summary(self, account_context):
segment = _make_segment(content="same content", word_count=len("same content"))
document = _make_document(
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
word_count=20,
)
dataset = _make_dataset(indexing_technique="high_quality")
refreshed_segment = SimpleNamespace(id=segment.id)
processing_rule = SimpleNamespace(id=document.dataset_process_rule_id)
existing_summary = SimpleNamespace(summary_content="old summary")
embedding_model_instance = object()
args = SegmentUpdateArgs(
content="same content",
regenerate_child_chunks=True,
summary="new summary",
)
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.VectorService") as vector_service,
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model_instance
processing_rule_query = MagicMock()
processing_rule_query.where.return_value.first.return_value = processing_rule
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
result = SegmentService.update_segment(args, segment, document, dataset)
assert result is refreshed_segment
vector_service.generate_child_chunks.assert_called_once_with(
segment,
document,
dataset,
embedding_model_instance,
processing_rule,
True,
)
update_summary.assert_called_once_with(segment, dataset, "new summary")
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
def test_update_segment_auto_regenerates_summary_after_content_change(self, account_context):
segment = _make_segment(content="old", word_count=3)
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10)
dataset = _make_dataset(indexing_technique="high_quality")
dataset.summary_index_setting = {"enable": True}
refreshed_segment = SimpleNamespace(id=segment.id)
existing_summary = SimpleNamespace(summary_content="old summary")
embedding_model = MagicMock()
embedding_model.get_text_embedding_num_tokens.return_value = [9]
args = SegmentUpdateArgs(content="new content", keywords=["kw-1"])
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.VectorService") as vector_service,
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"),
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch(
"services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary"
) as generate_summary,
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
result = SegmentService.update_segment(args, segment, document, dataset)
assert result is refreshed_segment
assert segment.content == "new content"
assert segment.index_node_hash == "hash-1"
assert segment.tokens == 9
assert document.word_count == 18
vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset)
generate_summary.assert_called_once_with(segment, dataset, {"enable": True})
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
def test_update_segment_regenerates_summary_when_manual_summary_is_unchanged(self, account_context):
segment = _make_segment(content="old", word_count=3)
document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10)
dataset = _make_dataset(indexing_technique="high_quality")
dataset.summary_index_setting = {"enable": True}
refreshed_segment = SimpleNamespace(id=segment.id)
existing_summary = SimpleNamespace(summary_content="same summary")
embedding_model = MagicMock()
embedding_model.get_text_embedding_num_tokens.return_value = [7]
args = SegmentUpdateArgs(content="new text", summary="same summary")
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.VectorService") as vector_service,
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-2"),
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch(
"services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary"
) as generate_summary,
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
result = SegmentService.update_segment(args, segment, document, dataset)
assert result is refreshed_segment
generate_summary.assert_called_once_with(segment, dataset, {"enable": True})
update_summary.assert_not_called()
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
def test_delete_segment_removes_index_and_updates_document_word_count(self):
segment = _make_segment(word_count=4, index_node_id="parent-node")
document = _make_document(word_count=10)
dataset = _make_dataset()
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
):
mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)]
SegmentService.delete_segment(segment, document, dataset)
assert document.word_count == 6
mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_delete_indexing", 600, 1)
delete_task.delay.assert_called_once_with(
["parent-node"],
dataset.id,
document.id,
[segment.id],
["child-1", "child-2"],
)
mock_db.session.delete.assert_called_once_with(segment)
mock_db.session.add.assert_called_once_with(document)
mock_db.session.commit.assert_called_once()
def test_delete_segment_rejects_when_delete_is_already_in_progress(self):
segment = _make_segment()
document = _make_document()
dataset = _make_dataset()
with patch("services.dataset_service.redis_client") as mock_redis:
mock_redis.get.return_value = "1"
with pytest.raises(ValueError, match="Segment is deleting"):
SegmentService.delete_segment(segment, document, dataset)
def test_delete_segments_removes_records_and_clamps_document_word_count(self):
dataset = _make_dataset()
document = _make_document(word_count=3)
current_user = SimpleNamespace(current_tenant_id="tenant-1")
with (
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
):
segments_query = MagicMock()
segments_query.with_entities.return_value.where.return_value.all.return_value = [
("node-1", "segment-1", 2),
("node-2", "segment-2", 5),
]
child_query = MagicMock()
child_query.where.return_value.all.return_value = [("child-1",)]
delete_query = MagicMock()
delete_query.where.return_value.delete.return_value = 2
mock_db.session.query.side_effect = [segments_query, child_query, delete_query]
SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset)
assert document.word_count == 0
mock_db.session.add.assert_called_once_with(document)
delete_task.delay.assert_called_once_with(
["node-1", "node-2"],
dataset.id,
document.id,
["segment-1", "segment-2"],
["child-1"],
)
delete_query.where.return_value.delete.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_update_segments_status_enables_only_segments_without_indexing_cache(self):
dataset = _make_dataset()
document = _make_document()
segment_a = _make_segment(segment_id="segment-a", enabled=False)
segment_b = _make_segment(segment_id="segment-b", enabled=False)
current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1")
with (
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.enable_segments_to_index_task") as enable_task,
):
mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b]
mock_redis.get.side_effect = [None, "1"]
SegmentService.update_segments_status(["segment-a", "segment-b"], "enable", dataset, document)
assert segment_a.enabled is True
assert segment_a.disabled_at is None
assert segment_a.disabled_by is None
assert segment_b.enabled is False
mock_db.session.add.assert_called_once_with(segment_a)
mock_db.session.commit.assert_called_once()
enable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id)
def test_update_segments_status_disables_only_segments_without_indexing_cache(self):
dataset = _make_dataset()
document = _make_document()
segment_a = _make_segment(segment_id="segment-a", enabled=True)
segment_b = _make_segment(segment_id="segment-b", enabled=True)
current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1")
with (
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.disable_segments_from_index_task") as disable_task,
):
mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b]
mock_redis.get.side_effect = [None, "1"]
SegmentService.update_segments_status(["segment-a", "segment-b"], "disable", dataset, document)
assert segment_a.enabled is False
assert segment_a.disabled_at == "now"
assert segment_a.disabled_by == current_user.id
assert segment_b.enabled is True
mock_db.session.add.assert_called_once_with(segment_a)
mock_db.session.commit.assert_called_once()
disable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id)
class TestSegmentServiceChildChunkTailHelpers:
"""Unit tests for the remaining child-chunk helper branches."""
def test_update_child_chunk_rolls_back_on_vector_failure(self):
dataset = SimpleNamespace(id="dataset-1")
child_chunk = _make_child_chunk()
with (
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.VectorService") as vector_service,
):
vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed")
with pytest.raises(ChildChunkIndexingError, match="vector failed"):
SegmentService.update_child_chunk(
"new content", child_chunk, SimpleNamespace(), SimpleNamespace(), dataset
)
mock_db.session.rollback.assert_called_once()
mock_db.session.commit.assert_not_called()
def test_delete_child_chunk_commits_after_successful_vector_delete(self):
dataset = SimpleNamespace(id="dataset-1")
child_chunk = _make_child_chunk()
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.VectorService") as vector_service,
):
SegmentService.delete_child_chunk(child_chunk, dataset)
mock_db.session.delete.assert_called_once_with(child_chunk)
vector_service.delete_child_chunk_vector.assert_called_once_with(child_chunk, dataset)
mock_db.session.commit.assert_called_once()
class TestSegmentServiceAdditionalRegenerationBranches:
"""Additional unit tests for segment update and regeneration edge cases."""
@pytest.fixture
def account_context(self):
account = create_autospec(Account, instance=True)
account.id = "user-1"
account.current_tenant_id = "tenant-1"
with patch("services.dataset_service.current_user", account):
yield account
def test_update_segment_same_content_updates_answer_and_document_word_count_for_qa_segments(self, account_context):
segment = _make_segment(content="question", word_count=8)
document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=20)
dataset = _make_dataset()
refreshed_segment = SimpleNamespace(id=segment.id)
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
result = SegmentService.update_segment(
SegmentUpdateArgs(content="question", answer="new answer"),
segment,
document,
dataset,
)
assert result is refreshed_segment
assert segment.answer == "new answer"
assert segment.word_count == len("question") + len("new answer")
assert document.word_count == 20 + (len("question") + len("new answer") - 8)
vector_service.update_segment_vector.assert_not_called()
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
def test_update_segment_content_change_uses_answer_when_counting_tokens_for_qa_segments(self, account_context):
segment = _make_segment(content="old", word_count=3)
document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=10)
dataset = _make_dataset(indexing_technique="high_quality")
refreshed_segment = SimpleNamespace(id=segment.id)
embedding_model = MagicMock()
embedding_model.get_text_embedding_num_tokens.return_value = [21]
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.VectorService") as vector_service,
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-qa"),
patch("services.dataset_service.naive_utc_now", return_value="now"),
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = None
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
result = SegmentService.update_segment(
SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]),
segment,
document,
dataset,
)
assert result is refreshed_segment
embedding_model.get_text_embedding_num_tokens.assert_called_once_with(texts=["new questionnew answer"])
assert segment.answer == "new answer"
assert segment.tokens == 21
assert segment.word_count == len("new question") + len("new answer")
vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset)
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
def test_update_segment_content_change_parent_child_uses_default_embedding_and_ignores_summary_failures(
self, account_context
):
segment = _make_segment(content="old", word_count=3)
document = _make_document(
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
word_count=10,
)
dataset = _make_dataset(indexing_technique="high_quality")
dataset.embedding_model_provider = None
refreshed_segment = SimpleNamespace(id=segment.id)
processing_rule = SimpleNamespace(id=document.dataset_process_rule_id)
existing_summary = SimpleNamespace(summary_content="old summary")
embedding_model_instance = object()
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.VectorService") as vector_service,
patch("services.dataset_service.helper.generate_text_hash", return_value="hash-parent"),
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_default_model_instance.return_value = embedding_model_instance
update_summary.side_effect = RuntimeError("summary failed")
processing_rule_query = MagicMock()
processing_rule_query.where.return_value.first.return_value = processing_rule
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
result = SegmentService.update_segment(
SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"),
segment,
document,
dataset,
)
assert result is refreshed_segment
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
tenant_id="tenant-1",
model_type="text-embedding",
)
vector_service.generate_child_chunks.assert_called_once_with(
segment,
document,
dataset,
embedding_model_instance,
processing_rule,
True,
)
update_summary.assert_called_once_with(segment, dataset, "new summary")
vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset)
def test_update_segment_same_content_parent_child_marks_segment_error_for_non_high_quality_dataset(
self, account_context
):
segment = _make_segment(content="same content", word_count=len("same content"))
document = _make_document(
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
word_count=20,
)
dataset = _make_dataset(indexing_technique="economy")
refreshed_segment = SimpleNamespace(id=segment.id)
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
result = SegmentService.update_segment(
SegmentUpdateArgs(content="same content", regenerate_child_chunks=True),
segment,
document,
dataset,
)
assert result is refreshed_segment
assert segment.enabled is False
assert segment.disabled_at == "now"
assert segment.status == "error"
assert segment.error == "The knowledge base index technique is not high quality!"
vector_service.update_multimodel_vector.assert_not_called()