diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index b38dfdfc1f..66037696af 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -138,20 +138,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query = self.application_generate_entity.query # moderation - if self.handle_input_moderation( + stop, new_inputs, new_query = self.handle_input_moderation( app_record=self._app, app_generate_entity=self.application_generate_entity, inputs=inputs, query=query, message_id=self.message.id, - ): + ) + if stop: return + self.application_generate_entity.inputs = new_inputs + self.application_generate_entity.query = new_query + system_inputs.query = new_query + # annotation reply if self.handle_annotation_reply( app_record=self._app, message=self.message, - query=query, + query=new_query, app_generate_entity=self.application_generate_entity, ): return @@ -163,7 +168,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # init variable pool variable_pool = VariablePool( system_variables=system_inputs, - user_inputs=inputs, + user_inputs=new_inputs, environment_variables=self._workflow.environment_variables, # Based on the definition of `Variable`, # `VariableBase` instances can be safely used as `Variable` since they are compatible. @@ -240,10 +245,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): inputs: Mapping[str, Any], query: str, message_id: str, - ) -> bool: + ) -> tuple[bool, Mapping[str, Any], str]: try: # process sensitive_word_avoidance - _, inputs, query = self.moderation_for_inputs( + _, new_inputs, new_query = self.moderation_for_inputs( app_id=app_record.id, tenant_id=app_generate_entity.app_config.tenant_id, app_generate_entity=app_generate_entity, @@ -253,9 +258,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) except ModerationError as e: self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) - return True + return True, inputs, query - return False + return False, new_inputs, new_query def handle_annotation_reply( self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 12ab587564..15aceef2c7 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -125,7 +125,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, @@ -265,7 +269,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, @@ -412,7 +420,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py new file mode 100644 index 0000000000..5792a2f1e2 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -0,0 +1,170 @@ +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import QueueStopEvent +from core.moderation.base import ModerationError + + +@pytest.fixture +def build_runner(): + """Construct a minimal AdvancedChatAppRunner with heavy dependencies mocked.""" + app_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Mocks for constructor args + mock_queue_manager = MagicMock() + + mock_conversation = MagicMock() + mock_conversation.id = str(uuid4()) + mock_conversation.app_id = app_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + mock_workflow = MagicMock() + mock_workflow.id = workflow_id + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + gen = MagicMock(spec=AdvancedChatAppGenerateEntity) + gen.app_config = mock_app_config + gen.inputs = {"q": "raw"} + gen.query = "raw-query" + gen.files = [] + gen.user_id = str(uuid4()) + gen.invoke_from = InvokeFrom.SERVICE_API + gen.workflow_run_id = str(uuid4()) + gen.task_id = str(uuid4()) + gen.call_depth = 0 + gen.single_iteration_run = None + gen.single_loop_run = None + gen.trace_manager = None + + runner = AdvancedChatAppRunner( + application_generate_entity=gen, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + return runner + + +def _patch_common_run_deps(runner: AdvancedChatAppRunner): + """Context manager that patches common heavy deps used by run().""" + return patch.multiple( + "core.app.apps.advanced_chat.app_runner", + Session=MagicMock( + return_value=MagicMock( + __enter__=lambda s: s, + __exit__=lambda *a, **k: False, + scalar=lambda *a, **k: MagicMock(), + ), + ), + select=MagicMock(), + db=MagicMock(engine=MagicMock()), + RedisChannel=MagicMock(), + redis_client=MagicMock(), + WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}), + GraphRuntimeState=MagicMock(), + ) + + +def test_handle_input_moderation_stops_on_moderation_error(build_runner): + runner = build_runner + + # moderation_for_inputs raises ModerationError -> should stop and emit stop event + with ( + patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("blocked")), + patch.object(runner, "_complete_with_stream_output") as mock_complete, + ): + stop, new_inputs, new_query = runner.handle_input_moderation( + app_record=MagicMock(), + app_generate_entity=runner.application_generate_entity, + inputs={"k": "v"}, + query="hello", + message_id="mid", + ) + + assert stop is True + # inputs/query should be unchanged on error path + assert new_inputs == {"k": "v"} + assert new_query == "hello" + # ensure stopped_by reason is INPUT_MODERATION + assert mock_complete.called + args, kwargs = mock_complete.call_args + assert kwargs.get("stopped_by") == QueueStopEvent.StopBy.INPUT_MODERATION + + +def test_run_applies_overridden_inputs_and_query_from_moderation(build_runner): + runner = build_runner + + overridden_inputs = {"q": "sanitized"} + overridden_query = "sanitized-query" + + with ( + _patch_common_run_deps(runner), + patch.object( + runner, + "moderation_for_inputs", + return_value=(True, overridden_inputs, overridden_query), + ) as mock_moderate, + patch.object(runner, "handle_annotation_reply", return_value=False) as mock_anno, + patch.object(runner, "_init_graph", return_value=MagicMock()) as mock_init_graph, + ): + runner.run() + + # moderation called with original values + mock_moderate.assert_called_once() + + # application_generate_entity should be updated to overridden values + assert runner.application_generate_entity.inputs == overridden_inputs + assert runner.application_generate_entity.query == overridden_query + + # annotation reply should use the new query + mock_anno.assert_called() + assert mock_anno.call_args.kwargs.get("query") == overridden_query + + # since not stopped, graph initialization should proceed + assert mock_init_graph.called + + +def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_runner): + runner = build_runner + + with ( + _patch_common_run_deps(runner), + # Simulate handle_input_moderation signalling to stop + patch.object( + runner, + "handle_input_moderation", + return_value=(True, runner.application_generate_entity.inputs, runner.application_generate_entity.query), + ) as mock_handle, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_annotation_reply") as mock_anno, + ): + runner.run() + + mock_handle.assert_called_once() + # Ensure no further steps executed + mock_anno.assert_not_called() + mock_init_graph.assert_not_called()