mirror of https://github.com/langgenius/dify.git
test: added for core module moderation, repositories, schemas (#32514)
Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
This commit is contained in:
parent
40846c262c
commit
31506b27ab
|
|
@ -0,0 +1,181 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
|
||||
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
|
||||
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
|
||||
class TestApiModeration:
|
||||
@pytest.fixture
|
||||
def api_config(self):
|
||||
return {
|
||||
"inputs_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"outputs_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"api_based_extension_id": "test-extension-id",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def api_moderation(self, api_config):
|
||||
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
|
||||
|
||||
def test_moderation_input_params(self):
|
||||
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
|
||||
assert params.app_id == "app-1"
|
||||
assert params.inputs == {"key": "val"}
|
||||
assert params.query == "test query"
|
||||
|
||||
# Test defaults
|
||||
params_default = ModerationInputParams()
|
||||
assert params_default.app_id == ""
|
||||
assert params_default.inputs == {}
|
||||
assert params_default.query == ""
|
||||
|
||||
def test_moderation_output_params(self):
|
||||
params = ModerationOutputParams(app_id="app-1", text="test text")
|
||||
assert params.app_id == "app-1"
|
||||
assert params.text == "test text"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ModerationOutputParams()
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_validate_config_success(self, mock_get_extension, api_config):
|
||||
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
|
||||
ApiModeration.validate_config("test-tenant-id", api_config)
|
||||
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
|
||||
|
||||
def test_validate_config_missing_extension_id(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": True},
|
||||
"outputs_config": {"enabled": True},
|
||||
}
|
||||
with pytest.raises(ValueError, match="api_based_extension_id is required"):
|
||||
ApiModeration.validate_config("test-tenant-id", config)
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
|
||||
mock_get_extension.return_value = None
|
||||
with pytest.raises(ValueError, match="API-based Extension not found"):
|
||||
ApiModeration.validate_config("test-tenant-id", api_config)
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
|
||||
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
|
||||
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
|
||||
|
||||
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
|
||||
|
||||
assert isinstance(result, ModerationInputsResult)
|
||||
assert result.flagged is True
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == "Blocked by API"
|
||||
|
||||
mock_get_config.assert_called_once_with(
|
||||
APIBasedExtensionPoint.APP_MODERATION_INPUT,
|
||||
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
|
||||
)
|
||||
|
||||
def test_moderation_for_inputs_disabled(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": False},
|
||||
"outputs_config": {"enabled": True},
|
||||
"api_based_extension_id": "ext-id",
|
||||
}
|
||||
moderation = ApiModeration("app-id", "tenant-id", config)
|
||||
result = moderation.moderation_for_inputs(inputs={}, query="")
|
||||
|
||||
assert result.flagged is False
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == ""
|
||||
|
||||
def test_moderation_for_inputs_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation.moderation_for_inputs({}, "")
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
|
||||
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
|
||||
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
|
||||
|
||||
result = api_moderation.moderation_for_outputs(text="hello world")
|
||||
|
||||
assert isinstance(result, ModerationOutputsResult)
|
||||
assert result.flagged is False
|
||||
|
||||
mock_get_config.assert_called_once_with(
|
||||
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
|
||||
)
|
||||
|
||||
def test_moderation_for_outputs_disabled(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": True},
|
||||
"outputs_config": {"enabled": False},
|
||||
"api_based_extension_id": "ext-id",
|
||||
}
|
||||
moderation = ApiModeration("app-id", "tenant-id", config)
|
||||
result = moderation.moderation_for_outputs(text="test")
|
||||
|
||||
assert result.flagged is False
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
|
||||
def test_moderation_for_outputs_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation.moderation_for_outputs("test")
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
@patch("core.moderation.api.api.decrypt_token")
|
||||
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
|
||||
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
|
||||
mock_ext = MagicMock(spec=APIBasedExtension)
|
||||
mock_ext.api_endpoint = "http://api.test"
|
||||
mock_ext.api_key = "encrypted-key"
|
||||
mock_get_ext.return_value = mock_ext
|
||||
|
||||
mock_decrypt.return_value = "decrypted-key"
|
||||
|
||||
mock_requestor = MagicMock()
|
||||
mock_requestor.request.return_value = {"flagged": True}
|
||||
mock_requestor_cls.return_value = mock_requestor
|
||||
|
||||
params = {"some": "params"}
|
||||
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
|
||||
|
||||
assert result == {"flagged": True}
|
||||
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
|
||||
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
|
||||
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
|
||||
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
|
||||
|
||||
def test_get_config_by_requestor_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
|
||||
mock_get_ext.return_value = None
|
||||
with pytest.raises(ValueError, match="API-based Extension not found"):
|
||||
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
|
||||
|
||||
@patch("core.moderation.api.api.db.session.scalar")
|
||||
def test_get_api_based_extension(self, mock_scalar):
|
||||
mock_ext = MagicMock(spec=APIBasedExtension)
|
||||
mock_scalar.return_value = mock_ext
|
||||
|
||||
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
|
||||
|
||||
assert result == mock_ext
|
||||
mock_scalar.assert_called_once()
|
||||
# Verify the call has the correct filters
|
||||
args, kwargs = mock_scalar.call_args
|
||||
stmt = args[0]
|
||||
# We can't easily inspect the statement without complex sqlalchemy tricks,
|
||||
# but calling it is usually enough for unit tests if we mock the result.
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
|
||||
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class TestInputModeration:
|
||||
@pytest.fixture
|
||||
def app_config(self):
|
||||
config = MagicMock(spec=AppConfig)
|
||||
config.sensitive_word_avoidance = None
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def input_moderation(self):
|
||||
return InputModeration()
|
||||
|
||||
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is False
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {"keywords": ["bad"]}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is False
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
mock_factory_cls.assert_called_once_with(
|
||||
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
|
||||
)
|
||||
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
@patch("core.moderation.input_moderation.TraceTask")
|
||||
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
trace_manager = MagicMock(spec=TraceQueueManager)
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
|
||||
mock_trace_task.assert_called_once()
|
||||
call_kwargs = mock_trace_task.call_args.kwargs
|
||||
call_args = mock_trace_task.call_args.args
|
||||
assert call_args[0] == TraceTaskName.MODERATION_TRACE
|
||||
assert call_kwargs["message_id"] == message_id
|
||||
assert call_kwargs["moderation_result"] == mock_result
|
||||
assert call_kwargs["inputs"] == inputs
|
||||
assert "timer" in call_kwargs
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
|
||||
)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
with pytest.raises(ModerationError) as excinfo:
|
||||
input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == "Blocked content"
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(
|
||||
flagged=True,
|
||||
action=ModerationAction.OVERRIDDEN,
|
||||
inputs={"input_key": "overridden_value"},
|
||||
query="overridden query",
|
||||
)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is True
|
||||
assert final_inputs == {"input_key": "overridden_value"}
|
||||
assert final_query == "overridden query"
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = MagicMock()
|
||||
mock_result.flagged = True
|
||||
mock_result.action = "NONE" # Some other action
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
assert flagged is True
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageReplaceEvent
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
|
||||
|
||||
class TestOutputModeration:
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
return MagicMock(spec=AppQueueManager)
|
||||
|
||||
@pytest.fixture
|
||||
def moderation_rule(self):
|
||||
return ModerationRule(type="keywords", config={"keywords": "badword"})
|
||||
|
||||
@pytest.fixture
|
||||
def output_moderation(self, mock_queue_manager, moderation_rule):
|
||||
return OutputModeration(
|
||||
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
|
||||
)
|
||||
|
||||
def test_should_direct_output(self, output_moderation):
|
||||
assert output_moderation.should_direct_output() is False
|
||||
output_moderation.final_output = "blocked"
|
||||
assert output_moderation.should_direct_output() is True
|
||||
|
||||
def test_get_final_output(self, output_moderation):
|
||||
assert output_moderation.get_final_output() == ""
|
||||
output_moderation.final_output = "blocked"
|
||||
assert output_moderation.get_final_output() == "blocked"
|
||||
|
||||
def test_append_new_token(self, output_moderation):
|
||||
with patch.object(OutputModeration, "start_thread") as mock_start:
|
||||
output_moderation.append_new_token("hello")
|
||||
assert output_moderation.buffer == "hello"
|
||||
mock_start.assert_called_once()
|
||||
|
||||
output_moderation.thread = MagicMock()
|
||||
output_moderation.append_new_token(" world")
|
||||
assert output_moderation.buffer == "hello world"
|
||||
assert mock_start.call_count == 1
|
||||
|
||||
def test_moderation_completion_no_flag(self, output_moderation):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("safe content")
|
||||
|
||||
assert output == "safe content"
|
||||
assert flagged is False
|
||||
assert output_moderation.is_final_chunk is True
|
||||
|
||||
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
|
||||
)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
|
||||
|
||||
assert output == "preset"
|
||||
assert flagged is True
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert isinstance(args[0], QueueMessageReplaceEvent)
|
||||
assert args[0].text == "preset"
|
||||
assert args[1] == PublishFrom.TASK_PIPELINE
|
||||
|
||||
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
|
||||
)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
|
||||
|
||||
assert output == "masked content"
|
||||
assert flagged is True
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert args[0].text == "masked content"
|
||||
|
||||
def test_start_thread(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
|
||||
mock_current_app._get_current_object.return_value = mock_app
|
||||
with patch("threading.Thread") as mock_thread_class:
|
||||
mock_thread_instance = MagicMock()
|
||||
mock_thread_class.return_value = mock_thread_instance
|
||||
|
||||
thread = output_moderation.start_thread()
|
||||
|
||||
assert thread == mock_thread_instance
|
||||
mock_thread_class.assert_called_once()
|
||||
mock_thread_instance.start.assert_called_once()
|
||||
|
||||
def test_stop_thread(self, output_moderation):
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.is_alive.return_value = True
|
||||
output_moderation.thread = mock_thread
|
||||
|
||||
output_moderation.stop_thread()
|
||||
assert output_moderation.thread_running is False
|
||||
|
||||
output_moderation.thread_running = True
|
||||
mock_thread.is_alive.return_value = False
|
||||
output_moderation.stop_thread()
|
||||
assert output_moderation.thread_running is True
|
||||
|
||||
@patch("core.moderation.output_moderation.ModerationFactory")
|
||||
def test_moderation_success(self, mock_factory_class, output_moderation):
|
||||
mock_factory = mock_factory_class.return_value
|
||||
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_outputs.return_value = mock_result
|
||||
|
||||
result = output_moderation.moderation("tenant", "app", "buffer")
|
||||
|
||||
assert result == mock_result
|
||||
mock_factory_class.assert_called_once_with(
|
||||
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
|
||||
)
|
||||
|
||||
@patch("core.moderation.output_moderation.ModerationFactory")
|
||||
def test_moderation_exception(self, mock_factory_class, output_moderation):
|
||||
mock_factory_class.side_effect = Exception("error")
|
||||
|
||||
result = output_moderation.moderation("tenant", "app", "buffer")
|
||||
assert result is None
|
||||
|
||||
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
# Test exit on thread_running=False
|
||||
output_moderation.thread_running = False
|
||||
output_moderation.worker(mock_app, 10)
|
||||
# Should exit immediately
|
||||
|
||||
def test_worker_no_flag(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
|
||||
output_moderation.buffer = "safe"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
# To avoid infinite loop, we'll set thread_running to False after one iteration
|
||||
def side_effect(*args, **kwargs):
|
||||
output_moderation.thread_running = False
|
||||
return mock_moderation.return_value
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
assert mock_moderation.called
|
||||
|
||||
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
|
||||
)
|
||||
|
||||
output_moderation.buffer = "badword"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
assert output_moderation.final_output == "preset"
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
# It breaks on DIRECT_OUTPUT
|
||||
|
||||
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
# Use side_effect to change thread_running on second call
|
||||
def side_effect(*args, **kwargs):
|
||||
if mock_moderation.call_count > 1:
|
||||
output_moderation.thread_running = False
|
||||
return None
|
||||
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.buffer = "badword"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert args[0].text == "masked"
|
||||
|
||||
def test_worker_chunk_too_small(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# chunk_length < buffer_size and not is_final_chunk
|
||||
output_moderation.buffer = "123" # length 3
|
||||
output_moderation.is_final_chunk = False
|
||||
|
||||
def sleep_side_effect(seconds):
|
||||
output_moderation.thread_running = False
|
||||
|
||||
mock_sleep.side_effect = sleep_side_effect
|
||||
|
||||
output_moderation.worker(mock_app, 10) # buffer_size 10
|
||||
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
# Return None (exception or no rule)
|
||||
mock_moderation.return_value = None
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
output_moderation.thread_running = False
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.buffer = "something"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
|
|
@ -0,0 +1,677 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
HumanInputFormSubmissionRepository,
|
||||
_HumanInputFormEntityImpl,
|
||||
_HumanInputFormRecipientEntityImpl,
|
||||
_InvalidTimeoutStatusError,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _FakeSelect:
|
||||
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
|
||||
|
||||
|
||||
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"form_content": "hi",
|
||||
"inputs": [],
|
||||
"user_actions": [{"id": "submit", "title": "Submit"}],
|
||||
"rendered_content": "<p>hi</p>",
|
||||
}
|
||||
if include_expiration_time:
|
||||
payload["expiration_time"] = naive_utc_now()
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyForm:
|
||||
id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_definition: str
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
|
||||
selected_action_id: str | None = None
|
||||
submitted_data: str | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
completed_by_recipient_id: str | None = None
|
||||
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyRecipient:
|
||||
id: str
|
||||
form_id: str
|
||||
recipient_type: RecipientType
|
||||
access_token: str | None
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def first(self) -> Any:
|
||||
if isinstance(self._obj, list):
|
||||
return self._obj[0] if self._obj else None
|
||||
return self._obj
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
if self._obj is None:
|
||||
return []
|
||||
if isinstance(self._obj, list):
|
||||
return list(self._obj)
|
||||
return [self._obj]
|
||||
|
||||
|
||||
class _FakeExecuteResult:
|
||||
def __init__(self, rows: Sequence[tuple[Any, ...]]):
|
||||
self._rows = list(rows)
|
||||
|
||||
def all(self) -> list[tuple[Any, ...]]:
|
||||
return list(self._rows)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scalars_result: Any = None,
|
||||
scalars_results: list[Any] | None = None,
|
||||
forms: dict[str, _DummyForm] | None = None,
|
||||
recipients: dict[str, _DummyRecipient] | None = None,
|
||||
execute_rows: Sequence[tuple[Any, ...]] = (),
|
||||
):
|
||||
if scalars_results is not None:
|
||||
self._scalars_queue = list(scalars_results)
|
||||
else:
|
||||
self._scalars_queue = [scalars_result]
|
||||
self._forms = forms or {}
|
||||
self._recipients = recipients or {}
|
||||
self._execute_rows = list(execute_rows)
|
||||
self.added: list[Any] = []
|
||||
|
||||
def scalars(self, _query: Any) -> _FakeScalarResult:
|
||||
if self._scalars_queue:
|
||||
value = self._scalars_queue.pop(0)
|
||||
else:
|
||||
value = None
|
||||
return _FakeScalarResult(value)
|
||||
|
||||
def execute(self, _stmt: Any) -> _FakeExecuteResult:
|
||||
return _FakeExecuteResult(self._execute_rows)
|
||||
|
||||
def get(self, model_cls: Any, obj_id: str) -> Any:
|
||||
name = getattr(model_cls, "__name__", "")
|
||||
if name == "HumanInputForm":
|
||||
return self._forms.get(obj_id)
|
||||
if name == "HumanInputFormRecipient":
|
||||
return self._recipients.get(obj_id)
|
||||
return None
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def add_all(self, objs: Sequence[Any]) -> None:
|
||||
self.added.extend(list(objs))
|
||||
|
||||
def flush(self) -> None:
|
||||
# Simulate DB default population for attributes referenced in entity wrappers.
|
||||
for obj in self.added:
|
||||
if hasattr(obj, "id") and obj.id in (None, ""):
|
||||
obj.id = f"gen-{len(str(self.added))}"
|
||||
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
|
||||
if obj.recipient_type == RecipientType.CONSOLE:
|
||||
obj.access_token = "token-console"
|
||||
elif obj.recipient_type == RecipientType.BACKSTAGE:
|
||||
obj.access_token = "token-backstage"
|
||||
else:
|
||||
obj.access_token = "token-webapp"
|
||||
|
||||
def refresh(self, _obj: Any) -> None:
|
||||
return None
|
||||
|
||||
def begin(self) -> _FakeSession:
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _FakeSession:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _SessionFactoryStub:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def create_session(self) -> _FakeSession:
|
||||
return self._session
|
||||
|
||||
|
||||
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
|
||||
|
||||
|
||||
def test_recipient_entity_token_raises_when_missing() -> None:
|
||||
recipient = SimpleNamespace(id="r1", access_token=None)
|
||||
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
|
||||
with pytest.raises(AssertionError, match="access_token should not be None"):
|
||||
_ = entity.token
|
||||
|
||||
|
||||
def test_recipient_entity_id_and_token_success() -> None:
|
||||
recipient = SimpleNamespace(id="r1", access_token="tok")
|
||||
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
|
||||
assert entity.id == "r1"
|
||||
assert entity.token == "tok"
|
||||
|
||||
|
||||
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
|
||||
webapp = _DummyRecipient(
|
||||
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
|
||||
)
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "ctok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "wtok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token is None
|
||||
|
||||
|
||||
def test_form_entity_submitted_data_parsed() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
submitted_data='{"a": 1}',
|
||||
submitted_at=naive_utc_now(),
|
||||
)
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.submitted is True
|
||||
assert entity.submitted_data == {"a": 1}
|
||||
assert entity.rendered_content == "<p>x</p>"
|
||||
assert entity.selected_action_id is None
|
||||
assert entity.status == HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
|
||||
expiration = naive_utc_now()
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=False),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=expiration,
|
||||
submitted_data='{"k": "v"}',
|
||||
)
|
||||
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
|
||||
assert record.definition.expiration_time == expiration
|
||||
assert record.submitted_data == {"k": "v"}
|
||||
assert record.submitted is False
|
||||
|
||||
|
||||
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: list[SimpleNamespace] = []
|
||||
|
||||
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
|
||||
recipient = SimpleNamespace(
|
||||
id=f"{payload.TYPE}-{len(created)}",
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
access_token="tok",
|
||||
)
|
||||
created.append(recipient)
|
||||
return recipient
|
||||
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
|
||||
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
members=[
|
||||
_WorkspaceMemberInfo(user_id="u1", email=""),
|
||||
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
|
||||
],
|
||||
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
|
||||
)
|
||||
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
|
||||
|
||||
|
||||
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
|
||||
|
||||
|
||||
def test_query_workspace_members_by_ids_maps_rows() -> None:
|
||||
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
|
||||
assert rows == [
|
||||
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
|
||||
]
|
||||
|
||||
|
||||
def test_query_all_workspace_members_maps_rows() -> None:
|
||||
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
rows = repo._query_all_workspace_members(session=session)
|
||||
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
|
||||
|
||||
|
||||
def test_repository_init_sets_tenant_id() -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo._tenant_id == "tenant"
|
||||
|
||||
|
||||
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
|
||||
result = repo._delivery_method_to_model(
|
||||
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
|
||||
)
|
||||
assert result.delivery.id == "del-1"
|
||||
assert result.delivery.form_id == "form-1"
|
||||
assert len(result.recipients) == 1
|
||||
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
|
||||
|
||||
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
|
||||
called: dict[str, Any] = {}
|
||||
|
||||
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
|
||||
called.update(
|
||||
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
|
||||
)
|
||||
return ["r"]
|
||||
|
||||
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
subject="s",
|
||||
body="b",
|
||||
)
|
||||
)
|
||||
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
|
||||
assert result.recipients == ["r"]
|
||||
assert called["delivery_id"] == "del-1"
|
||||
|
||||
|
||||
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr(
|
||||
repo,
|
||||
"_query_all_workspace_members",
|
||||
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
|
||||
)
|
||||
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
|
||||
recipients = repo._build_email_recipients(
|
||||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
||||
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
|
||||
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
|
||||
assert restrict_to_user_ids == ["u1"]
|
||||
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
|
||||
|
||||
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
|
||||
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
|
||||
recipients = repo._build_email_recipients(
|
||||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
||||
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo.get_form("run", "node") is None
|
||||
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="r1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="tok",
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
entity = repo.get_form("run", "node")
|
||||
assert entity is not None
|
||||
assert entity.id == "f1"
|
||||
assert entity.recipients[0].id == "r1"
|
||||
assert entity.recipients[0].token == "tok"
|
||||
|
||||
|
||||
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
|
||||
|
||||
session = _FakeSession()
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
|
||||
form_config = HumanInputNodeData(
|
||||
title="Title",
|
||||
delivery_methods=[],
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
params = FormCreateParams(
|
||||
app_id="app",
|
||||
workflow_execution_id="run",
|
||||
node_id="node",
|
||||
form_config=form_config,
|
||||
rendered_content="<p>hello</p>",
|
||||
delivery_methods=[WebAppDeliveryMethod()],
|
||||
display_in_ui=True,
|
||||
resolved_default_values={},
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
console_recipient_required=True,
|
||||
console_creator_account_id="acc-1",
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
|
||||
entity = repo.create_form(params)
|
||||
assert entity.id == "form-id"
|
||||
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
|
||||
# Console token should take precedence when console recipient is present.
|
||||
assert entity.web_app_token == "token-console"
|
||||
assert len(entity.recipients) == 3
|
||||
|
||||
|
||||
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_token("tok") is None
|
||||
|
||||
recipient = SimpleNamespace(form=None)
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_token("tok") is None
|
||||
|
||||
|
||||
def test_submission_repository_init_no_args() -> None:
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert isinstance(repo, HumanInputFormSubmissionRepository)
|
||||
|
||||
|
||||
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = SimpleNamespace(
|
||||
id="r1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="tok",
|
||||
form=form,
|
||||
)
|
||||
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.get_by_token("tok")
|
||||
assert record is not None
|
||||
assert record.access_token == "tok"
|
||||
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
assert record is not None
|
||||
assert record.recipient_id == "r1"
|
||||
|
||||
|
||||
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
|
||||
|
||||
|
||||
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
missing_session = _FakeSession(forms={})
|
||||
_patch_session_factory(monkeypatch, missing_session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form not found"):
|
||||
repo.mark_submitted(
|
||||
form_id="missing",
|
||||
recipient_id=None,
|
||||
selected_action_id="a",
|
||||
form_data={},
|
||||
submission_user_id=None,
|
||||
submission_end_user_id=None,
|
||||
)
|
||||
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=fixed_now,
|
||||
)
|
||||
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
|
||||
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=recipient.id,
|
||||
selected_action_id="approve",
|
||||
form_data={"k": "v"},
|
||||
submission_user_id="u",
|
||||
submission_end_user_id="eu",
|
||||
)
|
||||
assert form.status == HumanInputFormStatus.SUBMITTED
|
||||
assert form.submitted_at == fixed_now
|
||||
assert record.submitted_data == {"k": "v"}
|
||||
|
||||
|
||||
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
|
||||
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
status=HumanInputFormStatus.TIMEOUT,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
|
||||
assert record.status == HumanInputFormStatus.TIMEOUT
|
||||
|
||||
|
||||
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form already submitted"):
|
||||
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
|
||||
|
||||
|
||||
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
selected_action_id="a",
|
||||
submitted_data="{}",
|
||||
submission_user_id="u",
|
||||
submission_end_user_id="eu",
|
||||
completed_by_recipient_id="r",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
|
||||
assert form.status == HumanInputFormStatus.EXPIRED
|
||||
assert form.selected_action_id is None
|
||||
assert form.submitted_data is None
|
||||
assert form.submission_user_id is None
|
||||
assert form.submission_end_user_id is None
|
||||
assert form.completed_by_recipient_id is None
|
||||
assert record.status == HumanInputFormStatus.EXPIRED
|
||||
|
||||
|
||||
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form not found"):
|
||||
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)
|
||||
|
|
@ -1,84 +1,291 @@
|
|||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, WorkflowRun
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from models import Account, CreatorUserRole, EndUser, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = str(uuid4())
|
||||
user.current_tenant_id = str(uuid4())
|
||||
|
||||
repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=real_session_factory,
|
||||
user=user,
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = False
|
||||
repository._session_factory = MagicMock(return_value=session_context)
|
||||
return repository
|
||||
|
||||
|
||||
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
|
||||
return WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "hello"},
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist():
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
session_factory.return_value.__enter__.return_value = session
|
||||
return session_factory
|
||||
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists():
|
||||
session = MagicMock()
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""Mock SQLAlchemy Engine."""
|
||||
return MagicMock(spec=Engine)
|
||||
|
||||
execution_id = str(uuid4())
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repository._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
session.get.return_value = existing_run
|
||||
|
||||
execution = _build_execution(
|
||||
execution_id=execution_id,
|
||||
started_at=datetime(2026, 1, 1, 12, 30, 0),
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
outputs={"output1": "result1"},
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
error_message="",
|
||||
total_tokens=100,
|
||||
total_steps=5,
|
||||
exceptions_count=0,
|
||||
started_at=datetime.now(UTC),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
class TestSQLAlchemyWorkflowExecutionRepository:
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
app_id = "test_app_id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
assert repo._session_factory == mock_session_factory
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
|
||||
|
||||
def test_init_with_engine(self, mock_engine, mock_account):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_engine,
|
||||
user=mock_account,
|
||||
app_id="test_app_id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert isinstance(repo._session_factory, sessionmaker)
|
||||
assert repo._session_factory.kw["bind"] == mock_engine
|
||||
|
||||
def test_init_invalid_session_factory(self, mock_account):
|
||||
with pytest.raises(ValueError, match="Invalid session_factory type"):
|
||||
SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
def test_init_no_tenant_id(self, mock_session_factory):
|
||||
user = MagicMock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
|
||||
)
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
assert repo._creator_user_role == CreatorUserRole.END_USER
|
||||
|
||||
def test_to_domain_model(self, mock_session_factory, mock_account):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
db_model = MagicMock(spec=WorkflowRun)
|
||||
db_model.id = str(uuid4())
|
||||
db_model.workflow_id = str(uuid4())
|
||||
db_model.type = "workflow"
|
||||
db_model.version = "1.0"
|
||||
db_model.inputs_dict = {"in": "val"}
|
||||
db_model.outputs_dict = {"out": "val"}
|
||||
db_model.graph_dict = {"nodes": []}
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = "some error"
|
||||
db_model.total_tokens = 50
|
||||
db_model.total_steps = 3
|
||||
db_model.exceptions_count = 1
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = datetime.now(UTC)
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
assert domain_model.id_ == db_model.id
|
||||
assert domain_model.workflow_id == db_model.workflow_id
|
||||
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert domain_model.inputs == db_model.inputs_dict
|
||||
assert domain_model.error_message == "some error"
|
||||
|
||||
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
|
||||
# Make elapsed time deterministic to avoid flaky tests
|
||||
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
assert db_model.id == sample_workflow_execution.id_
|
||||
assert db_model.tenant_id == repo._tenant_id
|
||||
assert db_model.app_id == "test_app"
|
||||
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
assert db_model.status == sample_workflow_execution.status.value
|
||||
assert db_model.total_tokens == sample_workflow_execution.total_tokens
|
||||
assert db_model.elapsed_time == 10.0
|
||||
|
||||
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Test with empty/None fields
|
||||
sample_workflow_execution.graph = None
|
||||
sample_workflow_execution.inputs = None
|
||||
sample_workflow_execution.outputs = None
|
||||
sample_workflow_execution.error_message = None
|
||||
sample_workflow_execution.finished_at = None
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
assert db_model.graph is None
|
||||
assert db_model.inputs is None
|
||||
assert db_model.outputs is None
|
||||
assert db_model.error is None
|
||||
assert db_model.elapsed_time == 0
|
||||
|
||||
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
assert not hasattr(db_model, "app_id") or db_model.app_id is None
|
||||
assert db_model.tenant_id == repo._tenant_id
|
||||
|
||||
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
# Test triggered_from missing
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
repo._creator_user_id = "some_id"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.merge.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
# Check cache
|
||||
assert sample_workflow_execution.id_ in repo._execution_cache
|
||||
cached_model = repo._execution_cache[sample_workflow_execution.id_]
|
||||
assert cached_model.id == sample_workflow_execution.id_
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
sample_workflow_execution.started_at = started_at
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = None
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
execution_id = sample_workflow_execution.id_
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repo._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = existing_run
|
||||
|
||||
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,772 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
_deterministic_json_dump,
|
||||
_filter_by_offload_type,
|
||||
_find_first,
|
||||
_replace_or_append_offload,
|
||||
)
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models import Account, EndUser
|
||||
from models.enums import ExecutionOffLoadType
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
return user
|
||||
|
||||
|
||||
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = user_id
|
||||
user.tenant_id = tenant_id
|
||||
return user
|
||||
|
||||
|
||||
def _execution(
|
||||
*,
|
||||
execution_id: str = "exec-id",
|
||||
node_execution_id: str = "node-exec-id",
|
||||
workflow_run_id: str = "run-id",
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
node_execution_id=node_execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Title",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
status=status,
|
||||
error=None,
|
||||
elapsed_time=1.0,
|
||||
metadata=metadata,
|
||||
created_at=datetime.now(UTC),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
|
||||
class _SessionCtx:
|
||||
def __init__(self, session: Any):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _session_factory(session: Any) -> sessionmaker:
|
||||
factory = Mock(spec=sessionmaker)
|
||||
factory.return_value = _SessionCtx(session)
|
||||
return factory
|
||||
|
||||
|
||||
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
engine: Engine = create_engine("sqlite:///:memory:")
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=engine,
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert isinstance(repo._session_factory, sessionmaker)
|
||||
|
||||
sm = Mock(spec=sessionmaker)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=sm,
|
||||
user=_mock_end_user(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert repo._creator_user_role.value == "end_user"
|
||||
|
||||
|
||||
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Invalid session_factory type"):
|
||||
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
|
||||
session_factory=object(),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
user = _mock_account()
|
||||
user.current_tenant_id = None
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=user,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: dict[str, Any] = {}
|
||||
|
||||
class FakeTruncator:
|
||||
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
|
||||
created.update(
|
||||
{
|
||||
"max_size_bytes": max_size_bytes,
|
||||
"array_element_limit": array_element_limit,
|
||||
"string_length_limit": string_length_limit,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
|
||||
FakeTruncator,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
_ = repo._create_truncator()
|
||||
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
|
||||
|
||||
|
||||
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
|
||||
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
|
||||
assert _find_first([], lambda _: True) is None
|
||||
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
|
||||
|
||||
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
|
||||
|
||||
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
|
||||
assert len(replaced) == 2
|
||||
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
|
||||
|
||||
|
||||
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
|
||||
|
||||
# Happy path: deterministic json dump should be sorted
|
||||
db_model = repo._to_db_model(execution)
|
||||
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
|
||||
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
|
||||
|
||||
repo._triggered_from = None
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
|
||||
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution()
|
||||
db_model = repo._to_db_model(execution)
|
||||
assert db_model.app_id == "app"
|
||||
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
repo._creator_user_id = "user"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
|
||||
def test_is_duplicate_key_error_and_regenerate_id(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
unique = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError("dup", params=None, orig=unique)
|
||||
assert repo._is_duplicate_key_error(duplicate_error) is True
|
||||
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
|
||||
|
||||
execution = _execution(execution_id="old-id")
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "old-id"
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
|
||||
caplog.set_level(logging.WARNING)
|
||||
repo._regenerate_id_on_duplicate(execution, db_model)
|
||||
assert execution.id == "new-id"
|
||||
assert db_model.id == "new-id"
|
||||
assert any("Duplicate key conflict" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id1"
|
||||
db_model.node_execution_id = "node1"
|
||||
db_model.foo = "bar" # type: ignore[attr-defined]
|
||||
db_model.__dict__["_private"] = "x"
|
||||
|
||||
existing = SimpleNamespace()
|
||||
session.get.return_value = existing
|
||||
repo._persist_to_database(db_model)
|
||||
assert existing.foo == "bar"
|
||||
session.add.assert_not_called()
|
||||
assert repo._node_execution_cache["node1"] is db_model
|
||||
|
||||
session.reset_mock()
|
||||
session.get.return_value = None
|
||||
repo._node_execution_cache.clear()
|
||||
repo._persist_to_database(db_model)
|
||||
session.add.assert_called_once_with(db_model)
|
||||
assert repo._node_execution_cache["node1"] is db_model
|
||||
|
||||
|
||||
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
|
||||
|
||||
class FakeTruncator:
|
||||
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
|
||||
return value, False
|
||||
|
||||
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
|
||||
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
|
||||
|
||||
|
||||
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
uploaded: dict[str, Any] = {}
|
||||
|
||||
class FakeFileService:
|
||||
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
|
||||
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
|
||||
return SimpleNamespace(id="file-id", key="file-key")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
class FakeTruncator:
|
||||
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
|
||||
return {"truncated": True}, True
|
||||
|
||||
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
|
||||
|
||||
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
|
||||
assert result is not None
|
||||
assert result.truncated_value == {"truncated": True}
|
||||
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
|
||||
assert result.offload.file_id == "file-id"
|
||||
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
|
||||
|
||||
|
||||
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id"
|
||||
db_model.node_execution_id = "node-exec"
|
||||
db_model.workflow_id = "wf"
|
||||
db_model.workflow_run_id = "run"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"trunc": "i"})
|
||||
db_model.process_data = json.dumps({"trunc": "p"})
|
||||
db_model.outputs = json.dumps({"trunc": "o"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 0.1
|
||||
db_model.execution_metadata = json.dumps({"total_tokens": 3})
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
|
||||
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
|
||||
off_in.file = SimpleNamespace(key="k-in")
|
||||
off_out.file = SimpleNamespace(key="k-out")
|
||||
off_proc.file = SimpleNamespace(key="k-proc")
|
||||
db_model.offload_data = [off_out, off_in, off_proc]
|
||||
|
||||
def fake_load(key: str) -> bytes:
|
||||
return json.dumps({"full": key}).encode()
|
||||
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
|
||||
|
||||
domain = repo._to_domain_model(db_model)
|
||||
assert domain.inputs == {"full": "k-in"}
|
||||
assert domain.outputs == {"full": "k-out"}
|
||||
assert domain.process_data == {"full": "k-proc"}
|
||||
assert domain.get_truncated_inputs() == {"trunc": "i"}
|
||||
assert domain.get_truncated_outputs() == {"trunc": "o"}
|
||||
assert domain.get_truncated_process_data() == {"trunc": "p"}
|
||||
|
||||
|
||||
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id"
|
||||
db_model.node_execution_id = "node-exec"
|
||||
db_model.workflow_id = "wf"
|
||||
db_model.workflow_run_id = "run"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"i": 1})
|
||||
db_model.process_data = json.dumps({"p": 2})
|
||||
db_model.outputs = json.dumps({"o": 3})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 0.1
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
db_model.offload_data = []
|
||||
|
||||
domain = repo._to_domain_model(db_model)
|
||||
assert domain.inputs == {"i": 1}
|
||||
assert domain.outputs == {"o": 3}
|
||||
|
||||
|
||||
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FakeConverter:
|
||||
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return {"wrapped": values["a"]}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
|
||||
FakeConverter,
|
||||
)
|
||||
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
|
||||
|
||||
|
||||
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
|
||||
id="id",
|
||||
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
process_data=None,
|
||||
)
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
|
||||
|
||||
trunc_result = SimpleNamespace(
|
||||
truncated_value={"trunc": True},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
|
||||
)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
# Inputs should be truncated, outputs/process_data encoded directly
|
||||
db_model = session.merge.call_args.args[0]
|
||||
assert json.loads(db_model.inputs) == {"trunc": True}
|
||||
assert json.loads(db_model.outputs) == {"b": 2}
|
||||
assert json.loads(db_model.process_data) == {"c": 3}
|
||||
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
|
||||
assert execution.get_truncated_inputs() == {"trunc": True}
|
||||
|
||||
|
||||
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
existing = SimpleNamespace(
|
||||
id="id",
|
||||
offload_data=[],
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
process_data=None,
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = existing
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
|
||||
|
||||
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
|
||||
if values == {"b": 2}:
|
||||
return SimpleNamespace(
|
||||
truncated_value={"b": "trunc"},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
|
||||
)
|
||||
if values == {"c": 3}:
|
||||
return SimpleNamespace(
|
||||
truncated_value={"c": "trunc"},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
|
||||
)
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
db_model = session.merge.call_args.args[0]
|
||||
assert json.loads(db_model.outputs) == {"b": "trunc"}
|
||||
assert json.loads(db_model.process_data) == {"c": "trunc"}
|
||||
assert execution.get_truncated_outputs() == {"b": "trunc"}
|
||||
assert execution.get_truncated_process_data() == {"c": "trunc"}
|
||||
|
||||
|
||||
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1})
|
||||
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
|
||||
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
|
||||
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
merged = session.merge.call_args.args[0]
|
||||
assert merged.inputs == '{"a": 1}'
|
||||
|
||||
|
||||
def test_save_retries_duplicate_and_logs_non_duplicate(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(execution_id="id")
|
||||
unique = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError("dup", params=None, orig=unique)
|
||||
other_error = IntegrityError("other", params=None, orig=None)
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
def persist(_db_model: Any) -> None:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise duplicate_error
|
||||
|
||||
monkeypatch.setattr(repo, "_persist_to_database", persist)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
|
||||
repo.save(execution)
|
||||
assert execution.id == "new-id"
|
||||
assert repo._node_execution_cache[execution.node_execution_id] is not None
|
||||
|
||||
caplog.set_level(logging.ERROR)
|
||||
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
|
||||
with pytest.raises(IntegrityError):
|
||||
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
|
||||
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_save_logs_and_reraises_on_unexpected_error(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
caplog.set_level(logging.ERROR)
|
||||
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
|
||||
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
class FakeStmt:
|
||||
def __init__(self) -> None:
|
||||
self.where_calls = 0
|
||||
self.order_by_args: tuple[Any, ...] | None = None
|
||||
|
||||
def where(self, *_args: Any) -> FakeStmt:
|
||||
self.where_calls += 1
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any) -> FakeStmt:
|
||||
self.order_by_args = args
|
||||
return self
|
||||
|
||||
stmt = FakeStmt()
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
|
||||
lambda _q: stmt,
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
|
||||
|
||||
model1 = SimpleNamespace(node_execution_id="n1")
|
||||
model2 = SimpleNamespace(node_execution_id=None)
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [model1, model2]
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
|
||||
db_models = repo.get_db_models_by_workflow_run("run", order)
|
||||
assert db_models == [model1, model2]
|
||||
assert repo._node_execution_cache["n1"] is model1
|
||||
assert stmt.order_by_args is not None
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
class FakeStmt:
|
||||
def where(self, *_args: Any) -> FakeStmt:
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any) -> FakeStmt:
|
||||
self.args = args # type: ignore[attr-defined]
|
||||
return self
|
||||
|
||||
stmt = FakeStmt()
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
|
||||
lambda _q: stmt,
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
|
||||
|
||||
|
||||
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
|
||||
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
|
||||
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
|
||||
|
||||
class FakeExecutor:
|
||||
def __enter__(self) -> FakeExecutor:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
|
||||
assert timeout == 30
|
||||
return list(map(func, items))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
|
||||
lambda max_workers: FakeExecutor(),
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("run", order_config=None)
|
||||
assert result == ["domain:db1", "domain:db2"]
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
|
||||
|
||||
class TestSchemaRegistry:
|
||||
def test_initialization(self, tmp_path):
|
||||
base_dir = tmp_path / "schemas"
|
||||
base_dir.mkdir()
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
assert registry.base_dir == base_dir
|
||||
assert registry.versions == {}
|
||||
assert registry.metadata == {}
|
||||
|
||||
def test_default_registry_singleton(self):
|
||||
registry1 = SchemaRegistry.default_registry()
|
||||
registry2 = SchemaRegistry.default_registry()
|
||||
assert registry1 is registry2
|
||||
assert isinstance(registry1, SchemaRegistry)
|
||||
|
||||
def test_load_all_versions_non_existent_dir(self, tmp_path):
|
||||
base_dir = tmp_path / "non_existent"
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
registry.load_all_versions()
|
||||
assert registry.versions == {}
|
||||
|
||||
def test_load_all_versions_filtering(self, tmp_path):
|
||||
base_dir = tmp_path / "schemas"
|
||||
base_dir.mkdir()
|
||||
(base_dir / "not_a_version_dir").mkdir()
|
||||
(base_dir / "v1").mkdir()
|
||||
(base_dir / "some_file.txt").write_text("content")
|
||||
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
with patch.object(registry, "_load_version_dir") as mock_load:
|
||||
registry.load_all_versions()
|
||||
mock_load.assert_called_once()
|
||||
assert mock_load.call_args[0][0] == "v1"
|
||||
|
||||
def test_load_version_dir_filtering(self, tmp_path):
|
||||
version_dir = tmp_path / "v1"
|
||||
version_dir.mkdir()
|
||||
(version_dir / "schema1.json").write_text("{}")
|
||||
(version_dir / "not_a_schema.txt").write_text("content")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
with patch.object(registry, "_load_schema") as mock_load:
|
||||
registry._load_version_dir("v1", version_dir)
|
||||
mock_load.assert_called_once()
|
||||
assert mock_load.call_args[0][1] == "schema1"
|
||||
|
||||
def test_load_version_dir_non_existent(self, tmp_path):
|
||||
version_dir = tmp_path / "non_existent"
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry._load_version_dir("v1", version_dir)
|
||||
assert "v1" not in registry.versions
|
||||
|
||||
def test_load_schema_success(self, tmp_path):
|
||||
schema_path = tmp_path / "test.json"
|
||||
schema_content = {"title": "Test Schema", "description": "A test schema"}
|
||||
schema_path.write_text(json.dumps(schema_content))
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
registry._load_schema("v1", "test", schema_path)
|
||||
|
||||
assert registry.versions["v1"]["test"] == schema_content
|
||||
uri = "https://dify.ai/schemas/v1/test.json"
|
||||
assert registry.metadata[uri]["title"] == "Test Schema"
|
||||
assert registry.metadata[uri]["version"] == "v1"
|
||||
|
||||
def test_load_schema_invalid_json(self, tmp_path, caplog):
|
||||
schema_path = tmp_path / "invalid.json"
|
||||
schema_path.write_text("invalid json")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
registry._load_schema("v1", "invalid", schema_path)
|
||||
|
||||
assert "Failed to load schema v1/invalid" in caplog.text
|
||||
|
||||
def test_load_schema_os_error(self, tmp_path, caplog):
|
||||
schema_path = tmp_path / "error.json"
|
||||
schema_path.write_text("{}")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
|
||||
with patch("builtins.open", side_effect=OSError("Read error")):
|
||||
registry._load_schema("v1", "error", schema_path)
|
||||
|
||||
assert "Failed to load schema v1/error" in caplog.text
|
||||
|
||||
def test_get_schema(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"test": {"type": "object"}}}
|
||||
|
||||
# Valid URI
|
||||
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
|
||||
|
||||
# Invalid URI
|
||||
assert registry.get_schema("invalid-uri") is None
|
||||
|
||||
# Missing version
|
||||
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
|
||||
|
||||
def test_list_versions(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v2": {}, "v1": {}}
|
||||
assert registry.list_versions() == ["v1", "v2"]
|
||||
|
||||
def test_list_schemas(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"b": {}, "a": {}}}
|
||||
|
||||
assert registry.list_schemas("v1") == ["a", "b"]
|
||||
assert registry.list_schemas("v2") == []
|
||||
|
||||
def test_get_all_schemas_for_version(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
|
||||
|
||||
results = registry.get_all_schemas_for_version("v1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "test"
|
||||
assert results[0]["label"] == "Test Label"
|
||||
assert results[0]["schema"] == {"title": "Test Label"}
|
||||
|
||||
# Default label if title missing
|
||||
registry.versions["v1"]["no_title"] = {}
|
||||
results = registry.get_all_schemas_for_version("v1")
|
||||
item = next(r for r in results if r["name"] == "no_title")
|
||||
assert item["label"] == "no_title"
|
||||
|
||||
# Empty if version missing
|
||||
assert registry.get_all_schemas_for_version("v2") == []
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
|
||||
|
||||
def test_init_with_provided_registry():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
assert manager.registry == mock_registry
|
||||
|
||||
|
||||
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
|
||||
def test_init_with_default_registry(mock_default_registry):
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_default_registry.return_value = mock_registry
|
||||
|
||||
manager = SchemaManager()
|
||||
|
||||
mock_default_registry.assert_called_once()
|
||||
assert manager.registry == mock_registry
|
||||
|
||||
|
||||
def test_get_all_schema_definitions():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
|
||||
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_all_schema_definitions(version="v2")
|
||||
|
||||
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
|
||||
assert result == expected_definitions
|
||||
|
||||
|
||||
def test_get_schema_by_name_success():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_schema = {"type": "object"}
|
||||
mock_registry.get_schema.return_value = mock_schema
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_schema_by_name("my_schema", version="v1")
|
||||
|
||||
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
|
||||
mock_registry.get_schema.assert_called_once_with(expected_uri)
|
||||
assert result == {"name": "my_schema", "schema": mock_schema}
|
||||
|
||||
|
||||
def test_get_schema_by_name_not_found():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_registry.get_schema.return_value = None
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_schema_by_name("non_existent", version="v1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_list_available_schemas():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_schemas = ["schema1", "schema2"]
|
||||
mock_registry.list_schemas.return_value = expected_schemas
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.list_available_schemas(version="v1")
|
||||
|
||||
mock_registry.list_schemas.assert_called_once_with("v1")
|
||||
assert result == expected_schemas
|
||||
|
||||
|
||||
def test_list_available_versions():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_versions = ["v1", "v2"]
|
||||
mock_registry.list_versions.return_value = expected_versions
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.list_available_versions()
|
||||
|
||||
mock_registry.list_versions.assert_called_once()
|
||||
assert result == expected_versions
|
||||
Loading…
Reference in New Issue