refactor: self contain message service integration fixtures to prevent flaky parallel test run

update message cleaning integration test fixtures to prevent detached instances
add new test data for pagination and cascade deletion
This commit is contained in:
Dev Sharma 2026-03-24 19:34:15 +05:30
parent b14e34412e
commit 8775981d90
1 changed files with 203 additions and 147 deletions

View File

@ -1,4 +1,5 @@
import datetime
import math
import uuid
import pytest
@ -6,6 +7,7 @@ from sqlalchemy import delete
from core.db.session_factory import session_factory
from models import Tenant
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import (
App,
Conversation,
@ -16,14 +18,22 @@ from models.model import (
from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy
from services.retention.conversation.messages_clean_service import MessagesCleanService
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0)
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC)
_OLD = _NOW - datetime.timedelta(days=60)
_VERY_OLD = _NOW - datetime.timedelta(days=90)
_RECENT = _NOW - datetime.timedelta(days=5)
_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1)
_WINDOW_END = _RECENT + datetime.timedelta(hours=1)
_DEFAULT_BATCH_SIZE = 100
_PAGINATION_MESSAGE_COUNT = 25
_PAGINATION_BATCH_SIZE = 8
@pytest.fixture
def tenant_and_app(flask_req_ctx):
"""Creates a Tenant, App and Conversation for the test and cleans up after."""
with session_factory.create_session() as session:
tenant = Tenant(name="retention_it_tenant")
session.add(tenant)
@ -50,12 +60,16 @@ def tenant_and_app(flask_req_ctx):
session.add(conv)
session.commit()
yield {"tenant": tenant, "app": app, "conversation": conv}
tenant_id = tenant.id
app_id = app.id
conv_id = conv.id
yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id}
with session_factory.create_session() as session:
session.execute(delete(Conversation).where(Conversation.id == conv.id))
session.execute(delete(App).where(App.id == app.id))
session.execute(delete(Tenant).where(Tenant.id == tenant.id))
session.execute(delete(Conversation).where(Conversation.id == conv_id))
session.execute(delete(App).where(App.id == app_id))
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
session.commit()
@ -80,13 +94,49 @@ def _make_message(app_id: str, conversation_id: str, created_at: datetime.dateti
class TestMessagesCleanServiceIntegration:
@pytest.fixture
def seed_messages(self, tenant_and_app):
"""Seeds one message at each of _VERY_OLD, _OLD, and _RECENT.
Yields a semantic mapping keyed by age label.
"""
data = tenant_and_app
app_id = data["app"].id
conv_id = data["conversation"].id
app_id = data["app_id"]
conv_id = data["conversation_id"]
# Ordered tuple of (label, timestamp) for deterministic seeding
timestamps = [
("very_old", _VERY_OLD),
("old", _OLD),
("recent", _RECENT),
]
msg_ids: dict[str, str] = {}
with session_factory.create_session() as session:
for label, ts in timestamps:
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
msg_ids[label] = msg.id
session.commit()
yield {"msg_ids": msg_ids, **data}
with session_factory.create_session() as session:
session.execute(
delete(Message)
.where(Message.id.in_(list(msg_ids.values())))
.execution_options(synchronize_session=False)
)
session.commit()
@pytest.fixture
def paginated_seed_messages(self, tenant_and_app):
"""Seeds multiple messages separated by 1-second increments starting at _OLD."""
data = tenant_and_app
app_id = data["app_id"]
conv_id = data["conversation_id"]
msg_ids: list[str] = []
with session_factory.create_session() as session:
for ts in [_VERY_OLD, _OLD, _RECENT]:
for i in range(_PAGINATION_MESSAGE_COUNT):
ts = _OLD + datetime.timedelta(seconds=i)
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
@ -99,124 +149,12 @@ class TestMessagesCleanServiceIntegration:
session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False))
session.commit()
def test_dry_run_does_not_delete(self, seed_messages):
data = seed_messages
app_id = data["app"].id
msg_ids = data["msg_ids"]
svc = MessagesCleanService(
policy=BillingDisabledPolicy(),
end_before=_NOW,
batch_size=100,
dry_run=True,
)
stats = svc.run()
assert stats["filtered_messages"] >= len(msg_ids)
assert stats["total_deleted"] == 0
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
assert remaining == len(msg_ids)
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
data = seed_messages
msg_ids = data["msg_ids"]
svc = MessagesCleanService(
policy=BillingDisabledPolicy(),
end_before=_NOW,
batch_size=100,
dry_run=False,
)
stats = svc.run()
assert stats["total_deleted"] >= len(msg_ids)
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
assert remaining == 0
def test_start_from_filters_correctly(self, seed_messages):
data = seed_messages
msg_ids = data["msg_ids"]
start = _OLD - datetime.timedelta(hours=1)
end = _OLD + datetime.timedelta(hours=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=start,
end_before=end,
batch_size=100,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(msg_ids)).all()}
assert msg_ids[1] not in remaining_ids
assert msg_ids[0] in remaining_ids
assert msg_ids[2] in remaining_ids
def test_cursor_pagination_across_batches(self, tenant_and_app):
@pytest.fixture
def cascade_test_data(self, tenant_and_app):
"""Seeds one Message with an associated Feedback and Annotation."""
data = tenant_and_app
app_id = data["app"].id
conv_id = data["conversation"].id
msg_ids: list[str] = []
with session_factory.create_session() as session:
for i in range(25):
ts = _OLD + datetime.timedelta(seconds=i)
msg = _make_message(app_id, conv_id, ts)
session.add(msg)
session.flush()
msg_ids.append(msg.id)
session.commit()
try:
svc = MessagesCleanService(
policy=BillingDisabledPolicy(),
end_before=_NOW,
start_from=_OLD - datetime.timedelta(seconds=1),
batch_size=8,
dry_run=False,
)
stats = svc.run()
assert stats["total_deleted"] == 25
assert stats["batches"] >= 4
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
assert remaining == 0
finally:
with session_factory.create_session() as session:
session.execute(
delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False)
)
session.commit()
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
far_future = _NOW + datetime.timedelta(days=365)
even_further = far_future + datetime.timedelta(days=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=far_future,
end_before=even_further,
batch_size=100,
)
stats = svc.run()
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
def test_relation_cascade_deletes(self, tenant_and_app):
data = tenant_and_app
app_id = data["app"].id
conv_id = data["conversation"].id
app_id = data["app_id"]
conv_id = data["conversation_id"]
with session_factory.create_session() as session:
msg = _make_message(app_id, conv_id, _OLD)
@ -227,8 +165,8 @@ class TestMessagesCleanServiceIntegration:
app_id=app_id,
conversation_id=conv_id,
message_id=msg.id,
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
)
annotation = MessageAnnotation(
app_id=app_id,
@ -240,32 +178,150 @@ class TestMessagesCleanServiceIntegration:
)
session.add_all([feedback, annotation])
session.commit()
msg_id = msg.id
fb_id = feedback.id
ann_id = annotation.id
try:
svc = MessagesCleanService(
policy=BillingDisabledPolicy(),
end_before=_NOW,
start_from=_OLD - datetime.timedelta(hours=1),
batch_size=100,
dry_run=False,
)
stats = svc.run()
yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data}
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
session.execute(delete(Message).where(Message.id == msg_id))
session.commit()
with session_factory.create_session() as session:
assert session.query(Message).where(Message.id == msg_id).count() == 0
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
finally:
with session_factory.create_session() as session:
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
session.execute(delete(Message).where(Message.id == msg_id))
session.commit()
def test_dry_run_does_not_delete(self, seed_messages):
"""Dry-run must count eligible rows without deleting any of them."""
data = seed_messages
msg_ids = data["msg_ids"]
all_ids = list(msg_ids.values())
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_WINDOW_START,
end_before=_WINDOW_END,
batch_size=_DEFAULT_BATCH_SIZE,
dry_run=True,
)
stats = svc.run()
assert stats["filtered_messages"] == len(all_ids)
assert stats["total_deleted"] == 0
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
assert remaining == len(all_ids)
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
"""All 3 seeded messages fall within the window and must be deleted."""
data = seed_messages
msg_ids = data["msg_ids"]
all_ids = list(msg_ids.values())
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_WINDOW_START,
end_before=_WINDOW_END,
batch_size=_DEFAULT_BATCH_SIZE,
dry_run=False,
)
stats = svc.run()
assert stats["total_deleted"] == len(all_ids)
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
assert remaining == 0
def test_start_from_filters_correctly(self, seed_messages):
"""Only the message at _OLD falls within the narrow ±1 h window."""
data = seed_messages
msg_ids = data["msg_ids"]
start = _OLD - datetime.timedelta(hours=1)
end = _OLD + datetime.timedelta(hours=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=start,
end_before=end,
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
all_ids = list(msg_ids.values())
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
assert msg_ids["old"] not in remaining_ids
assert msg_ids["very_old"] in remaining_ids
assert msg_ids["recent"] in remaining_ids
def test_cursor_pagination_across_batches(self, paginated_seed_messages):
"""Messages must be deleted across multiple batches."""
data = paginated_seed_messages
msg_ids = data["msg_ids"]
# _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s.
pagination_window_start = _OLD - datetime.timedelta(seconds=1)
pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=pagination_window_start,
end_before=pagination_window_end,
batch_size=_PAGINATION_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT
expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE)
assert stats["batches"] >= expected_batches
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
assert remaining == 0
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
"""A window entirely in the future must yield zero matches."""
far_future = _NOW + datetime.timedelta(days=365)
even_further = far_future + datetime.timedelta(days=1)
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=far_future,
end_before=even_further,
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
def test_relation_cascade_deletes(self, cascade_test_data):
"""Deleting a Message must cascade to its Feedback and Annotation rows."""
data = cascade_test_data
msg_id = data["msg_id"]
fb_id = data["fb_id"]
ann_id = data["ann_id"]
svc = MessagesCleanService.from_time_range(
policy=BillingDisabledPolicy(),
start_from=_OLD - datetime.timedelta(hours=1),
end_before=_OLD + datetime.timedelta(hours=1),
batch_size=_DEFAULT_BATCH_SIZE,
)
stats = svc.run()
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
assert session.query(Message).where(Message.id == msg_id).count() == 0
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
def test_factory_from_time_range_validation(self):
with pytest.raises(ValueError, match="start_from"):