test: added for core module moderation, repositories, schemas (#32514)

Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
This commit is contained in:
mahammadasim 2026-03-22 21:27:12 +05:30 committed by GitHub
parent 40846c262c
commit 31506b27ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 2561 additions and 66 deletions

View File

@ -0,0 +1,181 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
from models.api_based_extension import APIBasedExtension
class TestApiModeration:
@pytest.fixture
def api_config(self):
return {
"inputs_config": {
"enabled": True,
},
"outputs_config": {
"enabled": True,
},
"api_based_extension_id": "test-extension-id",
}
@pytest.fixture
def api_moderation(self, api_config):
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
def test_moderation_input_params(self):
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
assert params.app_id == "app-1"
assert params.inputs == {"key": "val"}
assert params.query == "test query"
# Test defaults
params_default = ModerationInputParams()
assert params_default.app_id == ""
assert params_default.inputs == {}
assert params_default.query == ""
def test_moderation_output_params(self):
params = ModerationOutputParams(app_id="app-1", text="test text")
assert params.app_id == "app-1"
assert params.text == "test text"
with pytest.raises(ValidationError):
ModerationOutputParams()
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_success(self, mock_get_extension, api_config):
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
ApiModeration.validate_config("test-tenant-id", api_config)
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
def test_validate_config_missing_extension_id(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": True},
}
with pytest.raises(ValueError, match="api_based_extension_id is required"):
ApiModeration.validate_config("test-tenant-id", config)
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
mock_get_extension.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
ApiModeration.validate_config("test-tenant-id", api_config)
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
assert isinstance(result, ModerationInputsResult)
assert result.flagged is True
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == "Blocked by API"
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_INPUT,
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
)
def test_moderation_for_inputs_disabled(self):
config = {
"inputs_config": {"enabled": False},
"outputs_config": {"enabled": True},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_inputs(inputs={}, query="")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == ""
def test_moderation_for_inputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_inputs({}, "")
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
result = api_moderation.moderation_for_outputs(text="hello world")
assert isinstance(result, ModerationOutputsResult)
assert result.flagged is False
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
)
def test_moderation_for_outputs_disabled(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": False},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_outputs(text="test")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
def test_moderation_for_outputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_outputs("test")
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
@patch("core.moderation.api.api.decrypt_token")
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_ext.api_endpoint = "http://api.test"
mock_ext.api_key = "encrypted-key"
mock_get_ext.return_value = mock_ext
mock_decrypt.return_value = "decrypted-key"
mock_requestor = MagicMock()
mock_requestor.request.return_value = {"flagged": True}
mock_requestor_cls.return_value = mock_requestor
params = {"some": "params"}
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
assert result == {"flagged": True}
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
def test_get_config_by_requestor_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
mock_get_ext.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.db.session.scalar")
def test_get_api_based_extension(self, mock_scalar):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_scalar.return_value = mock_ext
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
assert result == mock_ext
mock_scalar.assert_called_once()
# Verify the call has the correct filters
args, kwargs = mock_scalar.call_args
stmt = args[0]
# We can't easily inspect the statement without complex sqlalchemy tricks,
# but calling it is usually enough for unit tests if we mock the result.

View File

@ -0,0 +1,207 @@
from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
from core.moderation.input_moderation import InputModeration
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager
class TestInputModeration:
@pytest.fixture
def app_config(self):
config = MagicMock(spec=AppConfig)
config.sensitive_word_avoidance = None
return config
@pytest.fixture
def input_moderation(self):
return InputModeration()
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {"keywords": ["bad"]}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
mock_factory_cls.assert_called_once_with(
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
)
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
@patch("core.moderation.input_moderation.ModerationFactory")
@patch("core.moderation.input_moderation.TraceTask")
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
trace_manager = MagicMock(spec=TraceQueueManager)
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
trace_manager=trace_manager,
)
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
mock_trace_task.assert_called_once()
call_kwargs = mock_trace_task.call_args.kwargs
call_args = mock_trace_task.call_args.args
assert call_args[0] == TraceTaskName.MODERATION_TRACE
assert call_kwargs["message_id"] == message_id
assert call_kwargs["moderation_result"] == mock_result
assert call_kwargs["inputs"] == inputs
assert "timer" in call_kwargs
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
)
mock_factory.moderation_for_inputs.return_value = mock_result
with pytest.raises(ModerationError) as excinfo:
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert str(excinfo.value) == "Blocked content"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True,
action=ModerationAction.OVERRIDDEN,
inputs={"input_key": "overridden_value"},
query="overridden query",
)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is True
assert final_inputs == {"input_key": "overridden_value"}
assert final_query == "overridden query"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = MagicMock()
mock_result.flagged = True
mock_result.action = "NONE" # Some other action
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert flagged is True
assert final_inputs == inputs
assert final_query == query

View File

@ -0,0 +1,234 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import QueueMessageReplaceEvent
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.output_moderation import ModerationRule, OutputModeration
class TestOutputModeration:
@pytest.fixture
def mock_queue_manager(self):
return MagicMock(spec=AppQueueManager)
@pytest.fixture
def moderation_rule(self):
return ModerationRule(type="keywords", config={"keywords": "badword"})
@pytest.fixture
def output_moderation(self, mock_queue_manager, moderation_rule):
return OutputModeration(
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
)
def test_should_direct_output(self, output_moderation):
assert output_moderation.should_direct_output() is False
output_moderation.final_output = "blocked"
assert output_moderation.should_direct_output() is True
def test_get_final_output(self, output_moderation):
assert output_moderation.get_final_output() == ""
output_moderation.final_output = "blocked"
assert output_moderation.get_final_output() == "blocked"
def test_append_new_token(self, output_moderation):
with patch.object(OutputModeration, "start_thread") as mock_start:
output_moderation.append_new_token("hello")
assert output_moderation.buffer == "hello"
mock_start.assert_called_once()
output_moderation.thread = MagicMock()
output_moderation.append_new_token(" world")
assert output_moderation.buffer == "hello world"
assert mock_start.call_count == 1
def test_moderation_completion_no_flag(self, output_moderation):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output, flagged = output_moderation.moderation_completion("safe content")
assert output == "safe content"
assert flagged is False
assert output_moderation.is_final_chunk is True
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "preset"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert isinstance(args[0], QueueMessageReplaceEvent)
assert args[0].text == "preset"
assert args[1] == PublishFrom.TASK_PIPELINE
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "masked content"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked content"
def test_start_thread(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
mock_current_app._get_current_object.return_value = mock_app
with patch("threading.Thread") as mock_thread_class:
mock_thread_instance = MagicMock()
mock_thread_class.return_value = mock_thread_instance
thread = output_moderation.start_thread()
assert thread == mock_thread_instance
mock_thread_class.assert_called_once()
mock_thread_instance.start.assert_called_once()
def test_stop_thread(self, output_moderation):
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
output_moderation.thread = mock_thread
output_moderation.stop_thread()
assert output_moderation.thread_running is False
output_moderation.thread_running = True
mock_thread.is_alive.return_value = False
output_moderation.stop_thread()
assert output_moderation.thread_running is True
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_success(self, mock_factory_class, output_moderation):
mock_factory = mock_factory_class.return_value
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_outputs.return_value = mock_result
result = output_moderation.moderation("tenant", "app", "buffer")
assert result == mock_result
mock_factory_class.assert_called_once_with(
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
)
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_exception(self, mock_factory_class, output_moderation):
mock_factory_class.side_effect = Exception("error")
result = output_moderation.moderation("tenant", "app", "buffer")
assert result is None
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
# Test exit on thread_running=False
output_moderation.thread_running = False
output_moderation.worker(mock_app, 10)
# Should exit immediately
def test_worker_no_flag(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output_moderation.buffer = "safe"
output_moderation.is_final_chunk = True
# To avoid infinite loop, we'll set thread_running to False after one iteration
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
return mock_moderation.return_value
mock_moderation.side_effect = side_effect
output_moderation.worker(mock_app, 10)
assert mock_moderation.called
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
assert output_moderation.final_output == "preset"
mock_queue_manager.publish.assert_called_once()
# It breaks on DIRECT_OUTPUT
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Use side_effect to change thread_running on second call
def side_effect(*args, **kwargs):
if mock_moderation.call_count > 1:
output_moderation.thread_running = False
return None
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
mock_moderation.side_effect = side_effect
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked"
def test_worker_chunk_too_small(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("time.sleep") as mock_sleep:
# chunk_length < buffer_size and not is_final_chunk
output_moderation.buffer = "123" # length 3
output_moderation.is_final_chunk = False
def sleep_side_effect(seconds):
output_moderation.thread_running = False
mock_sleep.side_effect = sleep_side_effect
output_moderation.worker(mock_app, 10) # buffer_size 10
mock_sleep.assert_called_once_with(1)
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Return None (exception or no rule)
mock_moderation.return_value = None
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
mock_moderation.side_effect = side_effect
output_moderation.buffer = "something"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_not_called()

View File

@ -0,0 +1,677 @@
from __future__ import annotations
import dataclasses
import json
from collections.abc import Sequence
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository,
_HumanInputFormEntityImpl,
_HumanInputFormRecipientEntityImpl,
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
from dify_graph.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
)
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
from libs.datetime_utils import naive_utc_now
from models.human_input import HumanInputFormRecipient, RecipientType
@pytest.fixture(autouse=True)
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
class _FakeSelect:
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
payload: dict[str, Any] = {
"form_content": "hi",
"inputs": [],
"user_actions": [{"id": "submit", "title": "Submit"}],
"rendered_content": "<p>hi</p>",
}
if include_expiration_time:
payload["expiration_time"] = naive_utc_now()
return json.dumps(payload, default=str)
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str | None
node_id: str
tenant_id: str
app_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str | None
class _FakeScalarResult:
def __init__(self, obj: Any):
self._obj = obj
def first(self) -> Any:
if isinstance(self._obj, list):
return self._obj[0] if self._obj else None
return self._obj
def all(self) -> list[Any]:
if self._obj is None:
return []
if isinstance(self._obj, list):
return list(self._obj)
return [self._obj]
class _FakeExecuteResult:
def __init__(self, rows: Sequence[tuple[Any, ...]]):
self._rows = list(rows)
def all(self) -> list[tuple[Any, ...]]:
return list(self._rows)
class _FakeSession:
def __init__(
self,
*,
scalars_result: Any = None,
scalars_results: list[Any] | None = None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
execute_rows: Sequence[tuple[Any, ...]] = (),
):
if scalars_results is not None:
self._scalars_queue = list(scalars_results)
else:
self._scalars_queue = [scalars_result]
self._forms = forms or {}
self._recipients = recipients or {}
self._execute_rows = list(execute_rows)
self.added: list[Any] = []
def scalars(self, _query: Any) -> _FakeScalarResult:
if self._scalars_queue:
value = self._scalars_queue.pop(0)
else:
value = None
return _FakeScalarResult(value)
def execute(self, _stmt: Any) -> _FakeExecuteResult:
return _FakeExecuteResult(self._execute_rows)
def get(self, model_cls: Any, obj_id: str) -> Any:
name = getattr(model_cls, "__name__", "")
if name == "HumanInputForm":
return self._forms.get(obj_id)
if name == "HumanInputFormRecipient":
return self._recipients.get(obj_id)
return None
def add(self, obj: Any) -> None:
self.added.append(obj)
def add_all(self, objs: Sequence[Any]) -> None:
self.added.extend(list(objs))
def flush(self) -> None:
# Simulate DB default population for attributes referenced in entity wrappers.
for obj in self.added:
if hasattr(obj, "id") and obj.id in (None, ""):
obj.id = f"gen-{len(str(self.added))}"
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
if obj.recipient_type == RecipientType.CONSOLE:
obj.access_token = "token-console"
elif obj.recipient_type == RecipientType.BACKSTAGE:
obj.access_token = "token-backstage"
else:
obj.access_token = "token-webapp"
def refresh(self, _obj: Any) -> None:
return None
def begin(self) -> _FakeSession:
return self
def __enter__(self) -> _FakeSession:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
class _SessionFactoryStub:
def __init__(self, session: _FakeSession):
self._session = session
def create_session(self) -> _FakeSession:
return self._session
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
def test_recipient_entity_token_raises_when_missing() -> None:
recipient = SimpleNamespace(id="r1", access_token=None)
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
with pytest.raises(AssertionError, match="access_token should not be None"):
_ = entity.token
def test_recipient_entity_id_and_token_success() -> None:
recipient = SimpleNamespace(id="r1", access_token="tok")
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
assert entity.id == "r1"
assert entity.token == "tok"
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
webapp = _DummyRecipient(
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
assert entity.web_app_token == "ctok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
assert entity.web_app_token == "wtok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.web_app_token is None
def test_form_entity_submitted_data_parsed() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
submitted_data='{"a": 1}',
submitted_at=naive_utc_now(),
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.submitted is True
assert entity.submitted_data == {"a": 1}
assert entity.rendered_content == "<p>x</p>"
assert entity.selected_action_id is None
assert entity.status == HumanInputFormStatus.WAITING
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
expiration = naive_utc_now()
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=False),
rendered_content="<p>x</p>",
expiration_time=expiration,
submitted_data='{"k": "v"}',
)
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
assert record.definition.expiration_time == expiration
assert record.submitted_data == {"k": "v"}
assert record.submitted is False
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
created: list[SimpleNamespace] = []
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
recipient = SimpleNamespace(
id=f"{payload.TYPE}-{len(created)}",
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
access_token="tok",
)
created.append(recipient)
return recipient
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
form_id="f",
delivery_id="d",
members=[
_WorkspaceMemberInfo(user_id="u1", email=""),
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
],
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
)
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
def test_query_workspace_members_by_ids_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
assert rows == [
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
]
def test_query_all_workspace_members_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_all_workspace_members(session=session)
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
def test_repository_init_sets_tenant_id() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._tenant_id == "tenant"
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
result = repo._delivery_method_to_model(
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
)
assert result.delivery.id == "del-1"
assert result.delivery.form_id == "form-1"
assert len(result.recipients) == 1
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
called: dict[str, Any] = {}
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
called.update(
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
)
return ["r"]
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
subject="s",
body="b",
)
)
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
assert result.recipients == ["r"]
assert called["delivery_id"] == "del-1"
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr(
repo,
"_query_all_workspace_members",
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
)
assert recipients == ["ok"]
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
assert restrict_to_user_ids == ["u1"]
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
)
assert recipients == ["ok"]
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo.get_form("run", "node") is None
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
)
session = _FakeSession(scalars_results=[form, [recipient]])
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
entity = repo.get_form("run", "node")
assert entity is not None
assert entity.id == "f1"
assert entity.recipients[0].id == "r1"
assert entity.recipients[0].token == "tok"
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
session = _FakeSession()
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
form_config = HumanInputNodeData(
title="Title",
delivery_methods=[],
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
)
params = FormCreateParams(
app_id="app",
workflow_execution_id="run",
node_id="node",
form_config=form_config,
rendered_content="<p>hello</p>",
delivery_methods=[WebAppDeliveryMethod()],
display_in_ui=True,
resolved_default_values={},
form_kind=HumanInputFormKind.RUNTIME,
console_recipient_required=True,
console_creator_account_id="acc-1",
backstage_recipient_required=True,
)
entity = repo.create_form(params)
assert entity.id == "form-id"
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
# Console token should take precedence when console recipient is present.
assert entity.web_app_token == "token-console"
assert len(entity.recipients) == 3
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
recipient = SimpleNamespace(form=None)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
def test_submission_repository_init_no_args() -> None:
repo = HumanInputFormSubmissionRepository()
assert isinstance(repo, HumanInputFormSubmissionRepository)
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = SimpleNamespace(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
form=form,
)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_token("tok")
assert record is not None
assert record.access_token == "tok"
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
assert record is not None
assert record.recipient_id == "r1"
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
missing_session = _FakeSession(forms={})
_patch_session_factory(monkeypatch, missing_session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_submitted(
form_id="missing",
recipient_id=None,
selected_action_id="a",
form_data={},
submission_user_id=None,
submission_end_user_id=None,
)
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"k": "v"},
submission_user_id="u",
submission_end_user_id="eu",
)
assert form.status == HumanInputFormStatus.SUBMITTED
assert form.submitted_at == fixed_now
assert record.submitted_data == {"k": "v"}
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.TIMEOUT,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
assert record.status == HumanInputFormStatus.TIMEOUT
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.SUBMITTED,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form already submitted"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
selected_action_id="a",
submitted_data="{}",
submission_user_id="u",
submission_end_user_id="eu",
completed_by_recipient_id="r",
status=HumanInputFormStatus.WAITING,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
assert form.status == HumanInputFormStatus.EXPIRED
assert form.selected_action_id is None
assert form.submitted_data is None
assert form.submission_user_id is None
assert form.submission_end_user_id is None
assert form.completed_by_recipient_id is None
assert record.status == HumanInputFormStatus.EXPIRED
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)

View File

@ -1,84 +1,291 @@
from datetime import datetime
from datetime import UTC, datetime
from unittest.mock import MagicMock
from uuid import uuid4
from sqlalchemy import create_engine
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, WorkflowRun
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from models import Account, CreatorUserRole, EndUser, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
engine = create_engine("sqlite:///:memory:")
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
user = MagicMock(spec=Account)
user.id = str(uuid4())
user.current_tenant_id = str(uuid4())
repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=real_session_factory,
user=user,
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
session_context = MagicMock()
session_context.__enter__.return_value = session
session_context.__exit__.return_value = False
repository._session_factory = MagicMock(return_value=session_context)
return repository
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
return WorkflowExecution.new(
id_=execution_id,
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0.0",
graph={"nodes": [], "edges": []},
inputs={"query": "hello"},
started_at=started_at,
)
def test_save_uses_execution_started_at_when_record_does_not_exist():
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
session_factory = MagicMock(spec=sessionmaker)
session = MagicMock()
session.get.return_value = None
repository = _build_repository_with_mocked_session(session)
started_at = datetime(2026, 1, 1, 12, 0, 0)
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
session_factory.return_value.__enter__.return_value = session
return session_factory
def test_save_preserves_existing_created_at_when_record_already_exists():
session = MagicMock()
repository = _build_repository_with_mocked_session(session)
@pytest.fixture
def mock_engine():
"""Mock SQLAlchemy Engine."""
return MagicMock(spec=Engine)
execution_id = str(uuid4())
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repository._tenant_id
existing_run.created_at = existing_created_at
session.get.return_value = existing_run
execution = _build_execution(
execution_id=execution_id,
started_at=datetime(2026, 1, 1, 12, 30, 0),
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = MagicMock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = MagicMock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
outputs={"output1": "result1"},
status=WorkflowExecutionStatus.SUCCEEDED,
error_message="",
total_tokens=100,
total_steps=5,
exceptions_count=0,
started_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()
class TestSQLAlchemyWorkflowExecutionRepository:
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
app_id = "test_app_id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
)
assert repo._session_factory == mock_session_factory
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
def test_init_with_engine(self, mock_engine, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_engine,
user=mock_account,
app_id="test_app_id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
assert repo._session_factory.kw["bind"] == mock_engine
def test_init_invalid_session_factory(self, mock_account):
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowExecutionRepository(
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
)
def test_init_no_tenant_id(self, mock_session_factory):
user = MagicMock(spec=Account)
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
)
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
)
assert repo._tenant_id == mock_end_user.tenant_id
assert repo._creator_user_role == CreatorUserRole.END_USER
def test_to_domain_model(self, mock_session_factory, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
db_model = MagicMock(spec=WorkflowRun)
db_model.id = str(uuid4())
db_model.workflow_id = str(uuid4())
db_model.type = "workflow"
db_model.version = "1.0"
db_model.inputs_dict = {"in": "val"}
db_model.outputs_dict = {"out": "val"}
db_model.graph_dict = {"nodes": []}
db_model.status = "succeeded"
db_model.error = "some error"
db_model.total_tokens = 50
db_model.total_steps = 3
db_model.exceptions_count = 1
db_model.created_at = datetime.now(UTC)
db_model.finished_at = datetime.now(UTC)
domain_model = repo._to_domain_model(db_model)
assert domain_model.id_ == db_model.id
assert domain_model.workflow_id == db_model.workflow_id
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
assert domain_model.inputs == db_model.inputs_dict
assert domain_model.error_message == "some error"
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Make elapsed time deterministic to avoid flaky tests
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.id == sample_workflow_execution.id_
assert db_model.tenant_id == repo._tenant_id
assert db_model.app_id == "test_app"
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
assert db_model.status == sample_workflow_execution.status.value
assert db_model.total_tokens == sample_workflow_execution.total_tokens
assert db_model.elapsed_time == 10.0
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Test with empty/None fields
sample_workflow_execution.graph = None
sample_workflow_execution.inputs = None
sample_workflow_execution.outputs = None
sample_workflow_execution.error_message = None
sample_workflow_execution.finished_at = None
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.graph is None
assert db_model.inputs is None
assert db_model.outputs is None
assert db_model.error is None
assert db_model.elapsed_time == 0
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=None,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
db_model = repo._to_db_model(sample_workflow_execution)
assert not hasattr(db_model, "app_id") or db_model.app_id is None
assert db_model.tenant_id == repo._tenant_id
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
# Test triggered_from missing
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(sample_workflow_execution)
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(sample_workflow_execution)
repo._creator_user_id = "some_id"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(sample_workflow_execution)
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
session = mock_session_factory.return_value.__enter__.return_value
session.merge.assert_called_once()
session.commit.assert_called_once()
# Check cache
assert sample_workflow_execution.id_ in repo._execution_cache
cached_model = repo._execution_cache[sample_workflow_execution.id_]
assert cached_model.id == sample_workflow_execution.id_
def test_save_uses_execution_started_at_when_record_does_not_exist(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
sample_workflow_execution.started_at = started_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = None
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
def test_save_preserves_existing_created_at_when_record_already_exists(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution_id = sample_workflow_execution.id_
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repo._tenant_id
existing_run.created_at = existing_created_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = existing_run
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()

View File

@ -0,0 +1,772 @@
from __future__ import annotations
import json
import logging
from collections.abc import Mapping
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock
import psycopg2.errors
import pytest
from sqlalchemy import Engine, create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
_deterministic_json_dump,
_filter_by_offload_type,
_find_first,
_replace_or_append_offload,
)
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.enums import ExecutionOffLoadType
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
return user
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
user = Mock(spec=EndUser)
user.id = user_id
user.tenant_id = tenant_id
return user
def _execution(
*,
execution_id: str = "exec-id",
node_execution_id: str = "node-exec-id",
workflow_run_id: str = "run-id",
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
inputs: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
) -> WorkflowNodeExecution:
return WorkflowNodeExecution(
id=execution_id,
node_execution_id=node_execution_id,
workflow_id="workflow-id",
workflow_execution_id=workflow_run_id,
index=1,
predecessor_node_id=None,
node_id="node-id",
node_type=NodeType.LLM,
title="Title",
inputs=inputs,
outputs=outputs,
process_data=process_data,
status=status,
error=None,
elapsed_time=1.0,
metadata=metadata,
created_at=datetime.now(UTC),
finished_at=None,
)
class _SessionCtx:
def __init__(self, session: Any):
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type, exc, tb) -> None:
return None
def _session_factory(session: Any) -> sessionmaker:
factory = Mock(spec=sessionmaker)
factory.return_value = _SessionCtx(session)
return factory
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
engine: Engine = create_engine("sqlite:///:memory:")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=engine,
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
sm = Mock(spec=sessionmaker)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=sm,
user=_mock_end_user(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._creator_user_role.value == "end_user"
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
session_factory=object(),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
user = _mock_account()
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=user,
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
created: dict[str, Any] = {}
class FakeTruncator:
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
created.update(
{
"max_size_bytes": max_size_bytes,
"array_element_limit": array_element_limit,
"string_length_limit": string_length_limit,
}
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
FakeTruncator,
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
_ = repo._create_truncator()
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
assert _find_first([], lambda _: True) is None
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
assert len(replaced) == 2
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
# Happy path: deterministic json dump should be sorted
db_model = repo._to_db_model(execution)
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
repo._triggered_from = None
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(execution)
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution()
db_model = repo._to_db_model(execution)
assert db_model.app_id == "app"
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(execution)
repo._creator_user_id = "user"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(execution)
def test_is_duplicate_key_error_and_regenerate_id(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
assert repo._is_duplicate_key_error(duplicate_error) is True
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
execution = _execution(execution_id="old-id")
db_model = WorkflowNodeExecutionModel()
db_model.id = "old-id"
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
caplog.set_level(logging.WARNING)
repo._regenerate_id_on_duplicate(execution, db_model)
assert execution.id == "new-id"
assert db_model.id == "new-id"
assert any("Duplicate key conflict" in r.message for r in caplog.records)
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id1"
db_model.node_execution_id = "node1"
db_model.foo = "bar" # type: ignore[attr-defined]
db_model.__dict__["_private"] = "x"
existing = SimpleNamespace()
session.get.return_value = existing
repo._persist_to_database(db_model)
assert existing.foo == "bar"
session.add.assert_not_called()
assert repo._node_execution_cache["node1"] is db_model
session.reset_mock()
session.get.return_value = None
repo._node_execution_cache.clear()
repo._persist_to_database(db_model)
session.add.assert_called_once_with(db_model)
assert repo._node_execution_cache["node1"] is db_model
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return value, False
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
uploaded: dict[str, Any] = {}
class FakeFileService:
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
return SimpleNamespace(id="file-id", key="file-key")
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return {"truncated": True}, True
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
assert result is not None
assert result.truncated_value == {"truncated": True}
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
assert result.offload.file_id == "file-id"
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"trunc": "i"})
db_model.process_data = json.dumps({"trunc": "p"})
db_model.outputs = json.dumps({"trunc": "o"})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = json.dumps({"total_tokens": 3})
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
off_in.file = SimpleNamespace(key="k-in")
off_out.file = SimpleNamespace(key="k-out")
off_proc.file = SimpleNamespace(key="k-proc")
db_model.offload_data = [off_out, off_in, off_proc]
def fake_load(key: str) -> bytes:
return json.dumps({"full": key}).encode()
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"full": "k-in"}
assert domain.outputs == {"full": "k-out"}
assert domain.process_data == {"full": "k-proc"}
assert domain.get_truncated_inputs() == {"trunc": "i"}
assert domain.get_truncated_outputs() == {"trunc": "o"}
assert domain.get_truncated_process_data() == {"trunc": "p"}
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"i": 1})
db_model.process_data = json.dumps({"p": 2})
db_model.outputs = json.dumps({"o": 3})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = "{}"
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
db_model.offload_data = []
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"i": 1}
assert domain.outputs == {"o": 3}
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeConverter:
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
return {"wrapped": values["a"]}
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
FakeConverter,
)
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
id="id",
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
inputs=None,
outputs=None,
process_data=None,
)
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
trunc_result = SimpleNamespace(
truncated_value={"trunc": True},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
)
monkeypatch.setattr(
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
# Inputs should be truncated, outputs/process_data encoded directly
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.inputs) == {"trunc": True}
assert json.loads(db_model.outputs) == {"b": 2}
assert json.loads(db_model.process_data) == {"c": 3}
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
assert execution.get_truncated_inputs() == {"trunc": True}
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
existing = SimpleNamespace(
id="id",
offload_data=[],
inputs=None,
outputs=None,
process_data=None,
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = existing
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
if values == {"b": 2}:
return SimpleNamespace(
truncated_value={"b": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
)
if values == {"c": 3}:
return SimpleNamespace(
truncated_value={"c": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
)
return None
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.outputs) == {"b": "trunc"}
assert json.loads(db_model.process_data) == {"c": "trunc"}
assert execution.get_truncated_outputs() == {"b": "trunc"}
assert execution.get_truncated_process_data() == {"c": "trunc"}
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = None
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1})
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
repo.save_execution_data(execution)
merged = session.merge.call_args.args[0]
assert merged.inputs == '{"a": 1}'
def test_save_retries_duplicate_and_logs_non_duplicate(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(execution_id="id")
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
other_error = IntegrityError("other", params=None, orig=None)
calls = {"n": 0}
def persist(_db_model: Any) -> None:
calls["n"] += 1
if calls["n"] == 1:
raise duplicate_error
monkeypatch.setattr(repo, "_persist_to_database", persist)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
repo.save(execution)
assert execution.id == "new-id"
assert repo._node_execution_cache[execution.node_execution_id] is not None
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
with pytest.raises(IntegrityError):
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
def test_save_logs_and_reraises_on_unexpected_error(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
with pytest.raises(RuntimeError, match="boom"):
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def __init__(self) -> None:
self.where_calls = 0
self.order_by_args: tuple[Any, ...] | None = None
def where(self, *_args: Any) -> FakeStmt:
self.where_calls += 1
return self
def order_by(self, *args: Any) -> FakeStmt:
self.order_by_args = args
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
model1 = SimpleNamespace(node_execution_id="n1")
model2 = SimpleNamespace(node_execution_id=None)
session = MagicMock()
session.scalars.return_value.all.return_value = [model1, model2]
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
db_models = repo.get_db_models_by_workflow_run("run", order)
assert db_models == [model1, model2]
assert repo._node_execution_cache["n1"] is model1
assert stmt.order_by_args is not None
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def where(self, *_args: Any) -> FakeStmt:
return self
def order_by(self, *args: Any) -> FakeStmt:
self.args = args # type: ignore[attr-defined]
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
session = MagicMock()
session.scalars.return_value.all.return_value = []
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
class FakeExecutor:
def __enter__(self) -> FakeExecutor:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
assert timeout == 30
return list(map(func, items))
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
lambda max_workers: FakeExecutor(),
)
result = repo.get_by_workflow_run("run", order_config=None)
assert result == ["domain:db1", "domain:db2"]

View File

@ -0,0 +1,137 @@
import json
from unittest.mock import patch
from core.schemas.registry import SchemaRegistry
class TestSchemaRegistry:
def test_initialization(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
registry = SchemaRegistry(str(base_dir))
assert registry.base_dir == base_dir
assert registry.versions == {}
assert registry.metadata == {}
def test_default_registry_singleton(self):
registry1 = SchemaRegistry.default_registry()
registry2 = SchemaRegistry.default_registry()
assert registry1 is registry2
assert isinstance(registry1, SchemaRegistry)
def test_load_all_versions_non_existent_dir(self, tmp_path):
base_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(base_dir))
registry.load_all_versions()
assert registry.versions == {}
def test_load_all_versions_filtering(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
(base_dir / "not_a_version_dir").mkdir()
(base_dir / "v1").mkdir()
(base_dir / "some_file.txt").write_text("content")
registry = SchemaRegistry(str(base_dir))
with patch.object(registry, "_load_version_dir") as mock_load:
registry.load_all_versions()
mock_load.assert_called_once()
assert mock_load.call_args[0][0] == "v1"
def test_load_version_dir_filtering(self, tmp_path):
version_dir = tmp_path / "v1"
version_dir.mkdir()
(version_dir / "schema1.json").write_text("{}")
(version_dir / "not_a_schema.txt").write_text("content")
registry = SchemaRegistry(str(tmp_path))
with patch.object(registry, "_load_schema") as mock_load:
registry._load_version_dir("v1", version_dir)
mock_load.assert_called_once()
assert mock_load.call_args[0][1] == "schema1"
def test_load_version_dir_non_existent(self, tmp_path):
version_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(tmp_path))
registry._load_version_dir("v1", version_dir)
assert "v1" not in registry.versions
def test_load_schema_success(self, tmp_path):
schema_path = tmp_path / "test.json"
schema_content = {"title": "Test Schema", "description": "A test schema"}
schema_path.write_text(json.dumps(schema_content))
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "test", schema_path)
assert registry.versions["v1"]["test"] == schema_content
uri = "https://dify.ai/schemas/v1/test.json"
assert registry.metadata[uri]["title"] == "Test Schema"
assert registry.metadata[uri]["version"] == "v1"
def test_load_schema_invalid_json(self, tmp_path, caplog):
schema_path = tmp_path / "invalid.json"
schema_path.write_text("invalid json")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "invalid", schema_path)
assert "Failed to load schema v1/invalid" in caplog.text
def test_load_schema_os_error(self, tmp_path, caplog):
schema_path = tmp_path / "error.json"
schema_path.write_text("{}")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
with patch("builtins.open", side_effect=OSError("Read error")):
registry._load_schema("v1", "error", schema_path)
assert "Failed to load schema v1/error" in caplog.text
def test_get_schema(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"type": "object"}}}
# Valid URI
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
# Invalid URI
assert registry.get_schema("invalid-uri") is None
# Missing version
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
def test_list_versions(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v2": {}, "v1": {}}
assert registry.list_versions() == ["v1", "v2"]
def test_list_schemas(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"b": {}, "a": {}}}
assert registry.list_schemas("v1") == ["a", "b"]
assert registry.list_schemas("v2") == []
def test_get_all_schemas_for_version(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
results = registry.get_all_schemas_for_version("v1")
assert len(results) == 1
assert results[0]["name"] == "test"
assert results[0]["label"] == "Test Label"
assert results[0]["schema"] == {"title": "Test Label"}
# Default label if title missing
registry.versions["v1"]["no_title"] = {}
results = registry.get_all_schemas_for_version("v1")
item = next(r for r in results if r["name"] == "no_title")
assert item["label"] == "no_title"
# Empty if version missing
assert registry.get_all_schemas_for_version("v2") == []

View File

@ -0,0 +1,80 @@
from unittest.mock import MagicMock, patch
from core.schemas.registry import SchemaRegistry
from core.schemas.schema_manager import SchemaManager
def test_init_with_provided_registry():
mock_registry = MagicMock(spec=SchemaRegistry)
manager = SchemaManager(registry=mock_registry)
assert manager.registry == mock_registry
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
def test_init_with_default_registry(mock_default_registry):
mock_registry = MagicMock(spec=SchemaRegistry)
mock_default_registry.return_value = mock_registry
manager = SchemaManager()
mock_default_registry.assert_called_once()
assert manager.registry == mock_registry
def test_get_all_schema_definitions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
manager = SchemaManager(registry=mock_registry)
result = manager.get_all_schema_definitions(version="v2")
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
assert result == expected_definitions
def test_get_schema_by_name_success():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_schema = {"type": "object"}
mock_registry.get_schema.return_value = mock_schema
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("my_schema", version="v1")
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
mock_registry.get_schema.assert_called_once_with(expected_uri)
assert result == {"name": "my_schema", "schema": mock_schema}
def test_get_schema_by_name_not_found():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_registry.get_schema.return_value = None
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("non_existent", version="v1")
assert result is None
def test_list_available_schemas():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_schemas = ["schema1", "schema2"]
mock_registry.list_schemas.return_value = expected_schemas
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_schemas(version="v1")
mock_registry.list_schemas.assert_called_once_with("v1")
assert result == expected_schemas
def test_list_available_versions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_versions = ["v1", "v2"]
mock_registry.list_versions.return_value = expected_versions
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_versions()
mock_registry.list_versions.assert_called_once()
assert result == expected_versions