From 8bbaa862f26a740cd0a9b7fef3bea2ac8131026a Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:51:55 +0800 Subject: [PATCH 001/107] style(scroll-bar): align design (#33751) --- .../ui/scroll-area/__tests__/index.spec.tsx | 21 ++- .../base/ui/scroll-area/index.module.css | 75 +++++++++ .../base/ui/scroll-area/index.stories.tsx | 149 ++++++++++++++++++ .../components/base/ui/scroll-area/index.tsx | 11 +- 4 files changed, 239 insertions(+), 17 deletions(-) create mode 100644 web/app/components/base/ui/scroll-area/index.module.css diff --git a/web/app/components/base/ui/scroll-area/__tests__/index.spec.tsx b/web/app/components/base/ui/scroll-area/__tests__/index.spec.tsx index 2781a5844f..e506fe59d0 100644 --- a/web/app/components/base/ui/scroll-area/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/scroll-area/__tests__/index.spec.tsx @@ -8,6 +8,7 @@ import { ScrollAreaThumb, ScrollAreaViewport, } from '../index' +import styles from '../index.module.css' const renderScrollArea = (options: { rootClassName?: string @@ -72,20 +73,19 @@ describe('scroll-area wrapper', () => { const thumb = screen.getByTestId('scroll-area-vertical-thumb') expect(scrollbar).toHaveAttribute('data-orientation', 'vertical') + expect(scrollbar).toHaveClass(styles.scrollbar) expect(scrollbar).toHaveClass( 'flex', + 'overflow-clip', + 'p-1', 'touch-none', 'select-none', - 'opacity-0', + 'opacity-100', 'transition-opacity', 'motion-reduce:transition-none', 'pointer-events-none', 'data-[hovering]:pointer-events-auto', - 'data-[hovering]:opacity-100', 'data-[scrolling]:pointer-events-auto', - 'data-[scrolling]:opacity-100', - 'hover:pointer-events-auto', - 'hover:opacity-100', 'data-[orientation=vertical]:absolute', 'data-[orientation=vertical]:inset-y-0', 'data-[orientation=vertical]:w-3', @@ -97,7 +97,6 @@ describe('scroll-area wrapper', () => { 'rounded-[4px]', 'bg-state-base-handle', 'transition-[background-color]', - 'hover:bg-state-base-handle-hover', 'motion-reduce:transition-none', 'data-[orientation=vertical]:w-1', ) @@ -112,20 +111,19 @@ describe('scroll-area wrapper', () => { const thumb = screen.getByTestId('scroll-area-horizontal-thumb') expect(scrollbar).toHaveAttribute('data-orientation', 'horizontal') + expect(scrollbar).toHaveClass(styles.scrollbar) expect(scrollbar).toHaveClass( 'flex', + 'overflow-clip', + 'p-1', 'touch-none', 'select-none', - 'opacity-0', + 'opacity-100', 'transition-opacity', 'motion-reduce:transition-none', 'pointer-events-none', 'data-[hovering]:pointer-events-auto', - 'data-[hovering]:opacity-100', 'data-[scrolling]:pointer-events-auto', - 'data-[scrolling]:opacity-100', - 'hover:pointer-events-auto', - 'hover:opacity-100', 'data-[orientation=horizontal]:absolute', 'data-[orientation=horizontal]:inset-x-0', 'data-[orientation=horizontal]:h-3', @@ -137,7 +135,6 @@ describe('scroll-area wrapper', () => { 'rounded-[4px]', 'bg-state-base-handle', 'transition-[background-color]', - 'hover:bg-state-base-handle-hover', 'motion-reduce:transition-none', 'data-[orientation=horizontal]:h-1', ) diff --git a/web/app/components/base/ui/scroll-area/index.module.css b/web/app/components/base/ui/scroll-area/index.module.css new file mode 100644 index 0000000000..a81fd3d3c2 --- /dev/null +++ b/web/app/components/base/ui/scroll-area/index.module.css @@ -0,0 +1,75 @@ +.scrollbar::before, +.scrollbar::after { + content: ''; + position: absolute; + z-index: 1; + border-radius: 9999px; + pointer-events: none; + opacity: 0; + transition: opacity 150ms ease; +} + +.scrollbar[data-orientation='vertical']::before { + left: 50%; + top: 4px; + width: 4px; + height: 12px; + transform: translateX(-50%); + background: linear-gradient(to bottom, var(--scroll-area-edge-hint-bg, var(--color-components-panel-bg)), transparent); +} + +.scrollbar[data-orientation='vertical']::after { + left: 50%; + bottom: 4px; + width: 4px; + height: 12px; + transform: translateX(-50%); + background: linear-gradient(to top, var(--scroll-area-edge-hint-bg, var(--color-components-panel-bg)), transparent); +} + +.scrollbar[data-orientation='horizontal']::before { + top: 50%; + left: 4px; + width: 12px; + height: 4px; + transform: translateY(-50%); + background: linear-gradient(to right, var(--scroll-area-edge-hint-bg, var(--color-components-panel-bg)), transparent); +} + +.scrollbar[data-orientation='horizontal']::after { + top: 50%; + right: 4px; + width: 12px; + height: 4px; + transform: translateY(-50%); + background: linear-gradient(to left, var(--scroll-area-edge-hint-bg, var(--color-components-panel-bg)), transparent); +} + +.scrollbar[data-orientation='vertical']:not([data-overflow-y-start])::before { + opacity: 1; +} + +.scrollbar[data-orientation='vertical']:not([data-overflow-y-end])::after { + opacity: 1; +} + +.scrollbar[data-orientation='horizontal']:not([data-overflow-x-start])::before { + opacity: 1; +} + +.scrollbar[data-orientation='horizontal']:not([data-overflow-x-end])::after { + opacity: 1; +} + +.scrollbar[data-hovering] > [data-orientation], +.scrollbar[data-scrolling] > [data-orientation], +.scrollbar > [data-orientation]:active { + background-color: var(--scroll-area-thumb-bg-active, var(--color-state-base-handle-hover)); +} + +@media (prefers-reduced-motion: reduce) { + .scrollbar::before, + .scrollbar::after { + transition: none; + } +} diff --git a/web/app/components/base/ui/scroll-area/index.stories.tsx b/web/app/components/base/ui/scroll-area/index.stories.tsx index 8eb655a151..465e534921 100644 --- a/web/app/components/base/ui/scroll-area/index.stories.tsx +++ b/web/app/components/base/ui/scroll-area/index.stories.tsx @@ -1,5 +1,6 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ReactNode } from 'react' +import * as React from 'react' import AppIcon from '@/app/components/base/app-icon' import { cn } from '@/utils/classnames' import { @@ -78,6 +79,16 @@ const activityRows = Array.from({ length: 14 }, (_, index) => ({ body: 'A short line of copy to mimic dense operational feeds in settings and debug panels.', })) +const scrollbarShowcaseRows = Array.from({ length: 18 }, (_, index) => ({ + title: `Scroll checkpoint ${index + 1}`, + body: 'Dedicated story content so the scrollbar can be inspected without sticky headers, masks, or clipped shells.', +})) + +const horizontalShowcaseCards = Array.from({ length: 8 }, (_, index) => ({ + title: `Lane ${index + 1}`, + body: 'Horizontal scrollbar reference without edge hints.', +})) + const webAppsRows = [ { id: 'invoice-copilot', name: 'Invoice Copilot', meta: 'Pinned', icon: '🧾', iconBackground: '#FFEAD5', selected: true, pinned: true }, { id: 'rag-ops', name: 'RAG Ops Console', meta: 'Ops', icon: '🛰️', iconBackground: '#E0F2FE', selected: false, pinned: true }, @@ -255,6 +266,112 @@ const HorizontalRailPane = () => ( ) +const ScrollbarStatePane = ({ + eyebrow, + title, + description, + initialPosition, +}: { + eyebrow: string + title: string + description: string + initialPosition: 'top' | 'middle' | 'bottom' +}) => { + const viewportId = React.useId() + + React.useEffect(() => { + let frameA = 0 + let frameB = 0 + + const syncScrollPosition = () => { + const viewport = document.getElementById(viewportId) + + if (!(viewport instanceof HTMLDivElement)) + return + + const maxScrollTop = Math.max(0, viewport.scrollHeight - viewport.clientHeight) + + if (initialPosition === 'top') + viewport.scrollTop = 0 + + if (initialPosition === 'middle') + viewport.scrollTop = maxScrollTop / 2 + + if (initialPosition === 'bottom') + viewport.scrollTop = maxScrollTop + } + + frameA = requestAnimationFrame(() => { + frameB = requestAnimationFrame(syncScrollPosition) + }) + + return () => { + cancelAnimationFrame(frameA) + cancelAnimationFrame(frameB) + } + }, [initialPosition, viewportId]) + + return ( +
{description}
+Current design delivery defines the horizontal scrollbar body, but not a horizontal edge hint.
+ = Omit ,
+): WorkflowHookTestResult
Installed apps
+{description}
Current design delivery defines the horizontal scrollbar body, but not a horizontal edge hint.
Content without ENDTHINKFLAG
hi
", + } + if include_expiration_time: + payload["expiration_time"] = naive_utc_now() + return json.dumps(payload, default=str) + + +@dataclasses.dataclass +class _DummyForm: + id: str + workflow_run_id: str | None + node_id: str + tenant_id: str + app_id: str + form_definition: str + rendered_content: str + expiration_time: datetime + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + created_at: datetime = dataclasses.field(default_factory=naive_utc_now) + selected_action_id: str | None = None + submitted_data: str | None = None + submitted_at: datetime | None = None + submission_user_id: str | None = None + submission_end_user_id: str | None = None + completed_by_recipient_id: str | None = None + status: HumanInputFormStatus = HumanInputFormStatus.WAITING + + +@dataclasses.dataclass +class _DummyRecipient: + id: str + form_id: str + recipient_type: RecipientType + access_token: str | None + + +class _FakeScalarResult: + def __init__(self, obj: Any): + self._obj = obj + + def first(self) -> Any: + if isinstance(self._obj, list): + return self._obj[0] if self._obj else None + return self._obj + + def all(self) -> list[Any]: + if self._obj is None: + return [] + if isinstance(self._obj, list): + return list(self._obj) + return [self._obj] + + +class _FakeExecuteResult: + def __init__(self, rows: Sequence[tuple[Any, ...]]): + self._rows = list(rows) + + def all(self) -> list[tuple[Any, ...]]: + return list(self._rows) + + +class _FakeSession: + def __init__( + self, + *, + scalars_result: Any = None, + scalars_results: list[Any] | None = None, + forms: dict[str, _DummyForm] | None = None, + recipients: dict[str, _DummyRecipient] | None = None, + execute_rows: Sequence[tuple[Any, ...]] = (), + ): + if scalars_results is not None: + self._scalars_queue = list(scalars_results) + else: + self._scalars_queue = [scalars_result] + self._forms = forms or {} + self._recipients = recipients or {} + self._execute_rows = list(execute_rows) + self.added: list[Any] = [] + + def scalars(self, _query: Any) -> _FakeScalarResult: + if self._scalars_queue: + value = self._scalars_queue.pop(0) + else: + value = None + return _FakeScalarResult(value) + + def execute(self, _stmt: Any) -> _FakeExecuteResult: + return _FakeExecuteResult(self._execute_rows) + + def get(self, model_cls: Any, obj_id: str) -> Any: + name = getattr(model_cls, "__name__", "") + if name == "HumanInputForm": + return self._forms.get(obj_id) + if name == "HumanInputFormRecipient": + return self._recipients.get(obj_id) + return None + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: Sequence[Any]) -> None: + self.added.extend(list(objs)) + + def flush(self) -> None: + # Simulate DB default population for attributes referenced in entity wrappers. + for obj in self.added: + if hasattr(obj, "id") and obj.id in (None, ""): + obj.id = f"gen-{len(str(self.added))}" + if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None: + if obj.recipient_type == RecipientType.CONSOLE: + obj.access_token = "token-console" + elif obj.recipient_type == RecipientType.BACKSTAGE: + obj.access_token = "token-backstage" + else: + obj.access_token = "token-webapp" + + def refresh(self, _obj: Any) -> None: + return None + + def begin(self) -> _FakeSession: + return self + + def __enter__(self) -> _FakeSession: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +class _SessionFactoryStub: + def __init__(self, session: _FakeSession): + self._session = session + + def create_session(self) -> _FakeSession: + return self._session + + +def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session)) + + +def test_recipient_entity_token_raises_when_missing() -> None: + recipient = SimpleNamespace(id="r1", access_token=None) + entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type] + with pytest.raises(AssertionError, match="access_token should not be None"): + _ = entity.token + + +def test_recipient_entity_id_and_token_success() -> None: + recipient = SimpleNamespace(id="r1", access_token="tok") + entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type] + assert entity.id == "r1" + assert entity.token == "tok" + + +def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=naive_utc_now(), + ) + console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok") + webapp = _DummyRecipient( + id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok" + ) + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] + assert entity.web_app_token == "ctok" + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] + assert entity.web_app_token == "wtok" + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] + assert entity.web_app_token is None + + +def test_form_entity_submitted_data_parsed() -> None: + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=naive_utc_now(), + submitted_data='{"a": 1}', + submitted_at=naive_utc_now(), + ) + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] + assert entity.submitted is True + assert entity.submitted_data == {"a": 1} + assert entity.rendered_content == "x
" + assert entity.selected_action_id is None + assert entity.status == HumanInputFormStatus.WAITING + + +def test_form_record_from_models_injects_expiration_time_when_missing() -> None: + expiration = naive_utc_now() + form = _DummyForm( + id="f1", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=False), + rendered_content="x
", + expiration_time=expiration, + submitted_data='{"k": "v"}', + ) + record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type] + assert record.definition.expiration_time == expiration + assert record.submitted_data == {"k": "v"} + assert record.submitted is False + + +def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[SimpleNamespace] = [] + + def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def] + recipient = SimpleNamespace( + id=f"{payload.TYPE}-{len(created)}", + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + access_token="tok", + ) + created.append(recipient) + return recipient + + monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new)) + + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined] + form_id="f", + delivery_id="d", + members=[ + _WorkspaceMemberInfo(user_id="u1", email=""), + _WorkspaceMemberInfo(user_id="u2", email="a@example.com"), + _WorkspaceMemberInfo(user_id="u3", email="a@example.com"), + ], + external_emails=["", "a@example.com", "b@example.com", "b@example.com"], + ) + assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL] + + +def test_query_workspace_members_by_ids_empty_returns_empty() -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == [] + + +def test_query_workspace_members_by_ids_maps_rows() -> None: + session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")]) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"]) + assert rows == [ + _WorkspaceMemberInfo(user_id="u1", email="a@example.com"), + _WorkspaceMemberInfo(user_id="u2", email="b@example.com"), + ] + + +def test_query_all_workspace_members_maps_rows() -> None: + session = _FakeSession(execute_rows=[("u1", "a@example.com")]) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + rows = repo._query_all_workspace_members(session=session) + assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")] + + +def test_repository_init_sets_tenant_id() -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo._tenant_id == "tenant" + + +def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1") + result = repo._delivery_method_to_model( + session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod() + ) + assert result.delivery.id == "del-1" + assert result.delivery.form_id == "form-1" + assert len(result.recipients) == 1 + assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP + + +def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1") + called: dict[str, Any] = {} + + def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]: + called.update( + {"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config} + ) + return ["r"] + + monkeypatch.setattr(repo, "_build_email_recipients", fake_build) + + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + ), + subject="s", + body="b", + ) + ) + result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method) + assert result.recipients == ["r"] + assert called["delivery_id"] == "del-1" + + +def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr( + repo, + "_query_all_workspace_members", + lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")], + ) + monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"]) + recipients = repo._build_email_recipients( + session=MagicMock(), + form_id="f", + delivery_id="d", + recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), + ) + assert recipients == ["ok"] + + +def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + + def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]: + assert restrict_to_user_ids == ["u1"] + return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")] + + monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query) + monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"]) + recipients = repo._build_email_recipients( + session=MagicMock(), + form_id="f", + delivery_id="d", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + ), + ) + assert recipients == ["ok"] + + +def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo.get_form("run", "node") is None + + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="r1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="tok", + ) + session = _FakeSession(scalars_results=[form, [recipient]]) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + entity = repo.get_form("run", "node") + assert entity is not None + assert entity.id == "f1" + assert entity.recipients[0].id == "r1" + assert entity.recipients[0].token == "tok" + + +def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None: + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + ids = iter(["form-id", "del-web", "del-console", "del-backstage"]) + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids)) + + session = _FakeSession() + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + + form_config = HumanInputNodeData( + title="Title", + delivery_methods=[], + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + ) + params = FormCreateParams( + app_id="app", + workflow_execution_id="run", + node_id="node", + form_config=form_config, + rendered_content="hello
", + delivery_methods=[WebAppDeliveryMethod()], + display_in_ui=True, + resolved_default_values={}, + form_kind=HumanInputFormKind.RUNTIME, + console_recipient_required=True, + console_creator_account_id="acc-1", + backstage_recipient_required=True, + ) + + entity = repo.create_form(params) + assert entity.id == "form-id" + assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) + # Console token should take precedence when console recipient is present. + assert entity.web_app_token == "token-console" + assert len(entity.recipients) == 3 + + +def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=None)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_token("tok") is None + + recipient = SimpleNamespace(form=None) + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_token("tok") is None + + +def test_submission_repository_init_no_args() -> None: + repo = HumanInputFormSubmissionRepository() + assert isinstance(repo, HumanInputFormSubmissionRepository) + + +def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f1", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=naive_utc_now(), + ) + recipient = SimpleNamespace( + id="r1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="tok", + form=form, + ) + + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + record = repo.get_by_token("tok") + assert record is not None + assert record.access_token == "tok" + + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP) + assert record is not None + assert record.recipient_id == "r1" + + +def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=None)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None + + +def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + missing_session = _FakeSession(forms={}) + _patch_session_factory(monkeypatch, missing_session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form not found"): + repo.mark_submitted( + form_id="missing", + recipient_id=None, + selected_action_id="a", + form_data={}, + submission_user_id=None, + submission_end_user_id=None, + ) + + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=fixed_now, + ) + recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok") + session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_submitted( + form_id=form.id, + recipient_id=recipient.id, + selected_action_id="approve", + form_data={"k": "v"}, + submission_user_id="u", + submission_end_user_id="eu", + ) + assert form.status == HumanInputFormStatus.SUBMITTED + assert form.submitted_at == fixed_now + assert record.submitted_data == {"k": "v"} + + +def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=naive_utc_now(), + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"): + repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type] + + +def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=naive_utc_now(), + status=HumanInputFormStatus.TIMEOUT, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r") + assert record.status == HumanInputFormStatus.TIMEOUT + + +def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=naive_utc_now(), + status=HumanInputFormStatus.SUBMITTED, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form already submitted"): + repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED) + + +def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="x
", + expiration_time=naive_utc_now(), + selected_action_id="a", + submitted_data="{}", + submission_user_id="u", + submission_end_user_id="eu", + completed_by_recipient_id="r", + status=HumanInputFormStatus.WAITING, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED) + assert form.status == HumanInputFormStatus.EXPIRED + assert form.selected_action_id is None + assert form.submitted_data is None + assert form.submission_user_id is None + assert form.submission_end_user_id is None + assert form.completed_by_recipient_id is None + assert record.status == HumanInputFormStatus.EXPIRED + + +def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(forms={})) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form not found"): + repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT) diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index c66e50437a..232ab07882 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -1,84 +1,291 @@ -from datetime import datetime +from datetime import UTC, datetime from unittest.mock import MagicMock from uuid import uuid4 -from sqlalchemy import create_engine +import pytest +from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType -from models import Account, WorkflowRun +from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType +from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom -def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository: - engine = create_engine("sqlite:///:memory:") - real_session_factory = sessionmaker(bind=engine, expire_on_commit=False) - - user = MagicMock(spec=Account) - user.id = str(uuid4()) - user.current_tenant_id = str(uuid4()) - - repository = SQLAlchemyWorkflowExecutionRepository( - session_factory=real_session_factory, - user=user, - app_id="app-id", - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - ) - - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = False - repository._session_factory = MagicMock(return_value=session_context) - return repository - - -def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution: - return WorkflowExecution.new( - id_=execution_id, - workflow_id="workflow-id", - workflow_type=WorkflowType.WORKFLOW, - workflow_version="1.0.0", - graph={"nodes": [], "edges": []}, - inputs={"query": "hello"}, - started_at=started_at, - ) - - -def test_save_uses_execution_started_at_when_record_does_not_exist(): +@pytest.fixture +def mock_session_factory(): + """Mock SQLAlchemy session factory.""" + session_factory = MagicMock(spec=sessionmaker) session = MagicMock() session.get.return_value = None - repository = _build_repository_with_mocked_session(session) - - started_at = datetime(2026, 1, 1, 12, 0, 0) - execution = _build_execution(execution_id=str(uuid4()), started_at=started_at) - - repository.save(execution) - - saved_model = session.merge.call_args.args[0] - assert saved_model.created_at == started_at - session.commit.assert_called_once() + session_factory.return_value.__enter__.return_value = session + return session_factory -def test_save_preserves_existing_created_at_when_record_already_exists(): - session = MagicMock() - repository = _build_repository_with_mocked_session(session) +@pytest.fixture +def mock_engine(): + """Mock SQLAlchemy Engine.""" + return MagicMock(spec=Engine) - execution_id = str(uuid4()) - existing_created_at = datetime(2026, 1, 1, 12, 0, 0) - existing_run = WorkflowRun() - existing_run.id = execution_id - existing_run.tenant_id = repository._tenant_id - existing_run.created_at = existing_created_at - session.get.return_value = existing_run - execution = _build_execution( - execution_id=execution_id, - started_at=datetime(2026, 1, 1, 12, 30, 0), +@pytest.fixture +def mock_account(): + """Mock Account user.""" + account = MagicMock(spec=Account) + account.id = str(uuid4()) + account.current_tenant_id = str(uuid4()) + return account + + +@pytest.fixture +def mock_end_user(): + """Mock EndUser.""" + user = MagicMock(spec=EndUser) + user.id = str(uuid4()) + user.tenant_id = str(uuid4()) + return user + + +@pytest.fixture +def sample_workflow_execution(): + """Sample WorkflowExecution for testing.""" + return WorkflowExecution( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input1": "value1"}, + outputs={"output1": "result1"}, + status=WorkflowExecutionStatus.SUCCEEDED, + error_message="", + total_tokens=100, + total_steps=5, + exceptions_count=0, + started_at=datetime.now(UTC), + finished_at=datetime.now(UTC), ) - repository.save(execution) - saved_model = session.merge.call_args.args[0] - assert saved_model.created_at == existing_created_at - session.commit.assert_called_once() +class TestSQLAlchemyWorkflowExecutionRepository: + def test_init_with_sessionmaker(self, mock_session_factory, mock_account): + app_id = "test_app_id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from + ) + + assert repo._session_factory == mock_session_factory + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == app_id + assert repo._triggered_from == triggered_from + assert repo._creator_user_id == mock_account.id + assert repo._creator_user_role == CreatorUserRole.ACCOUNT + + def test_init_with_engine(self, mock_engine, mock_account): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_engine, + user=mock_account, + app_id="test_app_id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + assert isinstance(repo._session_factory, sessionmaker) + assert repo._session_factory.kw["bind"] == mock_engine + + def test_init_invalid_session_factory(self, mock_account): + with pytest.raises(ValueError, match="Invalid session_factory type"): + SQLAlchemyWorkflowExecutionRepository( + session_factory="invalid", user=mock_account, app_id=None, triggered_from=None + ) + + def test_init_no_tenant_id(self, mock_session_factory): + user = MagicMock(spec=Account) + user.current_tenant_id = None + + with pytest.raises(ValueError, match="User must have a tenant_id"): + SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None + ) + + def test_init_with_end_user(self, mock_session_factory, mock_end_user): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None + ) + assert repo._tenant_id == mock_end_user.tenant_id + assert repo._creator_user_role == CreatorUserRole.END_USER + + def test_to_domain_model(self, mock_session_factory, mock_account): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None + ) + + db_model = MagicMock(spec=WorkflowRun) + db_model.id = str(uuid4()) + db_model.workflow_id = str(uuid4()) + db_model.type = "workflow" + db_model.version = "1.0" + db_model.inputs_dict = {"in": "val"} + db_model.outputs_dict = {"out": "val"} + db_model.graph_dict = {"nodes": []} + db_model.status = "succeeded" + db_model.error = "some error" + db_model.total_tokens = 50 + db_model.total_steps = 3 + db_model.exceptions_count = 1 + db_model.created_at = datetime.now(UTC) + db_model.finished_at = datetime.now(UTC) + + domain_model = repo._to_domain_model(db_model) + + assert domain_model.id_ == db_model.id + assert domain_model.workflow_id == db_model.workflow_id + assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED + assert domain_model.inputs == db_model.inputs_dict + assert domain_model.error_message == "some error" + + def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + + # Make elapsed time deterministic to avoid flaky tests + sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC) + sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC) + + db_model = repo._to_db_model(sample_workflow_execution) + + assert db_model.id == sample_workflow_execution.id_ + assert db_model.tenant_id == repo._tenant_id + assert db_model.app_id == "test_app" + assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING + assert db_model.status == sample_workflow_execution.status.value + assert db_model.total_tokens == sample_workflow_execution.total_tokens + assert db_model.elapsed_time == 10.0 + + def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + # Test with empty/None fields + sample_workflow_execution.graph = None + sample_workflow_execution.inputs = None + sample_workflow_execution.outputs = None + sample_workflow_execution.error_message = None + sample_workflow_execution.finished_at = None + + db_model = repo._to_db_model(sample_workflow_execution) + + assert db_model.graph is None + assert db_model.inputs is None + assert db_model.outputs is None + assert db_model.error is None + assert db_model.elapsed_time == 0 + + def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id=None, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + db_model = repo._to_db_model(sample_workflow_execution) + assert not hasattr(db_model, "app_id") or db_model.app_id is None + assert db_model.tenant_id == repo._tenant_id + + def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None + ) + + # Test triggered_from missing + with pytest.raises(ValueError, match="triggered_from is required"): + repo._to_db_model(sample_workflow_execution) + + repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN + repo._creator_user_id = None + with pytest.raises(ValueError, match="created_by is required"): + repo._to_db_model(sample_workflow_execution) + + repo._creator_user_id = "some_id" + repo._creator_user_role = None + with pytest.raises(ValueError, match="created_by_role is required"): + repo._to_db_model(sample_workflow_execution) + + def test_save(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + repo.save(sample_workflow_execution) + + session = mock_session_factory.return_value.__enter__.return_value + session.merge.assert_called_once() + session.commit.assert_called_once() + + # Check cache + assert sample_workflow_execution.id_ in repo._execution_cache + cached_model = repo._execution_cache[sample_workflow_execution.id_] + assert cached_model.id == sample_workflow_execution.id_ + + def test_save_uses_execution_started_at_when_record_does_not_exist( + self, mock_session_factory, mock_account, sample_workflow_execution + ): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + sample_workflow_execution.started_at = started_at + + session = mock_session_factory.return_value.__enter__.return_value + session.get.return_value = None + + repo.save(sample_workflow_execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == started_at + session.commit.assert_called_once() + + def test_save_preserves_existing_created_at_when_record_already_exists( + self, mock_session_factory, mock_account, sample_workflow_execution + ): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + execution_id = sample_workflow_execution.id_ + existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + + existing_run = WorkflowRun() + existing_run.id = execution_id + existing_run.tenant_id = repo._tenant_id + existing_run.created_at = existing_created_at + + session = mock_session_factory.return_value.__enter__.return_value + session.get.return_value = existing_run + + sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC) + + repo.save(sample_workflow_execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == existing_created_at + session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 0000000000..c7af32789b --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import json +import logging +from collections.abc import Mapping +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, Mock + +import psycopg2.errors +import pytest +from sqlalchemy import Engine, create_engine +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories.sqlalchemy_workflow_node_execution_repository import ( + SQLAlchemyWorkflowNodeExecutionRepository, + _deterministic_json_dump, + _filter_by_offload_type, + _find_first, + _replace_or_append_offload, +) +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import ( + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from models import Account, EndUser +from models.enums import ExecutionOffLoadType +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom + + +def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account: + user = Mock(spec=Account) + user.id = user_id + user.current_tenant_id = tenant_id + return user + + +def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser: + user = Mock(spec=EndUser) + user.id = user_id + user.tenant_id = tenant_id + return user + + +def _execution( + *, + execution_id: str = "exec-id", + node_execution_id: str = "node-exec-id", + workflow_run_id: str = "run-id", + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED, + inputs: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + process_data: Mapping[str, Any] | None = None, + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, +) -> WorkflowNodeExecution: + return WorkflowNodeExecution( + id=execution_id, + node_execution_id=node_execution_id, + workflow_id="workflow-id", + workflow_execution_id=workflow_run_id, + index=1, + predecessor_node_id=None, + node_id="node-id", + node_type=NodeType.LLM, + title="Title", + inputs=inputs, + outputs=outputs, + process_data=process_data, + status=status, + error=None, + elapsed_time=1.0, + metadata=metadata, + created_at=datetime.now(UTC), + finished_at=None, + ) + + +class _SessionCtx: + def __init__(self, session: Any): + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +def _session_factory(session: Any) -> sessionmaker: + factory = Mock(spec=sessionmaker) + factory.return_value = _SessionCtx(session) + return factory + + +def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + engine: Engine = create_engine("sqlite:///:memory:") + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=engine, + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert isinstance(repo._session_factory, sessionmaker) + + sm = Mock(spec=sessionmaker) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=sm, + user=_mock_end_user(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + assert repo._creator_user_role.value == "end_user" + + +def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + with pytest.raises(ValueError, match="Invalid session_factory type"): + SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type] + session_factory=object(), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + user = _mock_account() + user.current_tenant_id = None + with pytest.raises(ValueError, match="User must have a tenant_id"): + SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=user, + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None: + created: dict[str, Any] = {} + + class FakeTruncator: + def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int): + created.update( + { + "max_size_bytes": max_size_bytes, + "array_element_limit": array_element_limit, + "string_length_limit": string_length_limit, + } + ) + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator", + FakeTruncator, + ) + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + _ = repo._create_truncator() + assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE + + +def test_helpers_find_first_and_replace_or_append_and_filter() -> None: + assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}' + assert _find_first([], lambda _: True) is None + assert _find_first([1, 2, 3], lambda x: x > 1) == 2 + + off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2 + + replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)) + assert len(replaced) == 2 + assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS] + + +def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1}) + + # Happy path: deterministic json dump should be sorted + db_model = repo._to_db_model(execution) + assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1} + assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1 + + repo._triggered_from = None + with pytest.raises(ValueError, match="triggered_from is required"): + repo._to_db_model(execution) + + +def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + execution = _execution() + db_model = repo._to_db_model(execution) + assert db_model.app_id == "app" + + repo._creator_user_id = None + with pytest.raises(ValueError, match="created_by is required"): + repo._to_db_model(execution) + + repo._creator_user_id = "user" + repo._creator_user_role = None + with pytest.raises(ValueError, match="created_by_role is required"): + repo._to_db_model(execution) + + +def test_is_duplicate_key_error_and_regenerate_id( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + unique = Mock(spec=psycopg2.errors.UniqueViolation) + duplicate_error = IntegrityError("dup", params=None, orig=unique) + assert repo._is_duplicate_key_error(duplicate_error) is True + assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False + + execution = _execution(execution_id="old-id") + db_model = WorkflowNodeExecutionModel() + db_model.id = "old-id" + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id") + caplog.set_level(logging.WARNING) + repo._regenerate_id_on_duplicate(execution, db_model) + assert execution.id == "new-id" + assert db_model.id == "new-id" + assert any("Duplicate key conflict" in r.message for r in caplog.records) + + +def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id1" + db_model.node_execution_id = "node1" + db_model.foo = "bar" # type: ignore[attr-defined] + db_model.__dict__["_private"] = "x" + + existing = SimpleNamespace() + session.get.return_value = existing + repo._persist_to_database(db_model) + assert existing.foo == "bar" + session.add.assert_not_called() + assert repo._node_execution_cache["node1"] is db_model + + session.reset_mock() + session.get.return_value = None + repo._node_execution_cache.clear() + repo._persist_to_database(db_model) + session.add.assert_called_once_with(db_model) + assert repo._node_execution_cache["node1"] is db_model + + +def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None + + class FakeTruncator: + def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def] + return value, False + + monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator()) + assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None + + +def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None: + uploaded: dict[str, Any] = {} + + class FakeFileService: + def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def] + uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user}) + return SimpleNamespace(id="file-id", key="file-key") + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService() + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id") + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + class FakeTruncator: + def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def] + return {"truncated": True}, True + + monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator()) + + result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS) + assert result is not None + assert result.truncated_value == {"truncated": True} + assert uploaded["filename"].startswith("node_execution_exec_inputs.json") + assert result.offload.file_id == "file-id" + assert result.offload.type_ == ExecutionOffLoadType.INPUTS + + +def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id" + db_model.node_execution_id = "node-exec" + db_model.workflow_id = "wf" + db_model.workflow_run_id = "run" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node" + db_model.node_type = NodeType.LLM + db_model.title = "t" + db_model.inputs = json.dumps({"trunc": "i"}) + db_model.process_data = json.dumps({"trunc": "p"}) + db_model.outputs = json.dumps({"trunc": "o"}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED + db_model.error = None + db_model.elapsed_time = 0.1 + db_model.execution_metadata = json.dumps({"total_tokens": 3}) + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + + off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA) + off_in.file = SimpleNamespace(key="k-in") + off_out.file = SimpleNamespace(key="k-out") + off_proc.file = SimpleNamespace(key="k-proc") + db_model.offload_data = [off_out, off_in, off_proc] + + def fake_load(key: str) -> bytes: + return json.dumps({"full": key}).encode() + + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load) + + domain = repo._to_domain_model(db_model) + assert domain.inputs == {"full": "k-in"} + assert domain.outputs == {"full": "k-out"} + assert domain.process_data == {"full": "k-proc"} + assert domain.get_truncated_inputs() == {"trunc": "i"} + assert domain.get_truncated_outputs() == {"trunc": "o"} + assert domain.get_truncated_process_data() == {"trunc": "p"} + + +def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id" + db_model.node_execution_id = "node-exec" + db_model.workflow_id = "wf" + db_model.workflow_run_id = "run" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node" + db_model.node_type = NodeType.LLM + db_model.title = "t" + db_model.inputs = json.dumps({"i": 1}) + db_model.process_data = json.dumps({"p": 2}) + db_model.outputs = json.dumps({"o": 3}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED + db_model.error = None + db_model.elapsed_time = 0.1 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + db_model.offload_data = [] + + domain = repo._to_domain_model(db_model) + assert domain.inputs == {"i": 1} + assert domain.outputs == {"o": 3} + + +def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeConverter: + def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]: + return {"wrapped": values["a"]} + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter", + FakeConverter, + ) + assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}' + + +def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace( + id="id", + offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)], + inputs=None, + outputs=None, + process_data=None, + ) + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3}) + + trunc_result = SimpleNamespace( + truncated_value={"trunc": True}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"), + ) + monkeypatch.setattr( + repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None + ) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True)) + + repo.save_execution_data(execution) + # Inputs should be truncated, outputs/process_data encoded directly + db_model = session.merge.call_args.args[0] + assert json.loads(db_model.inputs) == {"trunc": True} + assert json.loads(db_model.outputs) == {"b": 2} + assert json.loads(db_model.process_data) == {"c": 3} + assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data) + assert execution.get_truncated_inputs() == {"trunc": True} + + +def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + existing = SimpleNamespace( + id="id", + offload_data=[], + inputs=None, + outputs=None, + process_data=None, + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = existing + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3}) + + def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any: + if values == {"b": 2}: + return SimpleNamespace( + truncated_value={"b": "trunc"}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"), + ) + if values == {"c": 3}: + return SimpleNamespace( + truncated_value={"c": "trunc"}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"), + ) + return None + + monkeypatch.setattr(repo, "_truncate_and_upload", trunc) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True)) + + repo.save_execution_data(execution) + db_model = session.merge.call_args.args[0] + assert json.loads(db_model.outputs) == {"b": "trunc"} + assert json.loads(db_model.process_data) == {"c": "trunc"} + assert execution.get_truncated_outputs() == {"b": "trunc"} + assert execution.get_truncated_process_data() == {"c": "trunc"} + + +def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = None + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}) + fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None) + monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model) + monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values)) + + repo.save_execution_data(execution) + merged = session.merge.call_args.args[0] + assert merged.inputs == '{"a": 1}' + + +def test_save_retries_duplicate_and_logs_non_duplicate( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(execution_id="id") + unique = Mock(spec=psycopg2.errors.UniqueViolation) + duplicate_error = IntegrityError("dup", params=None, orig=unique) + other_error = IntegrityError("other", params=None, orig=None) + + calls = {"n": 0} + + def persist(_db_model: Any) -> None: + calls["n"] += 1 + if calls["n"] == 1: + raise duplicate_error + + monkeypatch.setattr(repo, "_persist_to_database", persist) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id") + repo.save(execution) + assert execution.id == "new-id" + assert repo._node_execution_cache[execution.node_execution_id] is not None + + caplog.set_level(logging.ERROR) + monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error)) + with pytest.raises(IntegrityError): + repo.save(_execution(execution_id="id2", node_execution_id="node2")) + assert any("Non-duplicate key integrity error" in r.message for r in caplog.records) + + +def test_save_logs_and_reraises_on_unexpected_error( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + caplog.set_level(logging.ERROR) + monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom"))) + with pytest.raises(RuntimeError, match="boom"): + repo.save(_execution(execution_id="id3", node_execution_id="node3")) + assert any("Failed to save workflow node execution" in r.message for r in caplog.records) + + +def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + class FakeStmt: + def __init__(self) -> None: + self.where_calls = 0 + self.order_by_args: tuple[Any, ...] | None = None + + def where(self, *_args: Any) -> FakeStmt: + self.where_calls += 1 + return self + + def order_by(self, *args: Any) -> FakeStmt: + self.order_by_args = args + return self + + stmt = FakeStmt() + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files", + lambda _q: stmt, + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select") + + model1 = SimpleNamespace(node_execution_id="n1") + model2 = SimpleNamespace(node_execution_id=None) + session = MagicMock() + session.scalars.return_value.all.return_value = [model1, model2] + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + order = OrderConfig(order_by=["index", "missing"], order_direction="desc") + db_models = repo.get_db_models_by_workflow_run("run", order) + assert db_models == [model1, model2] + assert repo._node_execution_cache["n1"] is model1 + assert stmt.order_by_args is not None + + +def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + class FakeStmt: + def where(self, *_args: Any) -> FakeStmt: + return self + + def order_by(self, *args: Any) -> FakeStmt: + self.args = args # type: ignore[attr-defined] + return self + + stmt = FakeStmt() + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files", + lambda _q: stmt, + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select") + + session = MagicMock() + session.scalars.return_value.all.return_value = [] + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc")) + + +def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")] + monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models) + monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}") + + class FakeExecutor: + def __enter__(self) -> FakeExecutor: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def map(self, func, items, timeout: int): # type: ignore[no-untyped-def] + assert timeout == 30 + return list(map(func, items)) + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor", + lambda max_workers: FakeExecutor(), + ) + + result = repo.get_by_workflow_run("run", order_config=None) + assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/core/schemas/test_registry.py b/api/tests/unit_tests/core/schemas/test_registry.py new file mode 100644 index 0000000000..5749e72eb0 --- /dev/null +++ b/api/tests/unit_tests/core/schemas/test_registry.py @@ -0,0 +1,137 @@ +import json +from unittest.mock import patch + +from core.schemas.registry import SchemaRegistry + + +class TestSchemaRegistry: + def test_initialization(self, tmp_path): + base_dir = tmp_path / "schemas" + base_dir.mkdir() + registry = SchemaRegistry(str(base_dir)) + assert registry.base_dir == base_dir + assert registry.versions == {} + assert registry.metadata == {} + + def test_default_registry_singleton(self): + registry1 = SchemaRegistry.default_registry() + registry2 = SchemaRegistry.default_registry() + assert registry1 is registry2 + assert isinstance(registry1, SchemaRegistry) + + def test_load_all_versions_non_existent_dir(self, tmp_path): + base_dir = tmp_path / "non_existent" + registry = SchemaRegistry(str(base_dir)) + registry.load_all_versions() + assert registry.versions == {} + + def test_load_all_versions_filtering(self, tmp_path): + base_dir = tmp_path / "schemas" + base_dir.mkdir() + (base_dir / "not_a_version_dir").mkdir() + (base_dir / "v1").mkdir() + (base_dir / "some_file.txt").write_text("content") + + registry = SchemaRegistry(str(base_dir)) + with patch.object(registry, "_load_version_dir") as mock_load: + registry.load_all_versions() + mock_load.assert_called_once() + assert mock_load.call_args[0][0] == "v1" + + def test_load_version_dir_filtering(self, tmp_path): + version_dir = tmp_path / "v1" + version_dir.mkdir() + (version_dir / "schema1.json").write_text("{}") + (version_dir / "not_a_schema.txt").write_text("content") + + registry = SchemaRegistry(str(tmp_path)) + with patch.object(registry, "_load_schema") as mock_load: + registry._load_version_dir("v1", version_dir) + mock_load.assert_called_once() + assert mock_load.call_args[0][1] == "schema1" + + def test_load_version_dir_non_existent(self, tmp_path): + version_dir = tmp_path / "non_existent" + registry = SchemaRegistry(str(tmp_path)) + registry._load_version_dir("v1", version_dir) + assert "v1" not in registry.versions + + def test_load_schema_success(self, tmp_path): + schema_path = tmp_path / "test.json" + schema_content = {"title": "Test Schema", "description": "A test schema"} + schema_path.write_text(json.dumps(schema_content)) + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + registry._load_schema("v1", "test", schema_path) + + assert registry.versions["v1"]["test"] == schema_content + uri = "https://dify.ai/schemas/v1/test.json" + assert registry.metadata[uri]["title"] == "Test Schema" + assert registry.metadata[uri]["version"] == "v1" + + def test_load_schema_invalid_json(self, tmp_path, caplog): + schema_path = tmp_path / "invalid.json" + schema_path.write_text("invalid json") + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + registry._load_schema("v1", "invalid", schema_path) + + assert "Failed to load schema v1/invalid" in caplog.text + + def test_load_schema_os_error(self, tmp_path, caplog): + schema_path = tmp_path / "error.json" + schema_path.write_text("{}") + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + + with patch("builtins.open", side_effect=OSError("Read error")): + registry._load_schema("v1", "error", schema_path) + + assert "Failed to load schema v1/error" in caplog.text + + def test_get_schema(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"test": {"type": "object"}}} + + # Valid URI + assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"} + + # Invalid URI + assert registry.get_schema("invalid-uri") is None + + # Missing version + assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None + + def test_list_versions(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v2": {}, "v1": {}} + assert registry.list_versions() == ["v1", "v2"] + + def test_list_schemas(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"b": {}, "a": {}}} + + assert registry.list_schemas("v1") == ["a", "b"] + assert registry.list_schemas("v2") == [] + + def test_get_all_schemas_for_version(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"test": {"title": "Test Label"}}} + + results = registry.get_all_schemas_for_version("v1") + assert len(results) == 1 + assert results[0]["name"] == "test" + assert results[0]["label"] == "Test Label" + assert results[0]["schema"] == {"title": "Test Label"} + + # Default label if title missing + registry.versions["v1"]["no_title"] = {} + results = registry.get_all_schemas_for_version("v1") + item = next(r for r in results if r["name"] == "no_title") + assert item["label"] == "no_title" + + # Empty if version missing + assert registry.get_all_schemas_for_version("v2") == [] diff --git a/api/tests/unit_tests/core/schemas/test_schema_manager.py b/api/tests/unit_tests/core/schemas/test_schema_manager.py new file mode 100644 index 0000000000..cb07340c6d --- /dev/null +++ b/api/tests/unit_tests/core/schemas/test_schema_manager.py @@ -0,0 +1,80 @@ +from unittest.mock import MagicMock, patch + +from core.schemas.registry import SchemaRegistry +from core.schemas.schema_manager import SchemaManager + + +def test_init_with_provided_registry(): + mock_registry = MagicMock(spec=SchemaRegistry) + manager = SchemaManager(registry=mock_registry) + assert manager.registry == mock_registry + + +@patch("core.schemas.schema_manager.SchemaRegistry.default_registry") +def test_init_with_default_registry(mock_default_registry): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_default_registry.return_value = mock_registry + + manager = SchemaManager() + + mock_default_registry.assert_called_once() + assert manager.registry == mock_registry + + +def test_get_all_schema_definitions(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}] + mock_registry.get_all_schemas_for_version.return_value = expected_definitions + + manager = SchemaManager(registry=mock_registry) + result = manager.get_all_schema_definitions(version="v2") + + mock_registry.get_all_schemas_for_version.assert_called_once_with("v2") + assert result == expected_definitions + + +def test_get_schema_by_name_success(): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_schema = {"type": "object"} + mock_registry.get_schema.return_value = mock_schema + + manager = SchemaManager(registry=mock_registry) + result = manager.get_schema_by_name("my_schema", version="v1") + + expected_uri = "https://dify.ai/schemas/v1/my_schema.json" + mock_registry.get_schema.assert_called_once_with(expected_uri) + assert result == {"name": "my_schema", "schema": mock_schema} + + +def test_get_schema_by_name_not_found(): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_registry.get_schema.return_value = None + + manager = SchemaManager(registry=mock_registry) + result = manager.get_schema_by_name("non_existent", version="v1") + + assert result is None + + +def test_list_available_schemas(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_schemas = ["schema1", "schema2"] + mock_registry.list_schemas.return_value = expected_schemas + + manager = SchemaManager(registry=mock_registry) + result = manager.list_available_schemas(version="v1") + + mock_registry.list_schemas.assert_called_once_with("v1") + assert result == expected_schemas + + +def test_list_available_versions(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_versions = ["v1", "v2"] + mock_registry.list_versions.return_value = expected_versions + + manager = SchemaManager(registry=mock_registry) + result = manager.list_available_versions() + + mock_registry.list_versions.assert_called_once() + assert result == expected_versions From b53675a16c910082536c8ae020c5ab7ea3d43caa Mon Sep 17 00:00:00 2001 From: Poojan