Approve?
", + user_actions=[UserAction(id="approve", title="Approve")], + ) + return FormCreateParams( + app_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + node_id="human-input-node", + form_config=form_config, + rendered_content="Approve?
", + delivery_methods=delivery_methods, + display_in_ui=False, + resolved_default_values={}, + ) + + +def _build_email_delivery( + whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient] +) -> EmailDeliveryMethod: + return EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + subject="Approval Needed", + body="Please review", + ) + ) + + +class TestHumanInputFormRepositoryImplWithContainers: + def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, members = _create_tenant_with_members( + db_session_with_containers, + member_emails=["member1@example.com", "member2@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = _build_form_params( + delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + recipients = verification_session.scalars( + select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) + ).all() + + assert len(recipients) == len(members) + member_payloads = [ + EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_MEMBER + ] + member_emails = {payload.email for payload in member_payloads} + assert member_emails == {member.email for member in members} + + def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, members = _create_tenant_with_members( + db_session_with_containers, + member_emails=["primary@example.com", "secondary@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = _build_form_params( + delivery_methods=[ + _build_email_delivery( + whole_workspace=False, + recipients=[ + MemberRecipient(user_id=members[0].id), + ExternalRecipient(email="external@example.com"), + ], + ) + ], + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + recipients = verification_session.scalars( + select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id) + ).all() + + member_recipient_payloads = [ + EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_MEMBER + ] + assert len(member_recipient_payloads) == 1 + assert member_recipient_payloads[0].user_id == members[0].id + + external_payloads = [ + EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload) + for recipient in recipients + if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL + ] + assert len(external_payloads) == 1 + assert external_payloads[0].email == "external@example.com" + + def test_create_form_persists_default_values(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, _ = _create_tenant_with_members( + db_session_with_containers, + member_emails=["prefill@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + resolved_values = {"greeting": "Hello!"} + params = FormCreateParams( + app_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + node_id="human-input-node", + form_config=HumanInputNodeData( + title="Human Approval", + form_content="Approve?
", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ), + rendered_content="Approve?
", + delivery_methods=[], + display_in_ui=False, + resolved_default_values=resolved_values, + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + form_model = verification_session.scalars( + select(HumanInputForm).where(HumanInputForm.id == form_entity.id) + ).first() + + assert form_model is not None + definition = FormDefinition.model_validate_json(form_model.form_definition) + assert definition.default_values == resolved_values + + def test_create_form_persists_display_in_ui(self, db_session_with_containers: Session) -> None: + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + tenant, _ = _create_tenant_with_members( + db_session_with_containers, + member_emails=["ui@example.com"], + ) + + repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = FormCreateParams( + app_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + node_id="human-input-node", + form_config=HumanInputNodeData( + title="Human Approval", + form_content="Approve?
", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + delivery_methods=[WebAppDeliveryMethod()], + ), + rendered_content="Approve?
", + delivery_methods=[WebAppDeliveryMethod()], + display_in_ui=True, + resolved_default_values={}, + ) + + form_entity = repository.create_form(params) + + with Session(engine) as verification_session: + form_model = verification_session.scalars( + select(HumanInputForm).where(HumanInputForm.id == form_entity.id) + ).first() + + assert form_model is not None + definition = FormDefinition.model_validate_json(form_model.form_definition) + assert definition.display_in_ui is True diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py new file mode 100644 index 0000000000..06d55177eb --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -0,0 +1,336 @@ +import time +import uuid +from datetime import timedelta +from unittest.mock import MagicMock + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities import GraphInitParams +from core.workflow.enums import WorkflowType +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now +from models import Account +from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.model import App, AppMode, IconType +from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun + + +def _mock_form_repository_without_submission() -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = False + repo.create_form.return_value = form_entity + repo.get_form.return_value = None + return repo + + +def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = True + form_entity.selected_action_id = action_id + form_entity.submitted_data = {} + form_entity.status = HumanInputFormStatus.WAITING + form_entity.expiration_time = naive_utc_now() + timedelta(hours=1) + repo.get_form.return_value = form_entity + return repo + + +def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + workflow_execution_id=workflow_execution_id, + app_id=app_id, + workflow_id=workflow_id, + user_id=user_id, + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph( + runtime_state: GraphRuntimeState, + tenant_id: str, + app_id: str, + workflow_id: str, + user_id: str, + form_repository: HumanInputFormRepository, +) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + params = GraphInitParams( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + graph_config=graph_config, + user_id=user_id, + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_data = StartNodeData(title="start", variables=[]) + start_node = StartNode( + id="start", + config={"id": "start", "data": start_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="human", + form_content="Awaiting human input", + inputs=[], + user_actions=[ + UserAction(id="continue", title="Continue"), + ], + ) + human_node = HumanInputNode( + id="human", + config={"id": "human", "data": human_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + form_repository=form_repository, + ) + + end_data = EndNodeData( + title="end", + outputs=[], + desc=None, + ) + end_node = EndNode( + id="end", + config={"id": "end", "data": end_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(human_node) + .add_node(end_node, from_node_id="human", source_handle="continue") + .build() + ) + + +def _build_generate_entity( + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_execution_id: str, + user_id: str, +) -> WorkflowAppGenerateEntity: + app_config = WorkflowUIBasedAppConfig( + tenant_id=tenant_id, + app_id=app_id, + app_mode=AppMode.WORKFLOW, + workflow_id=workflow_id, + ) + return WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user_id, + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_id=workflow_execution_id, + ) + + +class TestHumanInputResumeNodeExecutionIntegration: + @pytest.fixture(autouse=True) + def setup_test_data(self, db_session_with_containers: Session): + tenant = Tenant( + name="Test Tenant", + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + account = Account( + email="test@example.com", + name="Test User", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + account.current_tenant = tenant + + app = App( + tenant_id=tenant.id, + name="Test App", + description="", + mode=AppMode.WORKFLOW.value, + icon_type=IconType.EMOJI.value, + icon="rocket", + icon_background="#4ECDC4", + enable_site=False, + enable_api=False, + api_rpm=0, + api_rph=0, + is_demo=False, + is_public=False, + is_universal=False, + max_active_requests=None, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + workflow = Workflow( + tenant_id=tenant.id, + app_id=app.id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features='{"file_upload": {"enabled": false}}', + created_by=account.id, + created_at=naive_utc_now(), + ) + db_session_with_containers.add(workflow) + db_session_with_containers.commit() + + self.session = db_session_with_containers + self.tenant = tenant + self.account = account + self.app = app + self.workflow = workflow + + yield + + self.session.execute(delete(WorkflowNodeExecutionModel)) + self.session.execute(delete(WorkflowRun)) + self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id)) + self.session.execute(delete(App).where(App.id == self.app.id)) + self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id)) + self.session.execute(delete(Account).where(Account.id == self.account.id)) + self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id)) + self.session.commit() + + def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer: + generate_entity = _build_generate_entity( + tenant_id=self.tenant.id, + app_id=self.app.id, + workflow_id=self.workflow.id, + workflow_execution_id=execution_id, + user_id=self.account.id, + ) + execution_repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=self.session.get_bind(), + user=self.account, + app_id=self.app.id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=self.session.get_bind(), + user=self.account, + app_id=self.app.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + return WorkflowPersistenceLayer( + application_generate_entity=generate_entity, + workflow_info=PersistenceWorkflowInfo( + workflow_id=self.workflow.id, + workflow_type=WorkflowType.WORKFLOW, + version=self.workflow.version, + graph_data=self.workflow.graph_dict, + ), + workflow_execution_repository=execution_repo, + workflow_node_execution_repository=node_execution_repo, + ) + + def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None: + engine = GraphEngine( + workflow_id=self.workflow.id, + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + engine.layer(self._build_persistence_layer(execution_id)) + for _ in engine.run(): + continue + + def test_resume_human_input_does_not_create_duplicate_node_execution(self): + execution_id = str(uuid.uuid4()) + runtime_state = _build_runtime_state( + workflow_execution_id=execution_id, + app_id=self.app.id, + workflow_id=self.workflow.id, + user_id=self.account.id, + ) + pause_repo = _mock_form_repository_without_submission() + paused_graph = _build_graph( + runtime_state, + self.tenant.id, + self.app.id, + self.workflow.id, + self.account.id, + pause_repo, + ) + self._run_graph(paused_graph, runtime_state, execution_id) + + snapshot = runtime_state.dumps() + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resume_repo = _mock_form_repository_with_submission(action_id="continue") + resumed_graph = _build_graph( + resumed_state, + self.tenant.id, + self.app.id, + self.workflow.id, + self.account.id, + resume_repo, + ) + self._run_graph(resumed_graph, resumed_state, execution_id) + + stmt = select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id == execution_id, + WorkflowNodeExecutionModel.node_id == "human", + ) + records = self.session.execute(stmt).scalars().all() + assert len(records) == 1 + assert records[0].status != "paused" + assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + assert records[0].created_by_role == CreatorUserRole.ACCOUNT diff --git a/api/tests/test_containers_integration_tests/helpers/__init__.py b/api/tests/test_containers_integration_tests/helpers/__init__.py new file mode 100644 index 0000000000..40d03889a9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/helpers/__init__.py @@ -0,0 +1 @@ +"""Helper utilities for integration tests.""" diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py new file mode 100644 index 0000000000..19d7772c39 --- /dev/null +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta +from decimal import Decimal +from uuid import uuid4 + +from core.workflow.nodes.human_input.entities import FormDefinition, UserAction +from models.account import Account, Tenant, TenantAccountJoin +from models.execution_extra_content import HumanInputContent +from models.human_input import HumanInputForm, HumanInputFormStatus +from models.model import App, Conversation, Message + + +@dataclass +class HumanInputMessageFixture: + app: App + account: Account + conversation: Conversation + message: Message + form: HumanInputForm + action_id: str + action_text: str + node_title: str + + +def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session.add(tenant) + db_session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"human_input_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session.add(tenant_join) + db_session.flush() + + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session.add(app) + db_session.flush() + + conversation = Conversation( + app_id=app.id, + mode="chat", + name="Test Conversation", + summary="", + introduction="", + system_instruction="", + status="normal", + invoke_from="console", + from_source="console", + from_account_id=account.id, + from_end_user_id=None, + ) + conversation.inputs = {} + db_session.add(conversation) + db_session.flush() + + workflow_run_id = str(uuid4()) + message = Message( + app_id=app.id, + conversation_id=conversation.id, + inputs={}, + query="Human input query", + message={"messages": []}, + answer="Human input answer", + message_tokens=50, + message_unit_price=Decimal("0.001"), + answer_tokens=80, + answer_unit_price=Decimal("0.001"), + provider_response_latency=0.5, + currency="USD", + from_source="console", + from_account_id=account.id, + workflow_run_id=workflow_run_id, + ) + db_session.add(message) + db_session.flush() + + action_id = "approve" + action_text = "Approve request" + node_title = "Approval" + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id=action_id, title=action_text)], + rendered_content="Rendered block", + expiration_time=datetime.utcnow() + timedelta(days=1), + node_title=node_title, + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=tenant.id, + app_id=app.id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content="Rendered block", + status=HumanInputFormStatus.SUBMITTED, + expiration_time=datetime.utcnow() + timedelta(days=1), + selected_action_id=action_id, + ) + db_session.add(form) + db_session.flush() + + content = HumanInputContent( + workflow_run_id=workflow_run_id, + message_id=message.id, + form_id=form.id, + ) + db_session.add(content) + db_session.commit() + + return HumanInputMessageFixture( + app=app, + account=account, + conversation=conversation, + message=message, + form=form, + action_id=action_id, + action_text=action_text, + node_title=node_title, + ) diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py index d612e70910..43915a204d 100644 --- a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py @@ -16,6 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest import redis +from redis.cluster import RedisCluster from testcontainers.redis import RedisContainer from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic @@ -332,3 +333,95 @@ class TestShardedRedisBroadcastChannelIntegration: # Verify subscriptions are cleaned up topic_subscribers_after = self._get_sharded_numsub(redis_client, topic_name) assert topic_subscribers_after == 0 + + +class TestShardedRedisBroadcastChannelClusterIntegration: + """Integration tests for sharded pub/sub with RedisCluster client.""" + + @pytest.fixture(scope="class") + def redis_cluster_container(self) -> Iterator[RedisContainer]: + """Create a Redis 7 container with cluster mode enabled.""" + command = ( + "redis-server --port 6379 " + "--cluster-enabled yes " + "--cluster-config-file nodes.conf " + "--cluster-node-timeout 5000 " + "--appendonly no " + "--protected-mode no" + ) + with RedisContainer(image="redis:7-alpine").with_command(command) as container: + yield container + + @classmethod + def _get_test_topic_name(cls) -> str: + return f"test_sharded_cluster_topic_{uuid.uuid4()}" + + @staticmethod + def _ensure_single_node_cluster(host: str, port: int) -> None: + client = redis.Redis(host=host, port=port, decode_responses=False) + client.config_set("cluster-announce-ip", host) + client.config_set("cluster-announce-port", port) + slots = client.execute_command("CLUSTER", "SLOTS") + if not slots: + client.execute_command("CLUSTER", "ADDSLOTSRANGE", 0, 16383) + + deadline = time.time() + 5.0 + while time.time() < deadline: + info = client.execute_command("CLUSTER", "INFO") + info_text = info.decode("utf-8") if isinstance(info, (bytes, bytearray)) else str(info) + if "cluster_state:ok" in info_text: + return + time.sleep(0.05) + raise RuntimeError("Redis cluster did not become ready in time") + + @pytest.fixture(scope="class") + def redis_cluster_client(self, redis_cluster_container: RedisContainer) -> RedisCluster: + host = redis_cluster_container.get_container_host_ip() + port = int(redis_cluster_container.get_exposed_port(6379)) + self._ensure_single_node_cluster(host, port) + return RedisCluster(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_cluster_client: RedisCluster) -> BroadcastChannel: + return ShardedRedisBroadcastChannel(redis_cluster_client) + + def test_cluster_sharded_pubsub_delivers_message(self, broadcast_channel: BroadcastChannel): + """Ensure sharded subscription receives messages when using RedisCluster client.""" + topic_name = self._get_test_topic_name() + message = b"cluster sharded message" + + topic = broadcast_channel.topic(topic_name) + producer = topic.as_producer() + subscription = topic.subscribe() + ready_event = threading.Event() + + def consumer_thread() -> list[bytes]: + received = [] + try: + _ = subscription.receive(0.01) + except SubscriptionClosedError: + return received + ready_event.set() + deadline = time.time() + 5.0 + while time.time() < deadline: + msg = subscription.receive(timeout=0.1) + if msg is None: + continue + received.append(msg) + break + subscription.close() + return received + + def producer_thread(): + if not ready_event.wait(timeout=2.0): + pytest.fail("subscriber did not become ready before publish") + producer.publish(message) + + with ThreadPoolExecutor(max_workers=2) as executor: + consumer_future = executor.submit(consumer_thread) + producer_future = executor.submit(producer_thread) + + producer_future.result(timeout=5.0) + received_messages = consumer_future.result(timeout=5.0) + + assert received_messages == [message] diff --git a/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py b/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py new file mode 100644 index 0000000000..178fc2e4fb --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/test_rate_limiter_integration.py @@ -0,0 +1,25 @@ +""" +Integration tests for RateLimiter using testcontainers Redis. +""" + +import uuid + +import pytest + +from extensions.ext_redis import redis_client +from libs import helper as helper_module + + +@pytest.mark.usefixtures("flask_app_with_containers") +def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch): + prefix = f"test_rate_limit:{uuid.uuid4().hex}" + limiter = helper_module.RateLimiter(prefix=prefix, max_attempts=2, time_window=60) + key = limiter._get_key("203.0.113.10") + + redis_client.delete(key) + monkeypatch.setattr(helper_module.time, "time", lambda: 1_700_000_000) + + limiter.increment_rate_limit("203.0.113.10") + limiter.increment_rate_limit("203.0.113.10") + + assert limiter.is_rate_limited("203.0.113.10") is True diff --git a/api/tests/test_containers_integration_tests/models/test_account.py b/api/tests/test_containers_integration_tests/models/test_account.py new file mode 100644 index 0000000000..078dc0e8de --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_account.py @@ -0,0 +1,79 @@ +# import secrets + +# import pytest +# from sqlalchemy import select +# from sqlalchemy.orm import Session +# from sqlalchemy.orm.exc import DetachedInstanceError + +# from libs.datetime_utils import naive_utc_now +# from models.account import Account, Tenant, TenantAccountJoin + + +# @pytest.fixture +# def session(db_session_with_containers): +# with Session(db_session_with_containers.get_bind()) as session: +# yield session + + +# @pytest.fixture +# def account(session): +# account = Account( +# name="test account", +# email=f"test_{secrets.token_hex(8)}@example.com", +# ) +# session.add(account) +# session.commit() +# return account + + +# @pytest.fixture +# def tenant(session): +# tenant = Tenant(name="test tenant") +# session.add(tenant) +# session.commit() +# return tenant + + +# @pytest.fixture +# def tenant_account_join(session, account, tenant): +# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id) +# session.add(tenant_join) +# session.commit() +# yield tenant_join +# session.delete(tenant_join) +# session.commit() + + +# class TestAccountTenant: +# def test_set_current_tenant_should_reload_tenant( +# self, +# db_session_with_containers, +# account, +# tenant, +# tenant_account_join, +# ): +# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session: +# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one() +# account.current_tenant = scoped_tenant +# scoped_tenant.created_at = naive_utc_now() +# # session.commit() + +# # Ensure the tenant used in assignment is detached. +# with pytest.raises(DetachedInstanceError): +# _ = scoped_tenant.name + +# assert account._current_tenant.id == tenant.id +# assert account._current_tenant.id == tenant.id + +# def test_set_tenant_id_should_load_tenant_as_not_expire( +# self, +# flask_app_with_containers, +# account, +# tenant, +# tenant_account_join, +# ): +# with flask_app_with_containers.test_request_context(): +# account.set_tenant_id(tenant.id) + +# assert account._current_tenant.id == tenant.id +# assert account._current_tenant.id == tenant.id diff --git a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py new file mode 100644 index 0000000000..c9058626d1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from sqlalchemy.orm import sessionmaker + +from extensions.ext_database import db +from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository +from tests.test_containers_integration_tests.helpers.execution_extra_content import ( + create_human_input_message_fixture, +) + + +def test_get_by_message_ids_returns_human_input_content(db_session_with_containers): + fixture = create_human_input_message_fixture(db_session_with_containers) + repository = SQLAlchemyExecutionExtraContentRepository( + session_maker=sessionmaker(bind=db.engine, expire_on_commit=False) + ) + + results = repository.get_by_message_ids([fixture.message.id]) + + assert len(results) == 1 + assert len(results[0]) == 1 + content = results[0][0] + assert content.submitted is True + assert content.form_submission_data is not None + assert content.form_submission_data.action_id == fixture.action_id + assert content.form_submission_data.action_text == fixture.action_text + assert content.form_submission_data.rendered_content == fixture.form.rendered_content diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index a09a6e5c65..606e7e0b57 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -2309,6 +2309,12 @@ class TestRegisterService: mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + from extensions.ext_database import db + from models.model import DifySetup + + db.session.query(DifySetup).delete() + db.session.commit() + # Execute setup RegisterService.setup( email=admin_email, @@ -2319,9 +2325,7 @@ class TestRegisterService: ) # Verify account was created - from extensions.ext_database import db from models import Account - from models.model import DifySetup account = db.session.query(Account).filter_by(email=admin_email).first() assert account is not None diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 476f58585d..81bfa0ea20 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -1,5 +1,5 @@ import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from faker import Faker @@ -26,6 +26,7 @@ class TestAppGenerateService: patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator, patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator, patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator, + patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator, patch("services.account_service.FeatureService") as mock_account_feature_service, patch("services.app_generate_service.dify_config") as mock_dify_config, patch("configs.dify_config") as mock_global_dify_config, @@ -38,9 +39,13 @@ class TestAppGenerateService: # Setup default mock returns for workflow service mock_workflow_service_instance = mock_workflow_service.return_value - mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow) - mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow) - mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow) + mock_published_workflow = MagicMock(spec=Workflow) + mock_published_workflow.id = str(uuid.uuid4()) + mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow + mock_draft_workflow = MagicMock(spec=Workflow) + mock_draft_workflow.id = str(uuid.uuid4()) + mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow + mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow # Setup default mock returns for rate limiting mock_rate_limit_instance = mock_rate_limit.return_value @@ -66,6 +71,8 @@ class TestAppGenerateService: mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"] mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"] mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"] + mock_advanced_chat_generator_instance.retrieve_events.return_value = ["advanced_chat_events"] + mock_advanced_chat_generator_instance.convert_to_event_stream.return_value = ["advanced_chat_stream"] mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"] mock_workflow_generator_instance = mock_workflow_generator.return_value @@ -76,6 +83,8 @@ class TestAppGenerateService: mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"] mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"] + mock_message_based_generator.retrieve_events.return_value = ["workflow_events"] + # Setup default mock returns for account service mock_account_feature_service.get_system_features.return_value.is_allow_register = True @@ -88,6 +97,7 @@ class TestAppGenerateService: mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 + mock_global_dify_config.HOSTED_POOL_CREDITS = 1000 yield { "billing_service": mock_billing_service, @@ -98,6 +108,7 @@ class TestAppGenerateService: "agent_chat_generator": mock_agent_chat_generator, "advanced_chat_generator": mock_advanced_chat_generator, "workflow_generator": mock_workflow_generator, + "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, "global_dify_config": mock_global_dify_config, @@ -280,8 +291,10 @@ class TestAppGenerateService: assert result == ["test_response"] # Verify advanced chat generator was called - mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once() - mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once() + mock_external_service_dependencies["advanced_chat_generator"].return_value.retrieve_events.assert_called_once() + mock_external_service_dependencies[ + "advanced_chat_generator" + ].return_value.convert_to_event_stream.assert_called_once() def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -304,7 +317,7 @@ class TestAppGenerateService: assert result == ["test_response"] # Verify workflow generator was called - mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): @@ -970,14 +983,27 @@ class TestAppGenerateService: } # Execute the method under test - result = AppGenerateService.generate( - app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True - ) + with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params: + mock_payload = MagicMock() + mock_payload.workflow_run_id = fake.uuid4() + mock_payload.model_dump_json.return_value = "{}" + mock_exec_params.new.return_value = mock_payload + + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) # Verify the result assert result == ["test_response"] - # Verify workflow generator was called with complex args - mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() - call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args - assert call_args[1]["args"] == args + # Verify payload was built with complex args + mock_exec_params.new.assert_called_once() + call_kwargs = mock_exec_params.new.call_args.kwargs + assert call_kwargs["args"] == args + + # Verify workflow streaming event retrieval was used + mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once_with( + ANY, + mock_payload.workflow_run_id, + on_subscribe=ANY, + ) diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py new file mode 100644 index 0000000000..9c978f830f --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -0,0 +1,112 @@ +import json +import uuid +from unittest.mock import MagicMock + +import pytest + +from core.workflow.enums import NodeType +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, +) +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import App, AppMode +from models.workflow import Workflow, WorkflowType +from services.workflow_service import WorkflowService + + +def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) -> tuple[App, Account]: + tenant = Tenant(name="Test Tenant") + account = Account(name="Tester", email="tester@example.com") + session.add_all([tenant, account]) + session.flush() + + session.add( + TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + current=True, + role=TenantAccountRole.OWNER.value, + ) + ) + + app = App( + tenant_id=tenant.id, + name="Test App", + description="", + mode=AppMode.WORKFLOW.value, + icon_type="emoji", + icon="app", + icon_background="#ffffff", + enable_site=True, + enable_api=True, + created_by=account.id, + updated_by=account.id, + ) + session.add(app) + session.flush() + + email_method = EmailDeliveryMethod( + id=delivery_method_id, + enabled=True, + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ExternalRecipient(email="recipient@example.com")], + ), + subject="Test {{recipient_email}}", + body="Body {{#url#}} {{form_content}}", + ), + ) + node_data = HumanInputNodeData( + title="Human Input", + delivery_methods=[email_method], + form_content="Hello Human Input", + inputs=[], + user_actions=[], + ).model_dump(mode="json") + node_data["type"] = NodeType.HUMAN_INPUT.value + graph = json.dumps({"nodes": [{"id": "human-node", "data": node_data}], "edges": []}) + + workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version=Workflow.VERSION_DRAFT, + graph=graph, + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + session.add(workflow) + session.commit() + + return app, account + + +def test_human_input_delivery_test_sends_email( + db_session_with_containers, + monkeypatch: pytest.MonkeyPatch, +) -> None: + delivery_method_id = uuid.uuid4() + app, account = _create_app_with_draft_workflow(db_session_with_containers, delivery_method_id=delivery_method_id) + + send_mock = MagicMock() + monkeypatch.setattr("services.human_input_delivery_test_service.mail.is_inited", lambda: True) + monkeypatch.setattr("services.human_input_delivery_test_service.mail.send", send_mock) + + service = WorkflowService() + service.test_human_input_delivery( + app_model=app, + account=account, + node_id="human-node", + delivery_method_id=str(delivery_method_id), + ) + + assert send_mock.call_count == 1 + assert send_mock.call_args.kwargs["to"] == "recipient@example.com" diff --git a/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py b/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py new file mode 100644 index 0000000000..44e5a82868 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_message_service_execution_extra_content.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import pytest + +from services.message_service import MessageService +from tests.test_containers_integration_tests.helpers.execution_extra_content import ( + create_human_input_message_fixture, +) + + +@pytest.mark.usefixtures("flask_req_ctx_with_containers") +def test_pagination_returns_extra_contents(db_session_with_containers): + fixture = create_human_input_message_fixture(db_session_with_containers) + + pagination = MessageService.pagination_by_first_id( + app_model=fixture.app, + user=fixture.account, + conversation_id=fixture.conversation.id, + first_id=None, + limit=10, + ) + + assert pagination.data + message = pagination.data[0] + assert message.extra_contents == [ + { + "type": "human_input", + "workflow_run_id": fixture.message.workflow_run_id, + "submitted": True, + "form_submission_data": { + "node_id": fixture.form.node_id, + "node_title": fixture.node_title, + "rendered_content": fixture.form.rendered_content, + "action_id": fixture.action_id, + "action_text": fixture.action_text, + }, + } + ] diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 23c4eeb82f..3a88081db3 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -465,6 +465,27 @@ class TestWorkflowRunService: db.session.add(node_execution) node_executions.append(node_execution) + paused_node_execution = WorkflowNodeExecutionModel( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow_run.workflow_id, + triggered_from="workflow-run", + workflow_run_id=workflow_run.id, + index=99, + node_id="node_paused", + node_type="human_input", + title="Paused Node", + inputs=json.dumps({"input": "paused"}), + process_data=json.dumps({"process": "paused"}), + status="paused", + elapsed_time=0.5, + execution_metadata=json.dumps({"tokens": 0}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(paused_node_execution) + db.session.commit() # Act: Execute the method under test @@ -473,16 +494,19 @@ class TestWorkflowRunService: # Assert: Verify the expected outcomes assert result is not None - assert len(result) == 3 + assert len(result) == 4 # Verify node execution properties + statuses = [node_execution.status for node_execution in result] + assert "paused" in statuses + assert statuses.count("succeeded") == 3 + assert statuses.count("paused") == 1 + for node_execution in result: assert node_execution.tenant_id == app.tenant_id assert node_execution.app_id == app.id assert node_execution.workflow_run_id == workflow_run.id - assert node_execution.index in [0, 1, 2] # Check that index is one of the expected values - assert node_execution.node_id.startswith("node_") # Check that node_id starts with "node_" - assert node_execution.status == "succeeded" + assert node_execution.node_id.startswith("node_") def test_get_workflow_run_node_executions_empty( self, db_session_with_containers, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 3c0a660e7c..24fe5c4670 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -6,6 +6,7 @@ from faker import Faker from pydantic import ValidationError from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService @@ -513,6 +514,62 @@ class TestWorkflowToolManageService: assert tool_count == 0 + def test_create_workflow_tool_human_input_node_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool creation fails when workflow contains human input nodes. + + This test verifies: + - Human input nodes prevent workflow tool publishing + - Correct error message + - No database changes when workflow is invalid + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + workflow.graph = json.dumps( + { + "nodes": [ + { + "id": "human_input_node", + "data": {"type": "human-input"}, + } + ] + } + ) + + tool_parameters = self._create_test_workflow_tool_parameters() + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=tool_parameters, + ) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + + from extensions.ext_database import db + + tool_count = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + ) + .count() + ) + + assert tool_count == 0 + def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): """ Test successful workflow tool update with valid parameters. @@ -600,6 +657,80 @@ class TestWorkflowToolManageService: mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called() mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() + def test_update_workflow_tool_human_input_node_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool update fails when workflow contains human input nodes. + + This test verifies: + - Human input nodes prevent workflow tool updates + - Correct error message + - Existing tool data remains unchanged + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create initial workflow tool + initial_tool_name = fake.word() + initial_tool_parameters = self._create_test_workflow_tool_parameters() + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=initial_tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=initial_tool_parameters, + ) + + from extensions.ext_database import db + + created_tool = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + WorkflowToolProvider.app_id == app.id, + ) + .first() + ) + + original_name = created_tool.name + + workflow.graph = json.dumps( + { + "nodes": [ + { + "id": "human_input_node", + "data": {"type": "human-input"}, + } + ] + } + ) + db.session.commit() + + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + WorkflowToolManageService.update_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_tool_id=created_tool.id, + name=fake.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "⚙️"}, + description=fake.text(max_nb_chars=200), + parameters=initial_tool_parameters, + ) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + + db.session.refresh(created_tool) + assert created_tool.name == original_name + def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): """ Test workflow tool update fails when tool does not exist. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py new file mode 100644 index 0000000000..5fd6c56f7a --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -0,0 +1,214 @@ +import uuid +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest + +from configs import dify_config +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, + MemberRecipient, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from extensions.ext_storage import storage +from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient +from models.model import AppMode +from models.workflow import WorkflowPause, WorkflowRun, WorkflowType +from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task + + +@pytest.fixture(autouse=True) +def cleanup_database(db_session_with_containers): + db_session_with_containers.query(HumanInputFormRecipient).delete() + db_session_with_containers.query(HumanInputDelivery).delete() + db_session_with_containers.query(HumanInputForm).delete() + db_session_with_containers.query(WorkflowPause).delete() + db_session_with_containers.query(WorkflowRun).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() + + +def _create_workspace_member(db_session_with_containers): + account = Account( + email="owner@example.com", + name="Owner", + password="password", + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.created_at = datetime.now(UTC) + account.updated_at = datetime.now(UTC) + db_session_with_containers.add(account) + db_session_with_containers.commit() + db_session_with_containers.refresh(account) + + tenant = Tenant(name="Test Tenant") + tenant.created_at = datetime.now(UTC) + tenant.updated_at = datetime.now(UTC) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + db_session_with_containers.refresh(tenant) + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + tenant_join.created_at = datetime.now(UTC) + tenant_join.updated_at = datetime.now(UTC) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + return tenant, account + + +def _build_form(db_session_with_containers, tenant, account, *, app_id: str, workflow_execution_id: str): + delivery_method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(user_id=account.id), + ExternalRecipient(email="external@example.com"), + ], + ), + subject="Action needed {{ node_title }} {{#node1.value#}}", + body="Token {{ form_token }} link {{#url#}} content {{#node1.value#}}", + ) + ) + + node_data = HumanInputNodeData( + title="Review", + form_content="Form content", + delivery_methods=[delivery_method], + ) + + engine = db_session_with_containers.get_bind() + repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + params = FormCreateParams( + app_id=app_id, + workflow_execution_id=workflow_execution_id, + node_id="node-1", + form_config=node_data, + rendered_content="Rendered", + delivery_methods=node_data.delivery_methods, + display_in_ui=False, + resolved_default_values={}, + ) + return repo.create_form(params) + + +def _create_workflow_pause_state( + db_session_with_containers, + *, + workflow_run_id: str, + workflow_id: str, + tenant_id: str, + app_id: str, + account_id: str, + variable_pool: VariablePool, +): + workflow_run = WorkflowRun( + id=workflow_run_id, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + type=WorkflowType.WORKFLOW, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + version="1", + graph="{}", + inputs="{}", + status=WorkflowExecutionStatus.PAUSED, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account_id, + created_at=datetime.now(UTC), + ) + db_session_with_containers.add(workflow_run) + + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + resumption_context = WorkflowResumptionContext( + generate_entity={ + "type": AppMode.WORKFLOW, + "entity": WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=WorkflowUIBasedAppConfig( + tenant_id=tenant_id, + app_id=app_id, + app_mode=AppMode.WORKFLOW, + workflow_id=workflow_id, + ), + inputs={}, + files=[], + user_id=account_id, + stream=False, + invoke_from=InvokeFrom.WEB_APP, + workflow_execution_id=workflow_run_id, + ), + }, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + state_object_key = f"workflow_pause_states/{workflow_run_id}.json" + storage.save(state_object_key, resumption_context.dumps().encode()) + + pause_state = WorkflowPause( + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + state_object_key=state_object_key, + ) + db_session_with_containers.add(pause_state) + db_session_with_containers.commit() + + +def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers): + tenant, account = _create_workspace_member(db_session_with_containers) + workflow_run_id = str(uuid.uuid4()) + workflow_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + variable_pool = VariablePool() + variable_pool.add(["node1", "value"], "OK") + _create_workflow_pause_state( + db_session_with_containers, + workflow_run_id=workflow_run_id, + workflow_id=workflow_id, + tenant_id=tenant.id, + app_id=app_id, + account_id=account.id, + variable_pool=variable_pool, + ) + form_entity = _build_form( + db_session_with_containers, + tenant, + account, + app_id=app_id, + workflow_execution_id=workflow_run_id, + ) + + monkeypatch.setattr(dify_config, "APP_WEB_URL", "https://app.example.com") + + with patch("tasks.mail_human_input_delivery_task.mail") as mock_mail: + mock_mail.is_inited.return_value = True + + dispatch_human_input_email_task(form_id=form_entity.id, node_title="Approval") + + assert mock_mail.send.call_count == 2 + send_args = [call.kwargs for call in mock_mail.send.call_args_list] + recipients = {kwargs["to"] for kwargs in send_args} + assert recipients == {"owner@example.com", "external@example.com"} + assert all(kwargs["subject"] == "Action needed {{ node_title }} {{#node1.value#}}" for kwargs in send_args) + assert all("app.example.com/form/" in kwargs["html"] for kwargs in send_args) + assert all("content OK" in kwargs["html"] for kwargs in send_args) + assert all("{{ form_token }}" in kwargs["html"] for kwargs in send_args) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 889e3d1d83..5f4f28cf4f 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -94,11 +94,6 @@ class PrunePausesTestCase: def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]: """Create test cases for pause workflow failure scenarios.""" return [ - PauseWorkflowFailureCase( - name="pause_already_paused_workflow", - initial_status=WorkflowExecutionStatus.PAUSED, - description="Should fail to pause an already paused workflow", - ), PauseWorkflowFailureCase( name="pause_completed_workflow", initial_status=WorkflowExecutionStatus.SUCCEEDED, diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 6fce7849f9..cf52980e57 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -164,6 +164,62 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): assert "timezone=UTC" in options +def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_HOST", "redis.example.com") + monkeypatch.setenv("REDIS_PORT", "6380") + monkeypatch.setenv("REDIS_USERNAME", "user") + monkeypatch.setenv("REDIS_PASSWORD", "pass@word") + monkeypatch.setenv("REDIS_DB", "2") + monkeypatch.setenv("REDIS_USE_SSL", "true") + + config = DifyConfig() + + assert config.normalized_pubsub_redis_url == "rediss://user:pass%40word@redis.example.com:6380/2" + assert config.PUBSUB_REDIS_CHANNEL_TYPE == "pubsub" + + +def test_pubsub_redis_url_override(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("PUBSUB_REDIS_URL", "redis://pubsub-host:6381/5") + + config = DifyConfig() + + assert config.normalized_pubsub_redis_url == "redis://pubsub-host:6381/5" + + +def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_HOST", "") + + with pytest.raises(ValueError, match="PUBSUB_REDIS_URL must be set"): + _ = DifyConfig().normalized_pubsub_redis_url + + @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index e3c1a617f7..da957d3a81 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -51,6 +51,8 @@ def _patch_redis_clients_on_loaded_modules(): continue if hasattr(module, "redis_client"): module.redis_client = redis_mock + if hasattr(module, "pubsub_redis_client"): + module.pubsub_redis_client = redis_mock @pytest.fixture @@ -68,7 +70,10 @@ def _provide_app_context(app: Flask): def _patch_redis_clients(): """Patch redis_client to MagicMock only for unit test executions.""" - with patch.object(ext_redis, "redis_client", redis_mock): + with ( + patch.object(ext_redis, "redis_client", redis_mock), + patch.object(ext_redis, "pubsub_redis_client", redis_mock), + ): _patch_redis_clients_on_loaded_modules() yield diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index c557605916..2ac3dc037d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -16,11 +16,9 @@ if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] -def _load_app_module(): +@pytest.fixture(scope="module") +def app_module(): module_name = "controllers.console.app.app" - if module_name in sys.modules: - return sys.modules[module_name] - root = Path(__file__).resolve().parents[5] module_path = root / "controllers" / "console" / "app" / "app.py" @@ -59,8 +57,12 @@ def _load_app_module(): stub_namespace = _StubNamespace() - original_console = sys.modules.get("controllers.console") - original_app_pkg = sys.modules.get("controllers.console.app") + original_modules: dict[str, ModuleType | None] = { + "controllers.console": sys.modules.get("controllers.console"), + "controllers.console.app": sys.modules.get("controllers.console.app"), + "controllers.common.schema": sys.modules.get("controllers.common.schema"), + module_name: sys.modules.get(module_name), + } stubbed_modules: list[tuple[str, ModuleType | None]] = [] console_module = ModuleType("controllers.console") @@ -105,35 +107,35 @@ def _load_app_module(): module = util.module_from_spec(spec) sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + try: - assert spec.loader is not None - spec.loader.exec_module(module) + yield module finally: for name, original in reversed(stubbed_modules): if original is not None: sys.modules[name] = original else: sys.modules.pop(name, None) - if original_console is not None: - sys.modules["controllers.console"] = original_console - else: - sys.modules.pop("controllers.console", None) - if original_app_pkg is not None: - sys.modules["controllers.console.app"] = original_app_pkg - else: - sys.modules.pop("controllers.console.app", None) - - return module + for name, original in original_modules.items(): + if original is not None: + sys.modules[name] = original + else: + sys.modules.pop(name, None) -_app_module = _load_app_module() -AppDetailWithSite = _app_module.AppDetailWithSite -AppPagination = _app_module.AppPagination -AppPartial = _app_module.AppPartial +@pytest.fixture(scope="module") +def app_models(app_module): + return SimpleNamespace( + AppDetailWithSite=app_module.AppDetailWithSite, + AppPagination=app_module.AppPagination, + AppPartial=app_module.AppPartial, + ) @pytest.fixture(autouse=True) -def patch_signed_url(monkeypatch): +def patch_signed_url(monkeypatch, app_module): """Ensure icon URL generation uses a deterministic helper for tests.""" def _fake_signed_url(key: str | None) -> str | None: @@ -141,7 +143,7 @@ def patch_signed_url(monkeypatch): return None return f"signed:{key}" - monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) def _ts(hour: int = 12) -> datetime: @@ -169,7 +171,8 @@ def _dummy_workflow(): ) -def test_app_partial_serialization_uses_aliases(): +def test_app_partial_serialization_uses_aliases(app_models): + AppPartial = app_models.AppPartial created_at = _ts() app_obj = SimpleNamespace( id="app-1", @@ -204,7 +207,8 @@ def test_app_partial_serialization_uses_aliases(): assert serialized["tags"][0]["name"] == "Utilities" -def test_app_detail_with_site_includes_nested_serialization(): +def test_app_detail_with_site_includes_nested_serialization(app_models): + AppDetailWithSite = app_models.AppDetailWithSite timestamp = _ts(14) site = SimpleNamespace( code="site-code", @@ -253,7 +257,8 @@ def test_app_detail_with_site_includes_nested_serialization(): assert serialized["site"]["created_at"] == int(timestamp.timestamp()) -def test_app_pagination_aliases_per_page_and_has_next(): +def test_app_pagination_aliases_per_page_and_has_next(app_models): + AppPagination = app_models.AppPagination item_one = SimpleNamespace( id="app-10", name="Paginated One", diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py new file mode 100644 index 0000000000..86a3b2bd93 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow as workflow_module +from controllers.console.app import wraps as app_wraps +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole +from models.model import AppMode + + +def _make_account() -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = TenantAccountRole.OWNER + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _make_app(mode: AppMode) -> SimpleNamespace: + return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value) + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: + # Skip setup and auth guardrails + monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD") + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + monkeypatch.delenv("INIT_PASSWORD", raising=False) + + # Avoid hitting the database when resolving the app model + monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + + +@dataclass +class PreviewCase: + resource_cls: type + path: str + mode: AppMode + + +@pytest.mark.parametrize( + "case", + [ + PreviewCase( + resource_cls=workflow_module.AdvancedChatDraftHumanInputFormPreviewApi, + path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form/preview", + mode=AppMode.ADVANCED_CHAT, + ), + PreviewCase( + resource_cls=workflow_module.WorkflowDraftHumanInputFormPreviewApi, + path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form/preview", + mode=AppMode.WORKFLOW, + ), + ], +) +def test_human_input_preview_delegates_to_service( + app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase +) -> None: + account = _make_account() + app_model = _make_app(case.mode) + _patch_console_guards(monkeypatch, account, app_model) + + preview_payload = { + "form_id": "node-42", + "form_content": "done
"}, "action": "approve"} + service_instance = MagicMock() + service_instance.submit_human_input_form_preview.return_value = result_payload + monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) + + with app.test_request_context( + case.path, + method="POST", + json={"form_inputs": {"answer": "42"}, "inputs": {"#node-1.result#": "LLM output"}, "action": "approve"}, + ): + response = case.resource_cls().post(app_id=app_model.id, node_id="node-99") + + assert response == result_payload + service_instance.submit_human_input_form_preview.assert_called_once_with( + app_model=app_model, + account=account, + node_id="node-99", + form_inputs={"answer": "42"}, + inputs={"#node-1.result#": "LLM output"}, + action="approve", + ) + + +@dataclass +class DeliveryTestCase: + resource_cls: type + path: str + mode: AppMode + + +@pytest.mark.parametrize( + "case", + [ + DeliveryTestCase( + resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi, + path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test", + mode=AppMode.ADVANCED_CHAT, + ), + DeliveryTestCase( + resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi, + path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test", + mode=AppMode.WORKFLOW, + ), + ], +) +def test_human_input_delivery_test_calls_service( + app: Flask, monkeypatch: pytest.MonkeyPatch, case: DeliveryTestCase +) -> None: + account = _make_account() + app_model = _make_app(case.mode) + _patch_console_guards(monkeypatch, account, app_model) + + service_instance = MagicMock() + monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) + + with app.test_request_context( + case.path, + method="POST", + json={"delivery_method_id": "delivery-123"}, + ): + response = case.resource_cls().post(app_id=app_model.id, node_id="node-7") + + assert response == {} + service_instance.test_human_input_delivery.assert_called_once_with( + app_model=app_model, + account=account, + node_id="node-7", + delivery_method_id="delivery-123", + inputs={}, + ) + + +def test_human_input_delivery_test_maps_validation_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + app_model = _make_app(AppMode.ADVANCED_CHAT) + _patch_console_guards(monkeypatch, account, app_model) + + service_instance = MagicMock() + service_instance.test_human_input_delivery.side_effect = ValueError("bad delivery method") + monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance)) + + with app.test_request_context( + "/console/api/apps/app-123/workflows/draft/human-input/nodes/node-1/delivery-test", + method="POST", + json={"delivery_method_id": "bad"}, + ): + with pytest.raises(ValueError): + workflow_module.WorkflowDraftHumanInputDeliveryTestApi().post(app_id=app_model.id, node_id="node-1") + + +def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + app_model = _make_app(AppMode.ADVANCED_CHAT) + _patch_console_guards(monkeypatch, account, app_model) + + with app.test_request_context( + "/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form/preview", + method="POST", + json={"inputs": ["not-a-dict"]}, + ): + with pytest.raises(ValidationError): + workflow_module.AdvancedChatDraftHumanInputFormPreviewApi().post(app_id=app_model.id, node_id="node-1") diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py new file mode 100644 index 0000000000..f9788e2e50 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from flask import Flask + +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_run as workflow_run_module +from controllers.web.error import NotFoundError +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.nodes.human_input.entities import FormInput, UserAction +from core.workflow.nodes.human_input.enums import FormInputType +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole +from models.workflow import WorkflowRun + + +def _make_account() -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = TenantAccountRole.OWNER + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None: + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(workflow_run_module, "current_user", account) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + + +class _PauseEntity: + def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]): + self.paused_at = paused_at + self._reasons = reasons + + def get_pause_reasons(self): + return self._reasons + + +def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + _patch_console_guards(monkeypatch, account) + monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com") + + workflow_run = Mock(spec=WorkflowRun) + workflow_run.tenant_id = "tenant-123" + workflow_run.status = WorkflowExecutionStatus.PAUSED + workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0) + fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run)) + monkeypatch.setattr(workflow_run_module, "db", fake_db) + + reason = HumanInputRequired( + form_id="form-1", + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + actions=[UserAction(id="approve", title="Approve")], + node_id="node-1", + node_title="Ask Name", + form_token="backstage-token", + ) + pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) + + repo = Mock() + repo.get_workflow_pause.return_value = pause_entity + monkeypatch.setattr( + workflow_run_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_, **__: repo, + ) + + with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): + response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + + assert status == 200 + assert response["paused_at"] == "2024-01-01T12:00:00Z" + assert response["paused_nodes"][0]["node_id"] == "node-1" + assert response["paused_nodes"][0]["pause_type"]["type"] == "human_input" + assert ( + response["paused_nodes"][0]["pause_type"]["backstage_input_url"] + == "https://web.example.com/form/backstage-token" + ) + assert "pending_human_inputs" not in response + + +def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + _patch_console_guards(monkeypatch, account) + monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com") + + workflow_run = Mock(spec=WorkflowRun) + workflow_run.tenant_id = "tenant-456" + workflow_run.status = WorkflowExecutionStatus.PAUSED + workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0) + fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run)) + monkeypatch.setattr(workflow_run_module, "db", fake_db) + + with pytest.raises(NotFoundError): + with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): + response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py new file mode 100644 index 0000000000..fcaa61a871 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -0,0 +1,25 @@ +from types import SimpleNamespace + +from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField +from core.workflow.enums import WorkflowExecutionStatus + + +def test_workflow_run_status_field_with_enum() -> None: + field = WorkflowRunStatusField() + obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED) + + assert field.output("status", obj) == "paused" + + +def test_workflow_run_outputs_field_paused_returns_empty() -> None: + field = WorkflowRunOutputsField() + obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"}) + + assert field.output("outputs", obj) == {} + + +def test_workflow_run_outputs_field_running_returns_outputs() -> None: + field = WorkflowRunOutputsField() + obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"}) + + assert field.output("outputs", obj) == {"foo": "bar"} diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py new file mode 100644 index 0000000000..4fb735b033 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -0,0 +1,456 @@ +"""Unit tests for controllers.web.human_input_form endpoints.""" + +from __future__ import annotations + +import json +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +import controllers.web.human_input_form as human_input_module +import controllers.web.site as site_module +from controllers.web.error import WebFormRateLimitExceededError +from models.human_input import RecipientType +from services.human_input_service import FormExpiredError + +HumanInputFormApi = human_input_module.HumanInputFormApi +TenantStatus = human_input_module.TenantStatus + + +@pytest.fixture +def app() -> Flask: + """Configure a minimal Flask app for request contexts.""" + + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +class _FakeSession: + """Simple stand-in for db.session that returns pre-seeded objects.""" + + def __init__(self, mapping: dict[str, Any]): + self._mapping = mapping + self._model_name: str | None = None + + def query(self, model): + self._model_name = model.__name__ + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + assert self._model_name is not None + return self._mapping.get(self._model_name) + + +class _FakeDB: + """Minimal db stub exposing engine and session.""" + + def __init__(self, session: _FakeSession): + self.session = session + self.engine = object() + + +def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask): + """GET returns form definition merged with site payload.""" + + expiration_time = datetime(2099, 1, 1, tzinfo=UTC) + + class _FakeDefinition: + def model_dump(self): + return { + "form_content": "Raw content", + "rendered_content": "Rendered {{#$output.name#}}", + "inputs": [{"type": "text", "output_variable_name": "name", "default": None}], + "default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}}, + "user_actions": [{"id": "approve", "title": "Approve", "button_style": "default"}], + } + + class _FakeForm: + def __init__(self, expiration: datetime): + self.workflow_run_id = "workflow-1" + self.app_id = "app-1" + self.tenant_id = "tenant-1" + self.expiration_time = expiration + self.recipient_type = RecipientType.BACKSTAGE + + def get_definition(self): + return _FakeDefinition() + + form = _FakeForm(expiration_time) + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + + tenant = SimpleNamespace( + id="tenant-1", + status=TenantStatus.NORMAL, + plan="basic", + custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False}, + ) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + workflow_run = SimpleNamespace(app_id="app-1") + site_model = SimpleNamespace( + title="My Site", + icon_type="emoji", + icon="robot", + icon_background="#fff", + description="desc", + default_language="en", + chat_color_theme="light", + chat_color_theme_inverted=False, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + prompt_public=False, + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + # Patch service to return fake form. + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + + # Patch db session. + db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model})) + monkeypatch.setattr(human_input_module, "db", db_stub) + + monkeypatch.setattr( + site_module.FeatureService, + "get_features", + lambda tenant_id: SimpleNamespace(can_replace_logo=True), + ) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + response = HumanInputFormApi().get("token-1") + + body = json.loads(response.get_data(as_text=True)) + assert set(body.keys()) == { + "site", + "form_content", + "inputs", + "resolved_default_values", + "user_actions", + "expiration_time", + } + assert body["form_content"] == "Rendered {{#$output.name#}}" + assert body["inputs"] == [{"type": "text", "output_variable_name": "name", "default": None}] + assert body["resolved_default_values"] == {"name": "Alice", "age": "30", "meta": '{"k": "v"}'} + assert body["user_actions"] == [{"id": "approve", "title": "Approve", "button_style": "default"}] + assert body["expiration_time"] == int(expiration_time.timestamp()) + assert body["site"] == { + "app_id": "app-1", + "end_user_id": None, + "enable_site": True, + "site": { + "title": "My Site", + "chat_color_theme": "light", + "chat_color_theme_inverted": False, + "icon_type": "emoji", + "icon": "robot", + "icon_background": "#fff", + "icon_url": None, + "description": "desc", + "copyright": None, + "privacy_policy": None, + "custom_disclaimer": None, + "default_language": "en", + "prompt_public": False, + "show_workflow_steps": True, + "use_icon_as_answer_icon": False, + }, + "model_config": None, + "plan": "basic", + "can_replace_logo": True, + "custom_config": { + "remove_webapp_brand": True, + "replace_webapp_logo": None, + }, + } + service_mock.get_form_by_token.assert_called_once_with("token-1") + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") + + +def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask): + """GET returns form payload for backstage token.""" + + expiration_time = datetime(2099, 1, 2, tzinfo=UTC) + + class _FakeDefinition: + def model_dump(self): + return { + "form_content": "Raw content", + "rendered_content": "Rendered", + "inputs": [], + "default_values": {}, + "user_actions": [], + } + + class _FakeForm: + def __init__(self, expiration: datetime): + self.workflow_run_id = "workflow-1" + self.app_id = "app-1" + self.tenant_id = "tenant-1" + self.expiration_time = expiration + + def get_definition(self): + return _FakeDefinition() + + form = _FakeForm(expiration_time) + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + tenant = SimpleNamespace( + id="tenant-1", + status=TenantStatus.NORMAL, + plan="basic", + custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False}, + ) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + workflow_run = SimpleNamespace(app_id="app-1") + site_model = SimpleNamespace( + title="My Site", + icon_type="emoji", + icon="robot", + icon_background="#fff", + description="desc", + default_language="en", + chat_color_theme="light", + chat_color_theme_inverted=False, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + prompt_public=False, + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + + db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model})) + monkeypatch.setattr(human_input_module, "db", db_stub) + + monkeypatch.setattr( + site_module.FeatureService, + "get_features", + lambda tenant_id: SimpleNamespace(can_replace_logo=True), + ) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + response = HumanInputFormApi().get("token-1") + + body = json.loads(response.get_data(as_text=True)) + assert set(body.keys()) == { + "site", + "form_content", + "inputs", + "resolved_default_values", + "user_actions", + "expiration_time", + } + assert body["form_content"] == "Rendered" + assert body["inputs"] == [] + assert body["resolved_default_values"] == {} + assert body["user_actions"] == [] + assert body["expiration_time"] == int(expiration_time.timestamp()) + assert body["site"] == { + "app_id": "app-1", + "end_user_id": None, + "enable_site": True, + "site": { + "title": "My Site", + "chat_color_theme": "light", + "chat_color_theme_inverted": False, + "icon_type": "emoji", + "icon": "robot", + "icon_background": "#fff", + "icon_url": None, + "description": "desc", + "copyright": None, + "privacy_policy": None, + "custom_disclaimer": None, + "default_language": "en", + "prompt_public": False, + "show_workflow_steps": True, + "use_icon_as_answer_icon": False, + }, + "model_config": None, + "plan": "basic", + "can_replace_logo": True, + "custom_config": { + "remove_webapp_brand": True, + "replace_webapp_logo": None, + }, + } + service_mock.get_form_by_token.assert_called_once_with("token-1") + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") + + +def test_get_form_raises_forbidden_when_site_missing(monkeypatch: pytest.MonkeyPatch, app: Flask): + """GET raises Forbidden if site cannot be resolved.""" + + expiration_time = datetime(2099, 1, 3, tzinfo=UTC) + + class _FakeDefinition: + def model_dump(self): + return { + "form_content": "Raw content", + "rendered_content": "Rendered", + "inputs": [], + "default_values": {}, + "user_actions": [], + } + + class _FakeForm: + def __init__(self, expiration: datetime): + self.workflow_run_id = "workflow-1" + self.app_id = "app-1" + self.tenant_id = "tenant-1" + self.expiration_time = expiration + + def get_definition(self): + return _FakeDefinition() + + form = _FakeForm(expiration_time) + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + tenant = SimpleNamespace(status=TenantStatus.NORMAL) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + workflow_run = SimpleNamespace(app_id="app-1") + + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + + db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": None})) + monkeypatch.setattr(human_input_module, "db", db_stub) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + with pytest.raises(Forbidden): + HumanInputFormApi().get("token-1") + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") + + +def test_submit_form_accepts_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask): + """POST forwards backstage submissions to the service.""" + + class _FakeForm: + recipient_type = RecipientType.BACKSTAGE + + form = _FakeForm() + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) + + with app.test_request_context( + "/api/form/human_input/token-1", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + response, status = HumanInputFormApi().post("token-1") + + assert status == 200 + assert response == {} + service_mock.submit_form_by_token.assert_called_once_with( + recipient_type=RecipientType.BACKSTAGE, + form_token="token-1", + selected_action_id="approve", + form_data={"content": "ok"}, + submission_end_user_id=None, + ) + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") + + +def test_submit_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask): + """POST rejects submissions when rate limit is exceeded.""" + + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = True + monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = None + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) + + with app.test_request_context( + "/api/form/human_input/token-1", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + with pytest.raises(WebFormRateLimitExceededError): + HumanInputFormApi().post("token-1") + + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_not_called() + service_mock.get_form_by_token.assert_not_called() + + +def test_get_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask): + """GET rejects requests when rate limit is exceeded.""" + + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = True + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = None + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + with pytest.raises(WebFormRateLimitExceededError): + HumanInputFormApi().get("token-1") + + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_not_called() + service_mock.get_form_by_token.assert_not_called() + + +def test_get_form_raises_expired(monkeypatch: pytest.MonkeyPatch, app: Flask): + class _FakeForm: + pass + + form = _FakeForm() + limiter_mock = MagicMock() + limiter_mock.is_rate_limited.return_value = False + monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock) + monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10") + service_mock = MagicMock() + service_mock.get_form_by_token.return_value = form + service_mock.ensure_form_active.side_effect = FormExpiredError("form-id") + monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock) + monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({}))) + + with app.test_request_context("/api/form/human_input/token-1", method="GET"): + with pytest.raises(FormExpiredError): + HumanInputFormApi().get("token-1") + + service_mock.ensure_form_active.assert_called_once_with(form) + limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10") + limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10") diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py index 2835f7ffbf..1c096bfbcf 100644 --- a/api/tests/unit_tests/controllers/web/test_message_list.py +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -3,6 +3,7 @@ from __future__ import annotations import builtins +import uuid from datetime import datetime from types import ModuleType, SimpleNamespace from unittest.mock import patch @@ -12,6 +13,8 @@ import pytest from flask import Flask from flask.views import MethodView +from core.entities.execution_extra_content import HumanInputContent + # Ensure flask_restx.api finds MethodView during import. if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] @@ -137,6 +140,12 @@ def test_message_list_mapping(app: Flask) -> None: status="success", error=None, message_metadata_dict={"meta": "value"}, + extra_contents=[ + HumanInputContent( + workflow_run_id=str(uuid.uuid4()), + submitted=True, + ) + ], ) pagination = SimpleNamespace(limit=20, has_more=False, data=[message]) @@ -169,6 +178,8 @@ def test_message_list_mapping(app: Flask) -> None: assert item["agent_thoughts"][0]["chain_id"] == "chain-1" assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp()) + assert item["extra_contents"][0]["workflow_run_id"] == message.extra_contents[0].workflow_run_id + assert item["extra_contents"][0]["submitted"] == message.extra_contents[0].submitted assert item["message_files"][0]["id"] == "file-dict" assert item["message_files"][1]["id"] == "file-obj" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py new file mode 100644 index 0000000000..a94b5445f7 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime +from types import SimpleNamespace +from unittest import mock + +import pytest + +from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent +from core.workflow.entities.pause_reason import HumanInputRequired +from models.enums import MessageStatus +from models.execution_extra_content import HumanInputContent +from models.model import EndUser + + +def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: + pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__( + pipeline_module.AdvancedChatAppGenerateTaskPipeline + ) + pipeline._workflow_run_id = "run-1" + pipeline._message_id = "message-1" + pipeline._workflow_tenant_id = "tenant-1" + return pipeline + + +def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None: + pipeline = _build_pipeline() + monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1") + + captured_session: dict[str, mock.Mock] = {} + + @contextmanager + def fake_session(): + session = mock.Mock() + session.scalar.return_value = None + captured_session["session"] = session + yield session + + pipeline._database_session = fake_session # type: ignore[method-assign] + + pipeline._persist_human_input_extra_content(node_id="node-1") + + session = captured_session["session"] + session.add.assert_called_once() + content = session.add.call_args.args[0] + assert isinstance(content, HumanInputContent) + assert content.workflow_run_id == "run-1" + assert content.message_id == "message-1" + assert content.form_id == "form-1" + + +def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + pipeline = _build_pipeline() + monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None) + + called = {"value": False} + + @contextmanager + def fake_session(): + called["value"] = True + session = mock.Mock() + yield session + + pipeline._database_session = fake_session # type: ignore[method-assign] + + pipeline._persist_human_input_extra_content(node_id="node-1") + + assert called["value"] is False + + +def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None: + pipeline = _build_pipeline() + monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1") + + captured_session: dict[str, mock.Mock] = {} + + @contextmanager + def fake_session(): + session = mock.Mock() + session.scalar.return_value = HumanInputContent( + workflow_run_id="run-1", + message_id="message-1", + form_id="form-1", + ) + captured_session["session"] = session + yield session + + pipeline._database_session = fake_session # type: ignore[method-assign] + + pipeline._persist_human_input_extra_content(node_id="node-1") + + session = captured_session["session"] + session.add.assert_not_called() + + +def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None: + pipeline = _build_pipeline() + pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") + pipeline._workflow_response_converter = mock.Mock() + pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = [] + pipeline._ensure_graph_runtime_initialized = mock.Mock( + return_value=SimpleNamespace( + total_tokens=0, + node_run_steps=0, + ), + ) + pipeline._save_message = mock.Mock() + message = SimpleNamespace(status=MessageStatus.NORMAL) + pipeline._get_message = mock.Mock(return_value=message) + pipeline._persist_human_input_extra_content = mock.Mock() + pipeline._base_task_pipeline = mock.Mock() + pipeline._base_task_pipeline.queue_manager = mock.Mock() + pipeline._message_saved_on_pause = False + + @contextmanager + def fake_session(): + session = mock.Mock() + yield session + + pipeline._database_session = fake_session # type: ignore[method-assign] + + reason = HumanInputRequired( + form_id="form-1", + form_content="content", + inputs=[], + actions=[], + node_id="node-1", + node_title="Approval", + form_token="token-1", + resolved_default_values={}, + ) + event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) + + list(pipeline._handle_workflow_paused_event(event)) + + pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1") + assert message.status == MessageStatus.PAUSED + + +def test_resume_appends_chunks_to_paused_answer() -> None: + app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None) + application_generate_entity = SimpleNamespace( + app_config=app_config, + files=[], + workflow_run_id="run-1", + query="hello", + invoke_from=InvokeFrom.WEB_APP, + inputs={}, + task_id="task-1", + ) + queue_manager = SimpleNamespace(graph_runtime_state=None) + conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") + message = SimpleNamespace( + id="message-1", + created_at=datetime(2024, 1, 1), + query="hello", + answer="before", + status=MessageStatus.PAUSED, + ) + user = EndUser() + user.id = "user-1" + user.session_id = "session-1" + workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) + + pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=True, + dialogue_count=1, + draft_var_saver_factory=SimpleNamespace(), + ) + + pipeline._get_message = mock.Mock(return_value=message) + pipeline._recorded_files = [] + + list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) + pipeline._save_message(session=mock.Mock()) + + assert message.answer == "beforeafter" + assert message.status == MessageStatus.NORMAL diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py new file mode 100644 index 0000000000..1c36b4d12b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -0,0 +1,87 @@ +from datetime import UTC, datetime +from types import SimpleNamespace + +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + + +def _build_converter(): + system_variables = SystemVariable( + files=[], + user_id="user-1", + app_id="app-1", + workflow_id="wf-1", + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + app_entity = SimpleNamespace( + task_id="task-1", + app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"), + invoke_from=InvokeFrom.EXPLORE, + files=[], + inputs={}, + workflow_execution_id="run-1", + call_depth=0, + ) + account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com") + return WorkflowResponseConverter( + application_generate_entity=app_entity, + user=account, + system_variables=system_variables, + ) + + +def test_human_input_form_filled_stream_response_contains_rendered_content(): + converter = _build_converter() + converter.workflow_start_to_stream_response( + task_id="task-1", + workflow_run_id="run-1", + workflow_id="wf-1", + reason=WorkflowStartReason.INITIAL, + ) + + queue_event = QueueHumanInputFormFilledEvent( + node_execution_id="exec-1", + node_id="node-1", + node_type="human-input", + node_title="Human Input", + rendered_content="# Title\nvalue", + action_id="Approve", + action_text="Approve", + ) + + resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1") + + assert resp.workflow_run_id == "run-1" + assert resp.data.node_id == "node-1" + assert resp.data.node_title == "Human Input" + assert resp.data.rendered_content.startswith("# Title") + assert resp.data.action_id == "Approve" + + +def test_human_input_form_timeout_stream_response_contains_timeout_metadata(): + converter = _build_converter() + converter.workflow_start_to_stream_response( + task_id="task-1", + workflow_run_id="run-1", + workflow_id="wf-1", + reason=WorkflowStartReason.INITIAL, + ) + + queue_event = QueueHumanInputFormTimeoutEvent( + node_id="node-1", + node_type="human-input", + node_title="Human Input", + expiration_time=datetime(2025, 1, 1, tzinfo=UTC), + ) + + resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1") + + assert resp.workflow_run_id == "run-1" + assert resp.data.node_id == "node-1" + assert resp.data.node_title == "Human Input" + assert resp.data.expiration_time == 1735689600 diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py new file mode 100644 index 0000000000..0a9794e41c --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -0,0 +1,56 @@ +from types import SimpleNamespace + +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + + +def _build_converter() -> WorkflowResponseConverter: + """Construct a minimal WorkflowResponseConverter for testing.""" + system_variables = SystemVariable( + files=[], + user_id="user-1", + app_id="app-1", + workflow_id="wf-1", + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + app_entity = SimpleNamespace( + task_id="task-1", + app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"), + invoke_from=InvokeFrom.EXPLORE, + files=[], + inputs={}, + workflow_execution_id="run-1", + call_depth=0, + ) + account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com") + return WorkflowResponseConverter( + application_generate_entity=app_entity, + user=account, + system_variables=system_variables, + ) + + +def test_workflow_start_stream_response_carries_resumption_reason(): + converter = _build_converter() + resp = converter.workflow_start_to_stream_response( + task_id="task-1", + workflow_run_id="run-1", + workflow_id="wf-1", + reason=WorkflowStartReason.RESUMPTION, + ) + assert resp.data.reason is WorkflowStartReason.RESUMPTION + + +def test_workflow_start_stream_response_carries_initial_reason(): + converter = _build_converter() + resp = converter.workflow_start_to_stream_response( + task_id="task-1", + workflow_run_id="run-1", + workflow_id="wf-1", + reason=WorkflowStartReason.INITIAL, + ) + assert resp.data.reason is WorkflowStartReason.INITIAL diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 6b40bf462b..d25bff92dc 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -23,6 +23,7 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) +from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import NodeType from core.workflow.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now @@ -124,7 +125,12 @@ class TestWorkflowResponseConverter: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -160,7 +166,12 @@ class TestWorkflowResponseConverter: original_data = {"small": "data"} - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -191,7 +202,12 @@ class TestWorkflowResponseConverter: """Test node finish response when process_data is None.""" converter = self.create_workflow_response_converter() - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -225,7 +241,12 @@ class TestWorkflowResponseConverter: original_data = {"large_field": "x" * 10000, "metadata": "info"} truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -261,7 +282,12 @@ class TestWorkflowResponseConverter: original_data = {"small": "data"} - converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id") + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) start_event = self.create_node_started_event() converter.workflow_node_start_to_stream_response( event=start_event, @@ -400,6 +426,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: task_id="test-task-id", workflow_run_id="test-workflow-run-id", workflow_id="test-workflow-id", + reason=WorkflowStartReason.INITIAL, ) return converter diff --git a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py new file mode 100644 index 0000000000..f0d9afc0db --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps import message_based_app_generator +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.task_pipeline import message_cycle_manager +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from models.model import AppMode, Conversation, Message + + +def _make_app_config() -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant-id", + app_id="app-id", + app_mode=AppMode.ADVANCED_CHAT, + workflow_id="workflow-id", + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity: + return AdvancedChatAppGenerateEntity( + task_id="task-id", + app_config=app_config, + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user-id", + stream=True, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + workflow_run_id="workflow-run-id", + ) + + +@pytest.fixture(autouse=True) +def _mock_db_session(monkeypatch): + session = MagicMock() + + def refresh_side_effect(obj): + if isinstance(obj, Conversation) and obj.id is None: + obj.id = "generated-conversation-id" + if isinstance(obj, Message) and obj.id is None: + obj.id = "generated-message-id" + + session.refresh.side_effect = refresh_side_effect + session.add.return_value = None + session.commit.return_value = None + + monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session)) + return session + + +def test_init_generate_records_sets_conversation_metadata(): + app_config = _make_app_config() + entity = _make_generate_entity(app_config) + + generator = AdvancedChatAppGenerator() + + conversation, _ = generator._init_generate_records(entity, conversation=None) + + assert entity.conversation_id == "generated-conversation-id" + assert conversation.id == "generated-conversation-id" + assert entity.is_new_conversation is True + + +def test_init_generate_records_marks_existing_conversation(): + app_config = _make_app_config() + entity = _make_generate_entity(app_config) + + existing_conversation = Conversation( + app_id=app_config.app_id, + app_model_config_id=None, + model_provider=None, + override_model_configs=None, + model_id=None, + mode=app_config.app_mode.value, + name="existing", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP.value, + from_source="api", + from_end_user_id="user-id", + from_account_id=None, + ) + existing_conversation.id = "existing-conversation-id" + + generator = AdvancedChatAppGenerator() + + conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation) + + assert entity.conversation_id == "existing-conversation-id" + assert conversation is existing_conversation + assert entity.is_new_conversation is False + + +def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch): + app_config = _make_app_config() + entity = _make_generate_entity(app_config) + entity.conversation_id = "existing-conversation-id" + entity.is_new_conversation = True + entity.extras = {"auto_generate_conversation_name": True} + + captured = {} + + class DummyThread: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.started = False + + def start(self): + self.started = True + + def fake_thread(**kwargs): + thread = DummyThread(**kwargs) + captured["thread"] = thread + return thread + + monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread) + + manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock()) + thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello") + + assert thread is captured["thread"] + assert thread.started is True + assert entity.is_new_conversation is False diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py new file mode 100644 index 0000000000..87b8dc51e7 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import ( + AppAdditionalFeatures, + EasyUIBasedAppConfig, + EasyUIBasedAppModelConfigFrom, + ModelConfigEntity, + PromptTemplateEntity, +) +from core.app.apps import message_based_app_generator +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom +from models.model import AppMode, Conversation, Message + + +class DummyModelConf: + def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None: + self.provider = provider + self.model = model + + +class DummyCompletionGenerateEntity: + __slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf") + app_config: EasyUIBasedAppConfig + invoke_from: InvokeFrom + user_id: str + query: str + inputs: dict + files: list + model_conf: DummyModelConf + + def __init__(self, app_config: EasyUIBasedAppConfig) -> None: + self.app_config = app_config + self.invoke_from = InvokeFrom.WEB_APP + self.user_id = "user-id" + self.query = "hello" + self.inputs = {} + self.files = [] + self.model_conf = DummyModelConf() + + +def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: + return EasyUIBasedAppConfig( + tenant_id="tenant-id", + app_id="app-id", + app_mode=app_mode, + app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, + app_model_config_id="model-config-id", + app_model_config_dict={}, + model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"), + prompt_template=PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="Hello", + ), + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity: + return ChatAppGenerateEntity.model_construct( + task_id="task-id", + app_config=app_config, + model_conf=DummyModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user-id", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + +@pytest.fixture(autouse=True) +def _mock_db_session(monkeypatch): + session = MagicMock() + + def refresh_side_effect(obj): + if isinstance(obj, Conversation) and obj.id is None: + obj.id = "generated-conversation-id" + if isinstance(obj, Message) and obj.id is None: + obj.id = "generated-message-id" + + session.refresh.side_effect = refresh_side_effect + session.add.return_value = None + session.commit.return_value = None + + monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session)) + return session + + +def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity(): + app_config = _make_app_config(AppMode.COMPLETION) + entity = DummyCompletionGenerateEntity(app_config=app_config) + + generator = MessageBasedAppGenerator() + + conversation, message = generator._init_generate_records(entity, conversation=None) + + assert conversation.id == "generated-conversation-id" + assert message.id == "generated-message-id" + assert hasattr(entity, "conversation_id") is False + assert hasattr(entity, "is_new_conversation") is False + + +def test_init_generate_records_sets_conversation_fields_for_chat_entity(): + app_config = _make_app_config(AppMode.CHAT) + entity = _make_chat_generate_entity(app_config) + + generator = MessageBasedAppGenerator() + + conversation, _ = generator._init_generate_records(entity, conversation=None) + + assert entity.conversation_id == "generated-conversation-id" + assert entity.is_new_conversation is True + assert conversation.id == "generated-conversation-id" diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py new file mode 100644 index 0000000000..97c993928e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -0,0 +1,287 @@ +import sys +import time +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +API_DIR = str(Path(__file__).resolve().parents[5]) +if API_DIR not in sys.path: + sys.path.insert(0, API_DIR) + +import core.workflow.nodes.human_input.entities # noqa: F401 +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory +from core.workflow.entities import GraphInitParams +from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.node_events import NodeRunResult, PauseRequestedEvent +from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + +if "core.ops.ops_trace_manager" not in sys.modules: + ops_stub = ModuleType("core.ops.ops_trace_manager") + + class _StubTraceQueueManager: + def __init__(self, *_, **__): + pass + + ops_stub.TraceQueueManager = _StubTraceQueueManager + sys.modules["core.ops.ops_trace_manager"] = ops_stub + + +class _StubToolNodeData(BaseNodeData): + pause_on: bool = False + + +class _StubToolNode(Node[_StubToolNodeData]): + node_type = NodeType.TOOL + + @classmethod + def version(cls) -> str: + return "1" + + def init_node_data(self, data): + self._node_data = _StubToolNodeData.model_validate(data) + + def _get_error_strategy(self): + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self): + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + def _run(self): + if self.node_data.pause_on: + yield PauseRequestedEvent(reason=SchedulingPause(message="test pause")) + return + + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"value": f"{self.id}-done"}, + ) + yield self._convert_node_run_result_to_graph_node_event(result) + + +def _patch_tool_node(mocker): + original_create_node = DifyNodeFactory.create_node + + def _patched_create_node(self, node_config: dict[str, object]) -> Node: + node_data = node_config.get("data", {}) + if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value: + return _StubToolNode( + id=str(node_config["id"]), + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + ) + return original_create_node(self, node_config) + + mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + + +def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: + node_data = data.model_dump() + node_data["type"] = node_type.value + return node_data + + +def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: + start_data = StartNodeData(title="start", variables=[]) + tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a") + tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b") + tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c") + end_data = EndNodeData( + title="end", + outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])], + desc=None, + ) + + nodes = [ + {"id": "start", "data": _node_data(NodeType.START, start_data)}, + {"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)}, + {"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)}, + {"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)}, + {"id": "end", "data": _node_data(NodeType.END, end_data)}, + ] + edges = [ + {"source": "start", "target": "tool_a"}, + {"source": "tool_a", "target": "tool_b"}, + {"source": "tool_b", "target": "tool_c"}, + {"source": "tool_c", "target": "end"}, + ] + return {"nodes": nodes, "edges": edges} + + +def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph: + graph_config = _build_graph_config(pause_on=pause_on) + params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="service-api", + call_depth=0, + ) + + node_factory = DifyNodeFactory( + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + return Graph.init(graph_config=graph_config, node_factory=node_factory) + + +def _build_runtime_state(run_id: str) -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + user_inputs={}, + conversation_variables=[], + ) + variable_pool.system_variables.workflow_execution_id = run_id + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]: + command_channel = InMemoryChannel() + graph = _build_graph(runtime_state, pause_on=pause_on) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=command_channel, + ) + + events: list[GraphEngineEvent] = [] + for event in engine.run(): + events.append(event) + return events + + +def _node_successes(events: list[GraphEngineEvent]) -> list[str]: + return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)] + + +def test_workflow_app_pause_resume_matches_baseline(mocker): + _patch_tool_node(mocker) + + baseline_state = _build_runtime_state("baseline") + baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_nodes = _node_successes(baseline_events) + baseline_outputs = baseline_state.outputs + + paused_state = _build_runtime_state("paused-run") + paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") + assert isinstance(paused_events[-1], GraphRunPausedEvent) + paused_nodes = _node_successes(paused_events) + snapshot = paused_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + generator = wf_app_gen_module.WorkflowAppGenerator() + + def _fake_generate(**kwargs): + state: GraphRuntimeState = kwargs["graph_runtime_state"] + events = _run_with_optional_pause(state, pause_on=None) + return _node_successes(events) + + mocker.patch.object(generator, "_generate", side_effect=_fake_generate) + + resumed_nodes = generator.resume( + app_model=SimpleNamespace(mode="workflow"), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), + graph_runtime_state=resumed_state, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + ) + + assert paused_nodes + resumed_nodes == baseline_nodes + assert resumed_state.outputs == baseline_outputs + + +def test_advanced_chat_pause_resume_matches_baseline(mocker): + _patch_tool_node(mocker) + + baseline_state = _build_runtime_state("adv-baseline") + baseline_events = _run_with_optional_pause(baseline_state, pause_on=None) + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_nodes = _node_successes(baseline_events) + baseline_outputs = baseline_state.outputs + + paused_state = _build_runtime_state("adv-paused") + paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") + assert isinstance(paused_events[-1], GraphRunPausedEvent) + paused_nodes = _node_successes(paused_events) + snapshot = paused_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + generator = adv_app_gen_module.AdvancedChatAppGenerator() + + def _fake_generate(**kwargs): + state: GraphRuntimeState = kwargs["graph_runtime_state"] + events = _run_with_optional_pause(state, pause_on=None) + return _node_successes(events) + + mocker.patch.object(generator, "_generate", side_effect=_fake_generate) + + resumed_nodes = generator.resume( + app_model=SimpleNamespace(mode="workflow"), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + conversation=SimpleNamespace(id="conv"), + message=SimpleNamespace(id="msg"), + application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_runtime_state=resumed_state, + ) + + assert paused_nodes + resumed_nodes == baseline_nodes + assert resumed_state.outputs == baseline_outputs + + +def test_resume_emits_resumption_start_reason(mocker) -> None: + _patch_tool_node(mocker) + + paused_state = _build_runtime_state("resume-reason") + paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a") + initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent)) + assert initial_start.reason == WorkflowStartReason.INITIAL + + resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps()) + resumed_events = _run_with_optional_pause(resumed_state, pause_on=None) + resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent)) + assert resume_start.reason == WorkflowStartReason.RESUMPTION diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py new file mode 100644 index 0000000000..7b5447c01e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import queue + +import pytest + +from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.entities.task_entities import StreamEvent +from models.model import AppMode + + +class FakeSubscription: + def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None: + self._queue = message_queue + self._state = state + self._closed = False + + def __enter__(self): + self._state["subscribed"] = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def close(self) -> None: + self._closed = True + + def receive(self, timeout: float | None = 0.1) -> bytes | None: + if self._closed: + return None + try: + if timeout is None: + return self._queue.get() + return self._queue.get(timeout=timeout) + except queue.Empty: + return None + + +class FakeTopic: + def __init__(self) -> None: + self._queue: queue.Queue[bytes] = queue.Queue() + self._state = {"subscribed": False} + + def subscribe(self) -> FakeSubscription: + return FakeSubscription(self._queue, self._state) + + def publish(self, payload: bytes) -> None: + self._queue.put(payload) + + @property + def subscribed(self) -> bool: + return self._state["subscribed"] + + +def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch): + topic = FakeTopic() + + def fake_get_response_topic(cls, app_mode, workflow_run_id): + return topic + + monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic)) + + def on_subscribe() -> None: + assert topic.subscribed is True + event = {"event": StreamEvent.WORKFLOW_FINISHED.value} + topic.publish(json.dumps(event).encode()) + + generator = MessageBasedAppGenerator.retrieve_events( + AppMode.WORKFLOW, + "workflow-run-id", + idle_timeout=0.5, + on_subscribe=on_subscribe, + ) + + assert next(generator) == StreamEvent.PING.value + event = next(generator) + assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value + with pytest.raises(StopIteration): + next(generator) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index 83ac3a5591..7e8367c6c4 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -1,3 +1,6 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator @@ -17,3 +20,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false(): args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False} assert WorkflowAppGenerator()._should_prepare_user_inputs(args) + + +def test_resume_delegates_to_generate(mocker): + generator = WorkflowAppGenerator() + mock_generate = mocker.patch.object(generator, "_generate", return_value="ok") + + application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger") + runtime_state = MagicMock(name="runtime-state") + pause_config = MagicMock(name="pause-config") + + result = generator.resume( + app_model=MagicMock(), + workflow=MagicMock(), + user=MagicMock(), + application_generate_entity=application_generate_entity, + graph_runtime_state=runtime_state, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + graph_engine_layers=("layer",), + pause_state_config=pause_config, + variable_loader=MagicMock(), + ) + + assert result == "ok" + mock_generate.assert_called_once() + kwargs = mock_generate.call_args.kwargs + assert kwargs["graph_runtime_state"] is runtime_state + assert kwargs["pause_state_config"] is pause_config + assert kwargs["streaming"] is False + assert kwargs["invoke_from"] == "debugger" + + +def test_generate_appends_pause_layer_and_forwards_state(mocker): + generator = WorkflowAppGenerator() + + mock_queue_manager = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager) + + fake_current_app = MagicMock() + fake_current_app._get_current_object.return_value = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app) + + mocker.patch( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", + return_value="converted", + ) + mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response") + mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock()) + + pause_layer = MagicMock(name="pause-layer") + mocker.patch( + "core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", + return_value=pause_layer, + ) + + dummy_session = MagicMock() + dummy_session.close = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session) + + worker_kwargs: dict[str, object] = {} + + class DummyThread: + def __init__(self, target, kwargs): + worker_kwargs["target"] = target + worker_kwargs["kwargs"] = kwargs + + def start(self): + return None + + mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread) + + app_model = SimpleNamespace(mode="workflow") + app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf") + application_generate_entity = SimpleNamespace( + task_id="task", + user_id="user", + invoke_from="service-api", + app_config=app_config, + files=[], + stream=True, + workflow_execution_id="run", + ) + + graph_runtime_state = MagicMock() + + result = generator._generate( + app_model=app_model, + workflow=MagicMock(), + user=MagicMock(), + application_generate_entity=application_generate_entity, + invoke_from="service-api", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + graph_engine_layers=("base-layer",), + graph_runtime_state=graph_runtime_state, + pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"), + ) + + assert result == "converted" + assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer) + assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state + + +def test_resume_path_runs_worker_with_runtime_state(mocker): + generator = WorkflowAppGenerator() + runtime_state = MagicMock(name="runtime-state") + + pause_layer = MagicMock(name="pause-layer") + mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer) + + queue_manager = MagicMock() + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager) + + mocker.patch.object(generator, "_handle_response", return_value="raw-response") + mocker.patch( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert", + side_effect=lambda response, invoke_from: response, + ) + + fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock()) + mocker.patch("core.app.apps.workflow.app_generator.db", fake_db) + + workflow = SimpleNamespace( + id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1" + ) + end_user = SimpleNamespace(session_id="end-user-session") + app_record = SimpleNamespace(id="app") + + session = MagicMock() + session.__enter__.return_value = session + session.__exit__.return_value = False + session.scalar.side_effect = [workflow, end_user, app_record] + mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session) + + runner_instance = MagicMock() + + def runner_ctor(**kwargs): + assert kwargs["graph_runtime_state"] is runtime_state + return runner_instance + + mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor) + + class ImmediateThread: + def __init__(self, target, kwargs): + target(**kwargs) + + def start(self): + return None + + mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread) + + mocker.patch( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner") + + app_model = SimpleNamespace(mode="workflow") + app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow") + application_generate_entity = SimpleNamespace( + task_id="task", + user_id="user", + invoke_from="service-api", + app_config=app_config, + files=[], + stream=True, + workflow_execution_id="run", + trace_manager=MagicMock(), + ) + + result = generator.resume( + app_model=app_model, + workflow=workflow, + user=MagicMock(), + application_generate_entity=application_generate_entity, + graph_runtime_state=runtime_state, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + pause_state_config=pause_config, + ) + + assert result == "raw-response" + runner_instance.run.assert_called_once() + queue_manager.graph_runtime_state = runtime_state diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py new file mode 100644 index 0000000000..f4efb240c0 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.graph_events.graph import GraphRunPausedEvent + + +class _DummyQueueManager: + def __init__(self): + self.published = [] + + def publish(self, event, _from): + self.published.append(event) + + +class _DummyRuntimeState: + def get_paused_nodes(self): + return ["node-1"] + + +class _DummyGraphEngine: + def __init__(self): + self.graph_runtime_state = _DummyRuntimeState() + + +class _DummyWorkflowEntry: + def __init__(self): + self.graph_engine = _DummyGraphEngine() + + +def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch): + queue_manager = _DummyQueueManager() + runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id") + workflow_entry = _DummyWorkflowEntry() + + reason = HumanInputRequired( + form_id="form-123", + form_content="content", + inputs=[], + actions=[], + node_id="node-1", + node_title="Review", + ) + event = GraphRunPausedEvent(reasons=[reason], outputs={}) + + email_task = MagicMock() + monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task) + + runner._handle_event(workflow_entry, event) + + email_task.apply_async.assert_called_once() + kwargs = email_task.apply_async.call_args.kwargs["kwargs"] + assert kwargs["form_id"] == "form-123" + assert kwargs["node_title"] == "Review" + + assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py new file mode 100644 index 0000000000..c30b925d88 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -0,0 +1,183 @@ +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.apps.common import workflow_response_converter +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph_events.graph import GraphRunPausedEvent +from core.workflow.nodes.human_input.entities import FormInput, UserAction +from core.workflow.nodes.human_input.enums import FormInputType +from core.workflow.system_variable import SystemVariable +from models.account import Account + + +class _RecordingWorkflowAppRunner(WorkflowAppRunner): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.published_events = [] + + def _publish_event(self, event): + self.published_events.append(event) + + +class _FakeRuntimeState: + def get_paused_nodes(self): + return ["node-pause-1"] + + +def _build_runner(): + app_entity = SimpleNamespace( + app_config=SimpleNamespace(app_id="app-id"), + inputs={}, + files=[], + invoke_from=InvokeFrom.SERVICE_API, + single_iteration_run=None, + single_loop_run=None, + workflow_execution_id="run-id", + user_id="user-id", + ) + workflow = SimpleNamespace( + graph_dict={}, + tenant_id="tenant-id", + environment_variables={}, + id="workflow-id", + ) + queue_manager = SimpleNamespace(publish=lambda event, pub_from: None) + return _RecordingWorkflowAppRunner( + application_generate_entity=app_entity, + queue_manager=queue_manager, + variable_loader=MagicMock(), + workflow=workflow, + system_user_id="sys-user", + root_node_id=None, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + +def test_graph_run_paused_event_emits_queue_pause_event(): + runner = _build_runner() + reason = HumanInputRequired( + form_id="form-1", + form_content="content", + inputs=[], + actions=[], + node_id="node-human", + node_title="Human Step", + form_token="tok", + ) + event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) + workflow_entry = SimpleNamespace( + graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()), + ) + + runner._handle_event(workflow_entry, event) + + assert len(runner.published_events) == 1 + queue_event = runner.published_events[0] + assert isinstance(queue_event, QueueWorkflowPausedEvent) + assert queue_event.reasons == [reason] + assert queue_event.outputs == {"foo": "bar"} + assert queue_event.paused_nodes == ["node-pause-1"] + + +def _build_converter(): + application_generate_entity = SimpleNamespace( + inputs={}, + files=[], + invoke_from=InvokeFrom.SERVICE_API, + app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), + ) + system_variables = SystemVariable( + user_id="user", + app_id="app-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + ) + user = MagicMock(spec=Account) + user.id = "account-id" + user.name = "Tester" + user.email = "tester@example.com" + return WorkflowResponseConverter( + application_generate_entity=application_generate_entity, + user=user, + system_variables=system_variables, + ) + + +def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch): + converter = _build_converter() + converter.workflow_start_to_stream_response( + task_id="task", + workflow_run_id="run-id", + workflow_id="workflow-id", + reason=WorkflowStartReason.INITIAL, + ) + + expiration_time = datetime(2024, 1, 1, tzinfo=UTC) + + class _FakeSession: + def execute(self, _stmt): + return [("form-1", expiration_time)] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession()) + monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object())) + + reason = HumanInputRequired( + form_id="form-1", + form_content="Rendered", + inputs=[ + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), + ], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + node_id="node-id", + node_title="Human Step", + form_token="token", + ) + queue_event = QueueWorkflowPausedEvent( + reasons=[reason], + outputs={"answer": "value"}, + paused_nodes=["node-id"], + ) + + runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0) + responses = converter.workflow_pause_to_stream_response( + event=queue_event, + task_id="task", + graph_runtime_state=runtime_state, + ) + + assert isinstance(responses[-1], WorkflowPauseStreamResponse) + pause_resp = responses[-1] + assert pause_resp.workflow_run_id == "run-id" + assert pause_resp.data.paused_nodes == ["node-id"] + assert pause_resp.data.outputs == {} + assert pause_resp.data.reasons[0]["form_id"] == "form-1" + assert pause_resp.data.reasons[0]["display_in_ui"] is True + + assert isinstance(responses[0], HumanInputRequiredResponse) + hi_resp = responses[0] + assert hi_resp.data.form_id == "form-1" + assert hi_resp.data.node_id == "node-id" + assert hi_resp.data.node_title == "Human Step" + assert hi_resp.data.inputs[0].output_variable_name == "field" + assert hi_resp.data.actions[0].id == "approve" + assert hi_resp.data.display_in_ui is True + assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py new file mode 100644 index 0000000000..32cb1ed47c --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -0,0 +1,96 @@ +import time +from contextlib import contextmanager +from unittest.mock import MagicMock + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import QueueWorkflowStartedEvent +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from models.account import Account +from models.model import AppMode + + +def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant-id", + app_id="app-id", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-id", + ) + + +def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: + return WorkflowAppGenerateEntity( + task_id="task-id", + app_config=_build_workflow_app_config(), + inputs={}, + files=[], + user_id="user-id", + stream=False, + invoke_from=InvokeFrom.SERVICE_API, + workflow_execution_id=run_id, + ) + + +def _build_runtime_state(run_id: str) -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable(workflow_execution_id=run_id), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +@contextmanager +def _noop_session(): + yield MagicMock() + + +def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline: + queue_manager = MagicMock(spec=AppQueueManager) + queue_manager.invoke_from = InvokeFrom.SERVICE_API + queue_manager.graph_runtime_state = _build_runtime_state(run_id) + workflow = MagicMock() + workflow.id = "workflow-id" + workflow.features_dict = {} + user = Account(name="user", email="user@example.com") + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=_build_generate_entity(run_id), + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=False, + draft_var_saver_factory=MagicMock(), + ) + pipeline._database_session = _noop_session + return pipeline + + +def test_workflow_app_log_saved_only_on_initial_start() -> None: + run_id = "run-initial" + pipeline = _build_pipeline(run_id) + pipeline._save_workflow_app_log = MagicMock() + + event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL) + list(pipeline._handle_workflow_started_event(event)) + + pipeline._save_workflow_app_log.assert_called_once() + _, kwargs = pipeline._save_workflow_app_log.call_args + assert kwargs["workflow_run_id"] == run_id + assert pipeline._workflow_execution_id == run_id + + +def test_workflow_app_log_skipped_on_resumption_start() -> None: + run_id = "run-resume" + pipeline = _build_pipeline(run_id) + pipeline._save_workflow_app_log = MagicMock() + + event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION) + list(pipeline._handle_workflow_started_event(event)) + + pipeline._save_workflow_app_log.assert_not_called() + assert pipeline._workflow_execution_id == run_id diff --git a/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py new file mode 100644 index 0000000000..86c80985c4 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_app_invoke_entities.py @@ -0,0 +1,143 @@ +import json +from collections.abc import Callable +from dataclasses import dataclass + +import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + InvokeFrom, + WorkflowAppGenerateEntity, +) +from core.app.layers.pause_state_persist_layer import ( + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TraceQueueManagerStub(TraceQueueManager): + """Minimal TraceQueueManager stub that avoids Flask dependencies.""" + + def __init__(self): + # Skip parent initialization to avoid starting timers or accessing Flask globals. + pass + + +def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant-id", + app_id="app-id", + app_mode=app_mode, + workflow_id=f"{app_mode.value}-workflow-id", + ) + + +def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity: + return WorkflowAppGenerateEntity( + task_id="workflow-task", + app_config=_build_workflow_app_config(AppMode.WORKFLOW), + inputs={"topic": "serialization"}, + files=[], + user_id="user-workflow", + stream=True, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=1, + trace_manager=trace_manager, + workflow_execution_id="workflow-exec-id", + extras={"external_trace_id": "trace-id"}, + ) + + +def _create_advanced_chat_generate_entity( + trace_manager: TraceQueueManager | None = None, +) -> AdvancedChatAppGenerateEntity: + return AdvancedChatAppGenerateEntity( + task_id="advanced-task", + app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT), + conversation_id="conversation-id", + inputs={"topic": "roundtrip"}, + files=[], + user_id="user-advanced", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + query="Explain serialization", + extras={"auto_generate_conversation_name": True}, + trace_manager=trace_manager, + workflow_run_id="workflow-run-id", + ) + + +def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager(): + entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub()) + + serialized = entity.model_dump_json() + payload = json.loads(serialized) + + assert "trace_manager" not in payload + + restored = WorkflowAppGenerateEntity.model_validate_json(serialized) + + assert restored.model_dump() == entity.model_dump() + assert restored.trace_manager is None + + +def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager(): + entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub()) + + serialized = entity.model_dump_json() + payload = json.loads(serialized) + + assert "trace_manager" not in payload + + restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized) + + assert restored.model_dump() == entity.model_dump() + assert restored.trace_manager is None + + +@dataclass(frozen=True) +class ResumptionContextCase: + name: str + context_factory: Callable[[], tuple[WorkflowResumptionContext, type]] + + +def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]: + entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub()) + context = WorkflowResumptionContext( + serialized_graph_runtime_state=json.dumps({"state": "workflow"}), + generate_entity=_WorkflowGenerateEntityWrapper(entity=entity), + ) + return context, WorkflowAppGenerateEntity + + +def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]: + entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub()) + context = WorkflowResumptionContext( + serialized_graph_runtime_state=json.dumps({"state": "advanced"}), + generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity), + ) + return context, AdvancedChatAppGenerateEntity + + +@pytest.mark.parametrize( + "case", + [ + pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"), + pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"), + ], +) +def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase): + context, expected_type = case.context_factory() + + serialized = context.dumps() + restored = WorkflowResumptionContext.loads(serialized) + + assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state + entity = restored.get_generate_entity() + assert isinstance(entity, expected_type) + assert entity.model_dump() == context.get_generate_entity().model_dump() + assert entity.trace_manager is None diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py new file mode 100644 index 0000000000..a380149554 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -0,0 +1,72 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig +from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation +from models.model import AppMode + + +def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" + + app = MagicMock() + app.mode = AppMode.ADVANCED_CHAT + app.workflow = workflow + + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", + return_value={"result": "ok"}, + ) + + result = PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={"k": "v"}, + files=[], + ) + + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" + + +def test_invoke_workflow_app_injects_pause_state_config(mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" + + app = MagicMock() + app.mode = AppMode.WORKFLOW + app.workflow = workflow + + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", + return_value={"result": "ok"}, + ) + + result = PluginAppBackwardsInvocation.invoke_workflow_app( + app=app, + user=MagicMock(), + stream=False, + inputs={"k": "v"}, + files=[], + ) + + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py new file mode 100644 index 0000000000..811ed2143b --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -0,0 +1,574 @@ +"""Unit tests for HumanInputFormRepositoryImpl private helpers.""" + +from __future__ import annotations + +import dataclasses +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormRepositoryImpl, + HumanInputFormSubmissionRepository, + _WorkspaceMemberInfo, +) +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + FormDefinition, + MemberRecipient, + UserAction, +) +from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from libs.datetime_utils import naive_utc_now +from models.human_input import ( + EmailExternalRecipientPayload, + EmailMemberRecipientPayload, + HumanInputFormRecipient, + RecipientType, +) + + +def _build_repository() -> HumanInputFormRepositoryImpl: + return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") + + +def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: + created: list[SimpleNamespace] = [] + + def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def] + recipient = SimpleNamespace( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + ) + created.append(recipient) + return recipient + + monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new)) + return created + + +@pytest.fixture(autouse=True) +def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None: + """Avoid SQLAlchemy mapper configuration in tests using fake sessions.""" + + class _FakeSelect: + def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def] + return self + + def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def] + return self + + monkeypatch.setattr( + "core.repositories.human_input_repository.selectinload", lambda *args, **kwargs: "_loader_option" + ) + monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *args, **kwargs: _FakeSelect()) + + +class TestHumanInputFormRepositoryImplHelpers: + def test_build_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert restrict_to_user_ids == ["member-1"] + return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(user_id="member-1"), + ExternalRecipient(email="external@example.com"), + ], + ), + ) + + assert len(recipients) == 2 + member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER) + external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL) + + member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload) + assert member_payload.user_id == "member-1" + assert member_payload.email == "member@example.com" + + external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload) + assert external_payload.email == "external@example.com" + + def test_build_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + created = _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert restrict_to_user_ids == ["missing-member"] + return [] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(user_id="missing-member"), + ExternalRecipient(email="external@example.com"), + ], + ), + ) + + assert len(recipients) == 1 + assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL + assert len(created) == 1 # only external recipient created via factory + + def test_build_email_recipients_whole_workspace_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session): # type: ignore[no-untyped-def] + assert session is session_stub + return [ + _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), + _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), + ] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=True, + items=[], + ), + ) + + assert len(recipients) == 2 + emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients} + assert emails == {"member1@example.com", "member2@example.com"} + + def test_build_email_recipients_dedupes_external_by_email(self, monkeypatch: pytest.MonkeyPatch) -> None: + repo = _build_repository() + session_stub = object() + created = _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert restrict_to_user_ids == [] + return [] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[ + ExternalRecipient(email="external@example.com"), + ExternalRecipient(email="external@example.com"), + ], + ), + ) + + assert len(recipients) == 1 + assert len(created) == 1 + + def test_build_email_recipients_prefers_member_over_external_by_email( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def] + assert session is session_stub + assert restrict_to_user_ids == ["member-1"] + return [_WorkspaceMemberInfo(user_id="member-1", email="shared@example.com")] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query) + + recipients = repo._build_email_recipients( + session=session_stub, + form_id="form-id", + delivery_id="delivery-id", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(user_id="member-1"), + ExternalRecipient(email="shared@example.com"), + ], + ), + ) + + assert len(recipients) == 1 + assert recipients[0].recipient_type == RecipientType.EMAIL_MEMBER + + def test_delivery_method_to_model_includes_external_recipients_with_whole_workspace( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + repo = _build_repository() + session_stub = object() + _patch_recipient_factory(monkeypatch) + + def fake_query(self, session): # type: ignore[no-untyped-def] + assert session is session_stub + return [ + _WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"), + _WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"), + ] + + monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query) + + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=True, + items=[ExternalRecipient(email="external@example.com")], + ), + subject="subject", + body="body", + ) + ) + + result = repo._delivery_method_to_model(session=session_stub, form_id="form-id", delivery_method=method) + + assert len(result.recipients) == 3 + member_emails = { + EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email + for r in result.recipients + if r.recipient_type == RecipientType.EMAIL_MEMBER + } + assert member_emails == {"member1@example.com", "member2@example.com"} + external_payload = EmailExternalRecipientPayload.model_validate_json( + next(r for r in result.recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL).recipient_payload + ) + assert external_payload.email == "external@example.com" + + +def _make_form_definition() -> str: + return FormDefinition( + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + rendered_content="hello
", + expiration_time=datetime.utcnow(), + ).model_dump_json() + + +@dataclasses.dataclass +class _DummyForm: + id: str + workflow_run_id: str + node_id: str + tenant_id: str + app_id: str + form_definition: str + rendered_content: str + expiration_time: datetime + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + created_at: datetime = dataclasses.field(default_factory=naive_utc_now) + selected_action_id: str | None = None + submitted_data: str | None = None + submitted_at: datetime | None = None + submission_user_id: str | None = None + submission_end_user_id: str | None = None + completed_by_recipient_id: str | None = None + status: HumanInputFormStatus = HumanInputFormStatus.WAITING + + +@dataclasses.dataclass +class _DummyRecipient: + id: str + form_id: str + recipient_type: RecipientType + access_token: str + form: _DummyForm | None = None + + +class _FakeScalarResult: + def __init__(self, obj): + self._obj = obj + + def first(self): + if isinstance(self._obj, list): + return self._obj[0] if self._obj else None + return self._obj + + def all(self): + if isinstance(self._obj, list): + return list(self._obj) + if self._obj is None: + return [] + return [self._obj] + + +class _FakeSession: + def __init__( + self, + *, + scalars_result=None, + scalars_results: list[object] | None = None, + forms: dict[str, _DummyForm] | None = None, + recipients: dict[str, _DummyRecipient] | None = None, + ): + if scalars_results is not None: + self._scalars_queue = list(scalars_results) + elif scalars_result is not None: + self._scalars_queue = [scalars_result] + else: + self._scalars_queue = [] + self.forms = forms or {} + self.recipients = recipients or {} + + def scalars(self, _query): + if self._scalars_queue: + result = self._scalars_queue.pop(0) + else: + result = None + return _FakeScalarResult(result) + + def get(self, model_cls, obj_id): # type: ignore[no-untyped-def] + if getattr(model_cls, "__name__", None) == "HumanInputForm": + return self.forms.get(obj_id) + if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient": + return self.recipients.get(obj_id) + return None + + def add(self, _obj): + return None + + def flush(self): + return None + + def refresh(self, _obj): + return None + + def begin(self): + return self + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +def _session_factory(session: _FakeSession): + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return None + + def _factory(*_args, **_kwargs): + return _SessionContext() + + return _factory + + +class TestHumanInputFormRepositoryImplPublicMethods: + def test_get_form_returns_entity_and_recipients(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + app_id="app-id", + form_definition=_make_form_definition(), + rendered_content="hello
", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token-123", + ) + session = _FakeSession(scalars_results=[form, [recipient]]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + entity = repo.get_form(form.workflow_run_id, form.node_id) + + assert entity is not None + assert entity.id == form.id + assert entity.web_app_token == "token-123" + assert len(entity.recipients) == 1 + assert entity.recipients[0].token == "token-123" + + def test_get_form_returns_none_when_missing(self): + session = _FakeSession(scalars_results=[None]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + assert repo.get_form("run-1", "node-1") is None + + def test_get_form_returns_unsubmitted_state(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + app_id="app-id", + form_definition=_make_form_definition(), + rendered_content="hello
", + expiration_time=naive_utc_now(), + ) + session = _FakeSession(scalars_results=[form, []]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + entity = repo.get_form(form.workflow_run_id, form.node_id) + + assert entity is not None + assert entity.submitted is False + assert entity.selected_action_id is None + assert entity.submitted_data is None + + def test_get_form_returns_submission_when_completed(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + app_id="app-id", + form_definition=_make_form_definition(), + rendered_content="hello
", + expiration_time=naive_utc_now(), + selected_action_id="approve", + submitted_data='{"field": "value"}', + submitted_at=naive_utc_now(), + ) + session = _FakeSession(scalars_results=[form, []]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + entity = repo.get_form(form.workflow_run_id, form.node_id) + + assert entity is not None + assert entity.submitted is True + assert entity.selected_action_id == "approve" + assert entity.submitted_data == {"field": "value"} + + +class TestHumanInputFormSubmissionRepository: + def test_get_by_token_returns_record(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + app_id="app-1", + form_definition=_make_form_definition(), + rendered_content="hello
", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token-123", + form=form, + ) + session = _FakeSession(scalars_result=recipient) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) + + record = repo.get_by_token("token-123") + + assert record is not None + assert record.form_id == form.id + assert record.recipient_type == RecipientType.STANDALONE_WEB_APP + assert record.submitted is False + + def test_get_by_form_id_and_recipient_type_uses_recipient(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + app_id="app-1", + form_definition=_make_form_definition(), + rendered_content="hello
", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token-123", + form=form, + ) + session = _FakeSession(scalars_result=recipient) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) + + record = repo.get_by_form_id_and_recipient_type( + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + ) + + assert record is not None + assert record.recipient_id == recipient.id + assert record.access_token == recipient.access_token + + def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch): + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-1", + app_id="app-1", + form_definition=_make_form_definition(), + rendered_content="hello
", + expiration_time=fixed_now, + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id="form-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token-123", + ) + session = _FakeSession( + forms={form.id: form}, + recipients={recipient.id: recipient}, + ) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) + + record: HumanInputFormRecord = repo.mark_submitted( + form_id=form.id, + recipient_id=recipient.id, + selected_action_id="approve", + form_data={"field": "value"}, + submission_user_id="user-1", + submission_end_user_id="end-user-1", + ) + + assert form.selected_action_id == "approve" + assert form.completed_by_recipient_id == recipient.id + assert form.submission_user_id == "user-1" + assert form.submission_end_user_id == "end-user-1" + assert form.submitted_at == fixed_now + assert record.submitted is True + assert record.selected_action_id == "approve" + assert record.submitted_data == {"field": "value"} diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py new file mode 100644 index 0000000000..c46e31d90f --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -0,0 +1,33 @@ +import pytest + +from core.tools.errors import WorkflowToolHumanInputNotSupportedError +from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils + + +def test_ensure_no_human_input_nodes_passes_for_non_human_input(): + graph = { + "nodes": [ + { + "id": "start_node", + "data": {"type": "start"}, + } + ] + } + + WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) + + +def test_ensure_no_human_input_nodes_raises_for_human_input(): + graph = { + "nodes": [ + { + "id": "human_input_node", + "data": {"type": "human-input"}, + } + ] + } + + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index cd45292488..bbedfdb6ae 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -55,6 +55,43 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel assert exc_info.value.args == ("oops",) +def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch): + entity = ToolEntity( + identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), + parameters=[], + description=None, + has_runtime_parameters=False, + ) + runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) + tool = WorkflowTool( + workflow_app_id="", + workflow_as_tool_id="", + version="1", + workflow_entities={}, + workflow_call_depth=1, + entity=entity, + runtime=runtime, + ) + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + from unittest.mock import MagicMock, Mock + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert "pause_state_config" in call_kwargs + assert call_kwargs["pause_state_config"] is None + + def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should generate variable messages when there are outputs""" entity = ToolEntity( diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index deff06fc5d..1b6d03e36a 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -118,7 +118,6 @@ class TestGraphRuntimeState: from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue assert isinstance(queue, InMemoryReadyQueue) - assert state.ready_queue is queue def test_graph_execution_lazy_instantiation(self): state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py new file mode 100644 index 0000000000..6144df06e0 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -0,0 +1,88 @@ +""" +Tests for PauseReason discriminated union serialization/deserialization. +""" + +import pytest +from pydantic import BaseModel, ValidationError + +from core.workflow.entities.pause_reason import ( + HumanInputRequired, + PauseReason, + SchedulingPause, +) + + +class _Holder(BaseModel): + """Helper model that embeds PauseReason for union tests.""" + + reason: PauseReason + + +class TestPauseReasonDiscriminator: + """Test suite for PauseReason union discriminator.""" + + @pytest.mark.parametrize( + ("dict_value", "expected"), + [ + pytest.param( + { + "reason": { + "TYPE": "human_input_required", + "form_id": "form_id", + "form_content": "form_content", + "node_id": "node_id", + "node_title": "node_title", + }, + }, + HumanInputRequired( + form_id="form_id", + form_content="form_content", + node_id="node_id", + node_title="node_title", + ), + id="HumanInputRequired", + ), + pytest.param( + { + "reason": { + "TYPE": "scheduled_pause", + "message": "Hold on", + } + }, + SchedulingPause(message="Hold on"), + id="SchedulingPause", + ), + ], + ) + def test_model_validate(self, dict_value, expected): + """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" + holder = _Holder.model_validate(dict_value) + + assert type(holder.reason) == type(expected) + assert holder.reason == expected + + @pytest.mark.parametrize( + "reason", + [ + HumanInputRequired( + form_id="form_id", + form_content="form_content", + node_id="node_id", + node_title="node_title", + ), + SchedulingPause(message="Hold on"), + ], + ids=lambda x: type(x).__name__, + ) + def test_model_construct(self, reason): + holder = _Holder(reason=reason) + assert holder.reason == reason + + def test_model_construct_with_invalid_type(self): + with pytest.raises(ValidationError): + holder = _Holder(reason=object()) # type: ignore + + def test_unknown_type_fails_validation(self): + """Unknown TYPE values should raise a validation error.""" + with pytest.raises(ValidationError): + _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py new file mode 100644 index 0000000000..2ef23c7f0f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -0,0 +1,131 @@ +"""Utilities for testing HumanInputNode without database dependencies.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) +from libs.datetime_utils import naive_utc_now + + +class _InMemoryFormRecipient(HumanInputFormRecipientEntity): + """Minimal recipient entity required by the repository interface.""" + + def __init__(self, recipient_id: str, token: str) -> None: + self._id = recipient_id + self._token = token + + @property + def id(self) -> str: + return self._id + + @property + def token(self) -> str: + return self._token + + +@dataclass +class _InMemoryFormEntity(HumanInputFormEntity): + form_id: str + rendered: str + token: str | None = None + action_id: str | None = None + data: Mapping[str, Any] | None = None + is_submitted: bool = False + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = naive_utc_now() + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return self.token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class InMemoryHumanInputFormRepository(HumanInputFormRepository): + """Pure in-memory repository used by workflow graph engine tests.""" + + def __init__(self) -> None: + self._form_counter = 0 + self.created_params: list[FormCreateParams] = [] + self.created_forms: list[_InMemoryFormEntity] = [] + self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + self.created_params.append(params) + self._form_counter += 1 + form_id = f"form-{self._form_counter}" + token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" + entity = _InMemoryFormEntity( + form_id=form_id, + rendered=params.rendered_content, + token=token, + ) + self.created_forms.append(entity) + self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity + return entity + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_key.get((workflow_execution_id, node_id)) + + # Convenience helpers for tests ------------------------------------- + + def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: + """Simulate a human submission for the next repository lookup.""" + + if not self.created_forms: + raise AssertionError("no form has been created to attach submission data") + entity = self.created_forms[-1] + entity.action_id = action_id + entity.data = form_data or {} + entity.is_submitted = True + entity.status_value = HumanInputFormStatus.SUBMITTED + entity.expiration = naive_utc_now() + timedelta(days=1) + + def clear_submission(self) -> None: + if not self.created_forms: + return + for form in self.created_forms: + form.action_id = None + form.data = None + form.is_submitted = False + form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py new file mode 100644 index 0000000000..6038a15211 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py @@ -0,0 +1,74 @@ +import queue +import threading +from datetime import datetime + +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher +from core.workflow.graph_events import NodeRunSucceededEvent +from core.workflow.node_events import NodeRunResult + + +class StubExecutionCoordinator: + def __init__(self, paused: bool) -> None: + self._paused = paused + self.mark_complete_called = False + self.failed_error: Exception | None = None + + @property + def aborted(self) -> bool: + return False + + @property + def paused(self) -> bool: + return self._paused + + @property + def execution_complete(self) -> bool: + return False + + def check_scaling(self) -> None: + return None + + def process_commands(self) -> None: + return None + + def mark_complete(self) -> None: + self.mark_complete_called = True + + def mark_failed(self, error: Exception) -> None: + self.failed_error = error + + +class StubEventHandler: + def __init__(self) -> None: + self.events: list[object] = [] + + def dispatch(self, event: object) -> None: + self.events.append(event) + + +def test_dispatcher_drains_events_when_paused() -> None: + event_queue: queue.Queue = queue.Queue() + event = NodeRunSucceededEvent( + id="exec-1", + node_id="node-1", + node_type=NodeType.START, + start_at=datetime.utcnow(), + node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), + ) + event_queue.put(event) + + handler = StubEventHandler() + coordinator = StubExecutionCoordinator(paused=True) + dispatcher = Dispatcher( + event_queue=event_queue, + event_handler=handler, + execution_coordinator=coordinator, + event_emitter=None, + stop_event=threading.Event(), + ) + + dispatcher._dispatcher_loop() + + assert handler.events == [event] + assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py index 0d67a76169..53de8908a8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock +import pytest + from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor from core.workflow.graph_engine.domain.graph_execution import GraphExecution from core.workflow.graph_engine.graph_state_manager import GraphStateManager @@ -48,3 +50,13 @@ def test_handle_pause_noop_when_execution_running() -> None: worker_pool.stop.assert_not_called() state_manager.clear_executing.assert_not_called() + + +def test_has_executing_nodes_requires_pause() -> None: + graph_execution = GraphExecution(workflow_id="workflow") + graph_execution.start() + + coordinator, _, _ = _build_coordinator(graph_execution) + + with pytest.raises(AssertionError): + coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py new file mode 100644 index 0000000000..65d34c2009 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -0,0 +1,189 @@ +import time +from collections.abc import Mapping + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.enums import NodeState +from core.workflow.graph import Graph +from core.workflow.graph_engine.graph_state_manager import GraphStateManager +from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + +from .test_mock_config import MockConfig +from .test_mock_nodes import MockLLMNode + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_llm_node( + *, + node_id: str, + runtime_state: GraphRuntimeState, + graph_init_params: GraphInitParams, + mock_config: MockConfig, +) -> MockLLMNode: + llm_data = LLMNodeData( + title=f"LLM {node_id}", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text=f"Prompt {node_id}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + ) + llm_config = {"id": node_id, "data": llm_data.model_dump()} + return MockLLMNode( + id=llm_config["id"], + config=llm_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + mock_config=mock_config, + ) + + +def _build_graph(runtime_state: GraphRuntimeState) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + mock_config = MockConfig() + llm_a = _build_llm_node( + node_id="llm_a", + runtime_state=runtime_state, + graph_init_params=graph_init_params, + mock_config=mock_config, + ) + llm_b = _build_llm_node( + node_id="llm_b", + runtime_state=runtime_state, + graph_init_params=graph_init_params, + mock_config=mock_config, + ) + + end_data = EndNodeData(title="End", outputs=[], desc=None) + end_config = {"id": "end", "data": end_data.model_dump()} + end_node = EndNode( + id=end_config["id"], + config=end_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + builder = ( + Graph.new() + .add_root(start_node) + .add_node(llm_a, from_node_id="start") + .add_node(llm_b, from_node_id="start") + .add_node(end_node, from_node_id="llm_a") + ) + return builder.connect(tail="llm_b", head="end").build() + + +def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: + return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} + + +def test_runtime_state_snapshot_restores_graph_states() -> None: + runtime_state = _build_runtime_state() + graph = _build_graph(runtime_state) + runtime_state.attach_graph(graph) + + graph.nodes["llm_a"].state = NodeState.TAKEN + graph.nodes["llm_b"].state = NodeState.SKIPPED + + for edge in graph.edges.values(): + if edge.tail == "start" and edge.head == "llm_a": + edge.state = NodeState.TAKEN + elif edge.tail == "start" and edge.head == "llm_b": + edge.state = NodeState.SKIPPED + elif edge.head == "end" and edge.tail == "llm_a": + edge.state = NodeState.TAKEN + elif edge.head == "end" and edge.tail == "llm_b": + edge.state = NodeState.SKIPPED + + snapshot = runtime_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resumed_graph = _build_graph(resumed_state) + resumed_state.attach_graph(resumed_graph) + + assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN + assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED + assert _edge_state_map(resumed_graph) == _edge_state_map(graph) + + +def test_join_readiness_uses_restored_edge_states() -> None: + runtime_state = _build_runtime_state() + graph = _build_graph(runtime_state) + runtime_state.attach_graph(graph) + + ready_queue = InMemoryReadyQueue() + state_manager = GraphStateManager(graph, ready_queue) + + for edge in graph.get_incoming_edges("end"): + if edge.tail == "llm_a": + edge.state = NodeState.TAKEN + if edge.tail == "llm_b": + edge.state = NodeState.UNKNOWN + + assert state_manager.is_node_ready("end") is False + + for edge in graph.get_incoming_edges("end"): + if edge.tail == "llm_b": + edge.state = NodeState.TAKEN + + assert state_manager.is_node_ready("end") is True + + snapshot = runtime_state.dumps() + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resumed_graph = _build_graph(resumed_state) + resumed_state.attach_graph(resumed_graph) + + resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) + assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index c398e4e8c1..194d009288 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -1,5 +1,7 @@ +import datetime import time from collections.abc import Iterable +from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -14,11 +16,12 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) +from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input import HumanInputNode -from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -28,15 +31,21 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: +def _build_branching_graph( + mock_config: MockConfig, + form_repository: HumanInputFormRepository, + graph_runtime_state: GraphRuntimeState | None = None, +) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -49,12 +58,18 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + if graph_runtime_state is None: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="test-execution-id", + ), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -93,15 +108,21 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime human_data = HumanInputNodeData( title="Human Input", - required_variables=["human.input_ready"], - pause_reason="Awaiting human input", + form_content="Human input required", + inputs=[], + user_actions=[ + UserAction(id="primary", title="Primary"), + UserAction(id="secondary", title="Secondary"), + ], ) + human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=form_repository, ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") @@ -219,8 +240,18 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for scenario in branch_scenarios: runner = TableTestRunner() - def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config) + mock_create_repo = MagicMock(spec=HumanInputFormRepository) + mock_create_repo.get_form.return_value = None + mock_form_entity = MagicMock(spec=HumanInputFormEntity) + mock_form_entity.id = "test_form_id" + mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.recipients = [] + mock_form_entity.rendered_content = "rendered" + mock_form_entity.submitted = False + mock_create_repo.create_form.return_value = mock_form_entity + + def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: + return _build_branching_graph(mock_config, mock_create_repo) initial_case = WorkflowTestCase( description="HumanInput pause before branching decision", @@ -242,23 +273,16 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: assert initial_result.success, initial_result.event_mismatch_details assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) - graph_runtime_state = initial_result.graph_runtime_state - graph = initial_result.graph - assert graph_runtime_state is not None - assert graph is not None - - graph_runtime_state.variable_pool.add(("human", "input_ready"), True) - graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"]) - graph_runtime_state.graph_execution.pause_reason = None - pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) + expected_pre_chunk_events_in_resumption = [ + GraphRunStartedEvent, + NodeRunStartedEvent, + NodeRunHumanInputFormFilledEvent, + ] expected_resume_sequence: list[type] = ( - [ - GraphRunStartedEvent, - NodeRunStartedEvent, - ] + expected_pre_chunk_events_in_resumption + [NodeRunStreamChunkEvent] * pre_chunk_count + [ NodeRunSucceededEvent, @@ -273,11 +297,25 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: ] ) + mock_get_repo = MagicMock(spec=HumanInputFormRepository) + submitted_form = MagicMock(spec=HumanInputFormEntity) + submitted_form.id = mock_form_entity.id + submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.recipients = [] + submitted_form.rendered_content = mock_form_entity.rendered_content + submitted_form.submitted = True + submitted_form.selected_action_id = scenario["handle"] + submitted_form.submitted_data = {} + submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) + mock_get_repo.get_form.return_value = submitted_form + def resume_graph_factory( - graph_snapshot: Graph = graph, - state_snapshot: GraphRuntimeState = graph_runtime_state, + initial_result=initial_result, mock_get_repo=mock_get_repo ) -> tuple[Graph, GraphRuntimeState]: - return graph_snapshot, state_snapshot + assert initial_result.graph_runtime_state is not None + serialized_runtime_state = initial_result.graph_runtime_state.dumps() + resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) + return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) resume_case = WorkflowTestCase( description=f"HumanInput resumes via {scenario['handle']} branch", @@ -321,7 +359,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for index, event in enumerate(resume_events) if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index ] - assert pre_indices == list(range(2, 2 + pre_chunk_count)) + expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) + assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) resume_chunk_indices = [ index diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index ece69b080b..d8f229205b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -1,4 +1,6 @@ +import datetime import time +from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -13,11 +15,12 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) +from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input import HumanInputNode -from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -27,15 +30,21 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: +def _build_llm_human_llm_graph( + mock_config: MockConfig, + form_repository: HumanInputFormRepository, + graph_runtime_state: GraphRuntimeState | None = None, +) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -48,12 +57,15 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + if graph_runtime_state is None: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," + ), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -92,15 +104,21 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun human_data = HumanInputNodeData( title="Human Input", - required_variables=["human.input_ready"], - pause_reason="Awaiting human input", + form_content="Human input required", + inputs=[], + user_actions=[ + UserAction(id="accept", title="Accept"), + UserAction(id="reject", title="Reject"), + ], ) + human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=form_repository, ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -130,7 +148,7 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun .add_root(start_node) .add_node(llm_first) .add_node(human_node) - .add_node(llm_second) + .add_node(llm_second, source_handle="accept") .add_node(end_node) .build() ) @@ -167,8 +185,18 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunPausedEvent, # graph run pauses awaiting resume ] + mock_create_repo = MagicMock(spec=HumanInputFormRepository) + mock_create_repo.get_form.return_value = None + mock_form_entity = MagicMock(spec=HumanInputFormEntity) + mock_form_entity.id = "test_form_id" + mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.recipients = [] + mock_form_entity.rendered_content = "rendered" + mock_form_entity.submitted = False + mock_create_repo.create_form.return_value = mock_form_entity + def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config) + return _build_llm_human_llm_graph(mock_config, mock_create_repo) initial_case = WorkflowTestCase( description="HumanInput pause preserves LLM streaming order", @@ -210,6 +238,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None: expected_resume_sequence: list[type] = [ GraphRunStartedEvent, # resumed graph run begins NodeRunStartedEvent, # human node restarts + # Form Filled should be generated first, then the node execution ends and stream chunk is generated. + NodeRunHumanInputFormFilledEvent, NodeRunStreamChunkEvent, # cached llm_initial chunk 1 NodeRunStreamChunkEvent, # cached llm_initial chunk 2 NodeRunStreamChunkEvent, # cached llm_initial final chunk @@ -225,12 +255,27 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunSucceededEvent, # graph run succeeds after resume ] + mock_get_repo = MagicMock(spec=HumanInputFormRepository) + submitted_form = MagicMock(spec=HumanInputFormEntity) + submitted_form.id = mock_form_entity.id + submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.recipients = [] + submitted_form.rendered_content = mock_form_entity.rendered_content + submitted_form.submitted = True + submitted_form.selected_action_id = "accept" + submitted_form.submitted_data = {} + submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) + mock_get_repo.get_form.return_value = submitted_form + def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - assert graph_runtime_state is not None - assert graph is not None - graph_runtime_state.variable_pool.add(("human", "input_ready"), True) - graph_runtime_state.graph_execution.pause_reason = None - return graph, graph_runtime_state + # restruct the graph runtime state + serialized_runtime_state = initial_result.graph_runtime_state.dumps() + resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) + return _build_llm_human_llm_graph( + mock_config, + mock_get_repo, + resume_runtime_state, + ) resume_case = WorkflowTestCase( description="HumanInput resume continues LLM streaming order", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py new file mode 100644 index 0000000000..a6aab81f6c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -0,0 +1,270 @@ +import time +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, Protocol + +from core.workflow.entities import GraphInitParams +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.config import GraphEngineConfig +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.base.entities import OutputVariableEntity +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now + + +class PauseStateStore(Protocol): + def save(self, runtime_state: GraphRuntimeState) -> None: ... + + def load(self) -> GraphRuntimeState: ... + + +class InMemoryPauseStore: + def __init__(self) -> None: + self._snapshot: str | None = None + + def save(self, runtime_state: GraphRuntimeState) -> None: + self._snapshot = runtime_state.dumps() + + def load(self) -> GraphRuntimeState: + assert self._snapshot is not None + return GraphRuntimeState.from_snapshot(self._snapshot) + + +@dataclass +class StaticForm(HumanInputFormEntity): + form_id: str + rendered: str + is_submitted: bool + action_id: str | None = None + data: Mapping[str, Any] | None = None + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = naive_utc_now() + timedelta(days=1) + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return "token" + + @property + def recipients(self) -> list: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class StaticRepo(HumanInputFormRepository): + def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: + self._forms_by_node_id = dict(forms_by_node_id) + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + raise AssertionError("create_form should not be called in resume scenario") + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="Human Input", + form_content="Human input required", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ) + + human_a_config = {"id": "human_a", "data": human_data.model_dump()} + human_a = HumanInputNode( + id=human_a_config["id"], + config=human_a_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + ) + + human_b_config = {"id": "human_b", "data": human_data.model_dump()} + human_b = HumanInputNode( + id=human_b_config["id"], + config=human_b_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + ) + + end_data = EndNodeData( + title="End", + outputs=[ + OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]), + OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]), + ], + desc=None, + ) + end_config = {"id": "end", "data": end_data.model_dump()} + end_node = EndNode( + id=end_config["id"], + config=end_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + builder = ( + Graph.new() + .add_root(start_node) + .add_node(human_a, from_node_id="start") + .add_node(human_b, from_node_id="start") + .add_node(end_node, from_node_id="human_a", source_handle="approve") + ) + return builder.connect(tail="human_b", head="end", source_handle="approve").build() + + +def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]: + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + return list(engine.run()) + + +def _form(submitted: bool, action_id: str | None) -> StaticForm: + return StaticForm( + form_id="form", + rendered="rendered", + is_submitted=submitted, + action_id=action_id, + data={}, + status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING, + ) + + +def test_parallel_human_input_join_completes_after_second_resume() -> None: + pause_store: PauseStateStore = InMemoryPauseStore() + + initial_state = _build_runtime_state() + initial_repo = StaticRepo( + { + "human_a": _form(submitted=False, action_id=None), + "human_b": _form(submitted=False, action_id=None), + } + ) + initial_graph = _build_graph(initial_state, initial_repo) + initial_events = _run_graph(initial_graph, initial_state) + + assert isinstance(initial_events[-1], GraphRunPausedEvent) + pause_store.save(initial_state) + + first_resume_state = pause_store.load() + first_resume_repo = StaticRepo( + { + "human_a": _form(submitted=True, action_id="approve"), + "human_b": _form(submitted=False, action_id=None), + } + ) + first_resume_graph = _build_graph(first_resume_state, first_resume_repo) + first_resume_events = _run_graph(first_resume_graph, first_resume_state) + + assert isinstance(first_resume_events[0], GraphRunStartedEvent) + assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION + assert isinstance(first_resume_events[-1], GraphRunPausedEvent) + pause_store.save(first_resume_state) + + second_resume_state = pause_store.load() + second_resume_repo = StaticRepo( + { + "human_a": _form(submitted=True, action_id="approve"), + "human_b": _form(submitted=True, action_id="approve"), + } + ) + second_resume_graph = _build_graph(second_resume_state, second_resume_repo) + second_resume_events = _run_graph(second_resume_graph, second_resume_state) + + assert isinstance(second_resume_events[0], GraphRunStartedEvent) + assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION + assert isinstance(second_resume_events[-1], GraphRunSucceededEvent) + assert any(isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py new file mode 100644 index 0000000000..62aa56fc57 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -0,0 +1,333 @@ +import time +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.config import GraphEngineConfig +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + NodeRunPauseRequestedEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now + +from .test_mock_config import MockConfig, NodeMockConfig +from .test_mock_nodes import MockLLMNode + + +@dataclass +class StaticForm(HumanInputFormEntity): + form_id: str + rendered: str + is_submitted: bool + action_id: str | None = None + data: Mapping[str, Any] | None = None + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = naive_utc_now() + timedelta(days=1) + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return "token" + + @property + def recipients(self) -> list: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class StaticRepo(HumanInputFormRepository): + def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: + self._forms_by_node_id = dict(forms_by_node_id) + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + raise AssertionError("create_form should not be called in resume scenario") + + +class DelayedHumanInputNode(HumanInputNode): + def __init__(self, delay_seconds: float, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._delay_seconds = delay_seconds + + def _run(self): + if self._delay_seconds > 0: + time.sleep(self._delay_seconds) + yield from super()._run() + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="Human Input", + form_content="Human input required", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ) + + human_a_config = {"id": "human_a", "data": human_data.model_dump()} + human_a = HumanInputNode( + id=human_a_config["id"], + config=human_a_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + ) + + human_b_config = {"id": "human_b", "data": human_data.model_dump()} + human_b = DelayedHumanInputNode( + id=human_b_config["id"], + config=human_b_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + delay_seconds=0.2, + ) + + llm_data = LLMNodeData( + title="LLM A", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Prompt A", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + structured_output_enabled=False, + ) + llm_config = {"id": "llm_a", "data": llm_data.model_dump()} + llm_a = MockLLMNode( + id=llm_config["id"], + config=llm_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + mock_config=mock_config, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(human_a, from_node_id="start") + .add_node(human_b, from_node_id="start") + .add_node(llm_a, from_node_id="human_a", source_handle="approve") + .build() + ) + + +def test_parallel_human_input_pause_preserves_node_finished() -> None: + runtime_state = _build_runtime_state() + + runtime_state.graph_execution.start() + runtime_state.register_paused_node("human_a") + runtime_state.register_paused_node("human_b") + + submitted = StaticForm( + form_id="form-a", + rendered="rendered", + is_submitted=True, + action_id="approve", + data={}, + status_value=HumanInputFormStatus.SUBMITTED, + ) + pending = StaticForm( + form_id="form-b", + rendered="rendered", + is_submitted=False, + action_id=None, + data=None, + status_value=HumanInputFormStatus.WAITING, + ) + repo = StaticRepo({"human_a": submitted, "human_b": pending}) + + mock_config = MockConfig() + mock_config.simulate_delays = True + mock_config.set_node_config( + "llm_a", + NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), + ) + + graph = _build_graph(runtime_state, repo, mock_config) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + + events = list(engine.run()) + + llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) + llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) + human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) + graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) + graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) + + assert graph_started + assert graph_paused + assert human_b_pause + assert llm_started + assert llm_succeeded + + +def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: + base_state = _build_runtime_state() + base_state.graph_execution.start() + base_state.register_paused_node("human_a") + base_state.register_paused_node("human_b") + snapshot = base_state.dumps() + + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + submitted = StaticForm( + form_id="form-a", + rendered="rendered", + is_submitted=True, + action_id="approve", + data={}, + status_value=HumanInputFormStatus.SUBMITTED, + ) + pending = StaticForm( + form_id="form-b", + rendered="rendered", + is_submitted=False, + action_id=None, + data=None, + status_value=HumanInputFormStatus.WAITING, + ) + repo = StaticRepo({"human_a": submitted, "human_b": pending}) + + mock_config = MockConfig() + mock_config.simulate_delays = True + mock_config.set_node_config( + "llm_a", + NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), + ) + + graph = _build_graph(resumed_state, repo, mock_config) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=resumed_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + + events = list(engine.run()) + + start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) + assert start_event.reason is WorkflowStartReason.RESUMPTION + + llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) + llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) + human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) + graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) + + assert graph_paused + assert human_b_pause + assert llm_started + assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py new file mode 100644 index 0000000000..156cfefcd6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -0,0 +1,309 @@ +import time +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any + +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.workflow.entities import GraphInitParams +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.config import GraphEngineConfig +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now + +from .test_mock_config import MockConfig, NodeMockConfig +from .test_mock_nodes import MockLLMNode + + +@dataclass +class StaticForm(HumanInputFormEntity): + form_id: str + rendered: str + is_submitted: bool + action_id: str | None = None + data: Mapping[str, Any] | None = None + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = naive_utc_now() + timedelta(days=1) + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return "token" + + @property + def recipients(self) -> list: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class StaticRepo(HumanInputFormRepository): + def __init__(self, form: HumanInputFormEntity) -> None: + self._form = form + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + if node_id != "human_pause": + return None + return self._form + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + raise AssertionError("create_form should not be called in this test") + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} + start_node = StartNode( + id=start_config["id"], + config=start_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + llm_a_data = LLMNodeData( + title="LLM A", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Prompt A", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + structured_output_enabled=False, + ) + llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} + llm_a = MockLLMNode( + id=llm_a_config["id"], + config=llm_a_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + mock_config=mock_config, + ) + + llm_b_data = LLMNodeData( + title="LLM B", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Prompt B", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context=ContextConfig(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + reasoning_format="tagged", + structured_output_enabled=False, + ) + llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} + llm_b = MockLLMNode( + id=llm_b_config["id"], + config=llm_b_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + mock_config=mock_config, + ) + + human_data = HumanInputNodeData( + title="Human Input", + form_content="Pause here", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + ) + human_config = {"id": "human_pause", "data": human_data.model_dump()} + human_node = HumanInputNode( + id=human_config["id"], + config=human_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=repo, + ) + + end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) + end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} + end_human = EndNode( + id=end_human_config["id"], + config=end_human_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(llm_a, from_node_id="start") + .add_node(human_node, from_node_id="start") + .add_node(llm_b, from_node_id="llm_a") + .add_node(end_human, from_node_id="human_pause", source_handle="approve") + .build() + ) + + +def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: + for event in events: + if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: + return event + return None + + +def test_pause_defers_ready_nodes_until_resume() -> None: + runtime_state = _build_runtime_state() + + paused_form = StaticForm( + form_id="form-pause", + rendered="rendered", + is_submitted=False, + status_value=HumanInputFormStatus.WAITING, + ) + pause_repo = StaticRepo(paused_form) + + mock_config = MockConfig() + mock_config.simulate_delays = True + mock_config.set_node_config( + "llm_a", + NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), + ) + mock_config.set_node_config( + "llm_b", + NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), + ) + + graph = _build_graph(runtime_state, pause_repo, mock_config) + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + + paused_events = list(engine.run()) + + assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) + assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) + assert _get_node_started_event(paused_events, "llm_b") is None + + snapshot = runtime_state.dumps() + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + + submitted_form = StaticForm( + form_id="form-pause", + rendered="rendered", + is_submitted=True, + action_id="approve", + data={}, + status_value=HumanInputFormStatus.SUBMITTED, + ) + resume_repo = StaticRepo(submitted_form) + + resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) + resumed_engine = GraphEngine( + workflow_id="workflow", + graph=resumed_graph, + graph_runtime_state=resumed_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig( + min_workers=2, + max_workers=2, + scale_up_threshold=1, + scale_down_idle_time=30.0, + ), + ) + + resumed_events = list(resumed_engine.run()) + + start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) + assert start_event.reason is WorkflowStartReason.RESUMPTION + + llm_b_started = _get_node_started_event(resumed_events, "llm_b") + assert llm_b_started is not None + assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py new file mode 100644 index 0000000000..700b3f4b8b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -0,0 +1,217 @@ +import datetime +import time +from typing import Any +from unittest.mock import MagicMock + +from core.workflow.entities import GraphInitParams +from core.workflow.entities.workflow_start_reason import WorkflowStartReason +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphRunPausedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_events.graph import GraphRunStartedEvent +from core.workflow.nodes.base.entities import OutputVariableEntity +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now + + +def _build_runtime_state() -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="test-execution-id", + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = True + form_entity.selected_action_id = action_id + form_entity.submitted_data = {} + form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) + repo.get_form.return_value = form_entity + return repo + + +def _mock_form_repository_without_submission() -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = False + repo.create_form.return_value = form_entity + repo.get_form.return_value = None + return repo + + +def _build_human_input_graph( + runtime_state: GraphRuntimeState, + form_repository: HumanInputFormRepository, +) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="service-api", + call_depth=0, + ) + + start_data = StartNodeData(title="start", variables=[]) + start_node = StartNode( + id="start", + config={"id": "start", "data": start_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="human", + form_content="Awaiting human input", + inputs=[], + user_actions=[ + UserAction(id="continue", title="Continue"), + ], + ) + human_node = HumanInputNode( + id="human", + config={"id": "human", "data": human_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + form_repository=form_repository, + ) + + end_data = EndNodeData( + title="end", + outputs=[ + OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), + ], + desc=None, + ) + end_node = EndNode( + id="end", + config={"id": "end", "data": end_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(human_node) + .add_node(end_node, from_node_id="human", source_handle="continue") + .build() + ) + + +def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: + engine = GraphEngine( + workflow_id="workflow", + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + return list(engine.run()) + + +def _node_successes(events: list[GraphEngineEvent]) -> list[str]: + return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] + + +def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: + for event in events: + if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: + return event + return None + + +def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: + segment = variable_pool.get(selector) + assert segment is not None + return getattr(segment, "value", segment) + + +def test_engine_resume_restores_state_and_completion(): + # Baseline run without pausing + baseline_state = _build_runtime_state() + baseline_repo = _mock_form_repository_with_submission(action_id="continue") + baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) + baseline_events = _run_graph(baseline_graph, baseline_state) + assert baseline_events + first_paused_event = baseline_events[0] + assert isinstance(first_paused_event, GraphRunStartedEvent) + assert first_paused_event.reason is WorkflowStartReason.INITIAL + assert isinstance(baseline_events[-1], GraphRunSucceededEvent) + baseline_success_nodes = _node_successes(baseline_events) + + # Run with pause + paused_state = _build_runtime_state() + pause_repo = _mock_form_repository_without_submission() + paused_graph = _build_human_input_graph(paused_state, pause_repo) + paused_events = _run_graph(paused_graph, paused_state) + assert paused_events + first_paused_event = paused_events[0] + assert isinstance(first_paused_event, GraphRunStartedEvent) + assert first_paused_event.reason is WorkflowStartReason.INITIAL + assert isinstance(paused_events[-1], GraphRunPausedEvent) + snapshot = paused_state.dumps() + + # Resume from snapshot + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resume_repo = _mock_form_repository_with_submission(action_id="continue") + resumed_graph = _build_human_input_graph(resumed_state, resume_repo) + resumed_events = _run_graph(resumed_graph, resumed_state) + assert resumed_events + first_resumed_event = resumed_events[0] + assert isinstance(first_resumed_event, GraphRunStartedEvent) + assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION + assert isinstance(resumed_events[-1], GraphRunSucceededEvent) + + combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) + assert combined_success_nodes == baseline_success_nodes + + paused_human_started = _node_start_event(paused_events, "human") + resumed_human_started = _node_start_event(resumed_events, "human") + assert paused_human_started is not None + assert resumed_human_started is not None + assert paused_human_started.id == resumed_human_started.id + + assert baseline_state.outputs == resumed_state.outputs + assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( + resumed_state.variable_pool, ("human", "__action_id") + ) + assert baseline_state.graph_execution.completed + assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 488b47761b..21a642c2f8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -7,6 +7,7 @@ from core.workflow.nodes.base.node import Node # Ensures that all node classes are imported. from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed. _ = NODE_TYPE_CLASSES_MAPPING @@ -45,7 +46,9 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined assert isinstance(cls.node_type, NodeType) assert isinstance(node_version, str) node_type_and_version = (node_type, node_version) - assert node_type_and_version not in type_version_set + assert node_type_and_version not in type_version_set, ( + f"Duplicate node type and version for class: {cls=} {node_type_and_version=}" + ) type_version_set.add(node_type_and_version) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py b/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py new file mode 100644 index 0000000000..20807e9ef9 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/__init__.py @@ -0,0 +1 @@ +# Unit tests for human input node diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py new file mode 100644 index 0000000000..ca4a887d20 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -0,0 +1,16 @@ +from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients +from core.workflow.runtime import VariablePool + + +def test_render_body_template_replaces_variable_values(): + config = EmailDeliveryConfig( + recipients=EmailRecipients(), + subject="Subject", + body="Hello {{#node1.value#}} {{#url#}}", + ) + variable_pool = VariablePool() + variable_pool.add(["node1", "value"], "World") + + result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool) + + assert result == "Hello World https://example.com" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py new file mode 100644 index 0000000000..bfe7b03c13 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -0,0 +1,597 @@ +""" +Unit tests for human input node entities. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pydantic import ValidationError + +from core.workflow.entities import GraphInitParams +from core.workflow.node_events import PauseRequestedEvent +from core.workflow.node_events.node import StreamCompletedEvent +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + FormInput, + FormInputDefault, + HumanInputNodeData, + MemberRecipient, + UserAction, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, +) +from core.workflow.nodes.human_input.enums import ( + ButtonStyle, + DeliveryMethodType, + EmailRecipientType, + FormInputType, + PlaceholderType, + TimeoutUnit, +) +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository + + +class TestDeliveryMethod: + """Test DeliveryMethod entity.""" + + def test_webapp_delivery_method(self): + """Test webapp delivery method creation.""" + delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()) + + assert delivery_method.type == DeliveryMethodType.WEBAPP + assert delivery_method.enabled is True + assert isinstance(delivery_method.config, _WebAppDeliveryConfig) + + def test_email_delivery_method(self): + """Test email delivery method creation.""" + recipients = EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), + ], + ) + + config = EmailDeliveryConfig( + recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder" + ) + + delivery_method = EmailDeliveryMethod(enabled=True, config=config) + + assert delivery_method.type == DeliveryMethodType.EMAIL + assert delivery_method.enabled is True + assert isinstance(delivery_method.config, EmailDeliveryConfig) + assert delivery_method.config.subject == "Test Subject" + assert len(delivery_method.config.recipients.items) == 2 + + +class TestFormInput: + """Test FormInput entity.""" + + def test_text_input_with_constant_default(self): + """Test text input with constant default value.""" + default = FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter your response here...") + + form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default) + + assert form_input.type == FormInputType.TEXT_INPUT + assert form_input.output_variable_name == "user_input" + assert form_input.default.type == PlaceholderType.CONSTANT + assert form_input.default.value == "Enter your response here..." + + def test_text_input_with_variable_default(self): + """Test text input with variable default value.""" + default = FormInputDefault(type=PlaceholderType.VARIABLE, selector=["node_123", "output_var"]) + + form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default) + + assert form_input.default.type == PlaceholderType.VARIABLE + assert form_input.default.selector == ["node_123", "output_var"] + + def test_form_input_without_default(self): + """Test form input without default value.""" + form_input = FormInput(type=FormInputType.PARAGRAPH, output_variable_name="description") + + assert form_input.type == FormInputType.PARAGRAPH + assert form_input.output_variable_name == "description" + assert form_input.default is None + + +class TestUserAction: + """Test UserAction entity.""" + + def test_user_action_creation(self): + """Test user action creation.""" + action = UserAction(id="approve", title="Approve", button_style=ButtonStyle.PRIMARY) + + assert action.id == "approve" + assert action.title == "Approve" + assert action.button_style == ButtonStyle.PRIMARY + + def test_user_action_default_button_style(self): + """Test user action with default button style.""" + action = UserAction(id="cancel", title="Cancel") + + assert action.button_style == ButtonStyle.DEFAULT + + def test_user_action_length_boundaries(self): + """Test user action id and title length boundaries.""" + action = UserAction(id="a" * 20, title="b" * 20) + + assert action.id == "a" * 20 + assert action.title == "b" * 20 + + @pytest.mark.parametrize( + ("field_name", "value"), + [ + ("id", "a" * 21), + ("title", "b" * 21), + ], + ) + def test_user_action_length_limits(self, field_name: str, value: str): + """User action fields should enforce max length.""" + data = {"id": "approve", "title": "Approve"} + data[field_name] = value + + with pytest.raises(ValidationError) as exc_info: + UserAction(**data) + + errors = exc_info.value.errors() + assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) + + +class TestHumanInputNodeData: + """Test HumanInputNodeData entity.""" + + def test_valid_node_data_creation(self): + """Test creating valid human input node data.""" + delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())] + + inputs = [ + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="content", + default=FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter content..."), + ) + ] + + user_actions = [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)] + + node_data = HumanInputNodeData( + title="Human Input Test", + desc="Test node description", + delivery_methods=delivery_methods, + form_content="# Test Form\n\nPlease provide input:\n\n{{#$output.content#}}", + inputs=inputs, + user_actions=user_actions, + timeout=24, + timeout_unit=TimeoutUnit.HOUR, + ) + + assert node_data.title == "Human Input Test" + assert node_data.desc == "Test node description" + assert len(node_data.delivery_methods) == 1 + assert node_data.form_content.startswith("# Test Form") + assert len(node_data.inputs) == 1 + assert len(node_data.user_actions) == 1 + assert node_data.timeout == 24 + assert node_data.timeout_unit == TimeoutUnit.HOUR + + def test_node_data_with_multiple_delivery_methods(self): + """Test node data with multiple delivery methods.""" + delivery_methods = [ + WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()), + EmailDeliveryMethod( + enabled=False, # Disabled method should be fine + config=EmailDeliveryConfig( + subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + ), + ), + ] + + node_data = HumanInputNodeData( + title="Test Node", delivery_methods=delivery_methods, timeout=1, timeout_unit=TimeoutUnit.DAY + ) + + assert len(node_data.delivery_methods) == 2 + assert node_data.timeout == 1 + assert node_data.timeout_unit == TimeoutUnit.DAY + + def test_node_data_defaults(self): + """Test node data with default values.""" + node_data = HumanInputNodeData(title="Test Node") + + assert node_data.title == "Test Node" + assert node_data.desc is None + assert node_data.delivery_methods == [] + assert node_data.form_content == "" + assert node_data.inputs == [] + assert node_data.user_actions == [] + assert node_data.timeout == 36 + assert node_data.timeout_unit == TimeoutUnit.HOUR + + def test_duplicate_input_output_variable_name_raises_validation_error(self): + """Duplicate form input output_variable_name should raise validation error.""" + duplicate_inputs = [ + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), + ] + + with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"): + HumanInputNodeData(title="Test Node", inputs=duplicate_inputs) + + def test_duplicate_user_action_ids_raise_validation_error(self): + """Duplicate user action ids should raise validation error.""" + duplicate_actions = [ + UserAction(id="submit", title="Submit"), + UserAction(id="submit", title="Submit Again"), + ] + + with pytest.raises(ValidationError, match="duplicated user action id 'submit'"): + HumanInputNodeData(title="Test Node", user_actions=duplicate_actions) + + def test_extract_outputs_field_names(self): + content = r"""This is titile {{#start.title#}} + + A content is required: + + {{#$output.content#}} + + A ending is required: + + {{#$output.ending#}} + """ + + node_data = HumanInputNodeData(title="Human Input", form_content=content) + field_names = node_data.outputs_field_names() + assert field_names == ["content", "ending"] + + +class TestRecipients: + """Test email recipient entities.""" + + def test_member_recipient(self): + """Test member recipient creation.""" + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + + assert recipient.type == EmailRecipientType.MEMBER + assert recipient.user_id == "user-123" + + def test_external_recipient(self): + """Test external recipient creation.""" + recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com") + + assert recipient.type == EmailRecipientType.EXTERNAL + assert recipient.email == "test@example.com" + + def test_email_recipients_whole_workspace(self): + """Test email recipients with whole workspace enabled.""" + recipients = EmailRecipients( + whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] + ) + + assert recipients.whole_workspace is True + assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True + + def test_email_recipients_specific_users(self): + """Test email recipients with specific users.""" + recipients = EmailRecipients( + whole_workspace=False, + items=[ + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), + ], + ) + + assert recipients.whole_workspace is False + assert len(recipients.items) == 2 + assert recipients.items[0].user_id == "user-123" + assert recipients.items[1].email == "external@example.com" + + +class TestHumanInputNodeVariableResolution: + """Tests for resolving variable-based defaults in HumanInputNode.""" + + def test_resolves_variable_defaults(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + variable_pool.add(("start", "name"), "Jane Doe") + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Provide your name", + inputs=[ + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="user_name", + default=FormInputDefault(type=PlaceholderType.VARIABLE, selector=["start", "name"]), + ), + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="user_email", + default=FormInputDefault(type=PlaceholderType.CONSTANT, value="foo@example.com"), + ), + ], + user_actions=[UserAction(id="submit", title="Submit")], + ) + config = {"id": "human", "data": node_data.model_dump()} + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-1", + rendered_content="Provide your name", + web_app_token="token", + recipients=[], + submitted=False, + ) + + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=mock_repo, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + expected_values = {"user_name": "Jane Doe"} + assert pause_event.reason.resolved_default_values == expected_values + + params = mock_repo.create_form.call_args.args[0] + assert params.resolved_default_values == expected_values + + def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-2", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Provide your name", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + ) + config = {"id": "human", "data": node_data.model_dump()} + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-2", + rendered_content="Provide your name", + web_app_token="console-token", + recipients=[SimpleNamespace(token="recipient-token")], + submitted=False, + ) + + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=mock_repo, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + assert pause_event.reason.form_token == "console-token" + + def test_debugger_debug_mode_overrides_email_recipients(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user-123", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-3", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user-123", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Provide your name", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + delivery_methods=[ + EmailDeliveryMethod( + enabled=True, + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], + ), + subject="Subject", + body="Body", + debug_mode=True, + ), + ) + ], + ) + config = {"id": "human", "data": node_data.model_dump()} + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-3", + rendered_content="Provide your name", + web_app_token="token", + recipients=[], + submitted=False, + ) + + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=mock_repo, + ) + + run_result = node._run() + pause_event = next(run_result) + assert isinstance(pause_event, PauseRequestedEvent) + + params = mock_repo.create_form.call_args.args[0] + assert len(params.delivery_methods) == 1 + method = params.delivery_methods[0] + assert isinstance(method, EmailDeliveryMethod) + assert method.config.debug_mode is True + assert method.config.recipients.whole_workspace is False + assert len(method.config.recipients.items) == 1 + recipient = method.config.recipients.items[0] + assert isinstance(recipient, MemberRecipient) + assert recipient.user_id == "user-123" + + +class TestValidation: + """Test validation scenarios.""" + + def test_invalid_form_input_type(self): + """Test validation with invalid form input type.""" + with pytest.raises(ValidationError): + FormInput( + type="invalid-type", # Invalid type + output_variable_name="test", + ) + + def test_invalid_button_style(self): + """Test validation with invalid button style.""" + with pytest.raises(ValidationError): + UserAction( + id="test", + title="Test", + button_style="invalid-style", # Invalid style + ) + + def test_invalid_timeout_unit(self): + """Test validation with invalid timeout unit.""" + with pytest.raises(ValidationError): + HumanInputNodeData( + title="Test", + timeout_unit="invalid-unit", # Invalid unit + ) + + +class TestHumanInputNodeRenderedContent: + """Tests for rendering submitted content.""" + + def test_replaces_outputs_placeholders_after_submission(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Name: {{#$output.name#}}", + inputs=[ + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="name", + ) + ], + user_actions=[UserAction(id="approve", title="Approve")], + ) + config = {"id": "human", "data": node_data.model_dump()} + + form_repository = InMemoryHumanInputFormRepository() + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=form_repository, + ) + + pause_gen = node._run() + pause_event = next(pause_gen) + assert isinstance(pause_event, PauseRequestedEvent) + with pytest.raises(StopIteration): + next(pause_gen) + + form_repository.set_submission(action_id="approve", form_data={"name": "Alice"}) + + events = list(node._run()) + last_event = events[-1] + assert isinstance(last_event, StreamCompletedEvent) + node_run_result = last_event.node_run_result + assert node_run_result.outputs["__rendered_content"] == "Name: Alice" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py new file mode 100644 index 0000000000..a19ee4dee3 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -0,0 +1,172 @@ +import datetime +from types import SimpleNamespace + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.enums import NodeType +from core.workflow.graph_events import ( + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunStartedEvent, +) +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now +from models.enums import UserFrom + + +class _FakeFormRepository: + def __init__(self, form): + self._form = form + + def get_form(self, *_args, **_kwargs): + return self._form + + +def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: + system_variables = SystemVariable.default() + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), + start_at=0.0, + ) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + config = { + "id": "node-1", + "type": NodeType.HUMAN_INPUT.value, + "data": { + "title": "Human Input", + "form_content": form_content, + "inputs": [ + { + "type": "text_input", + "output_variable_name": "name", + "default": {"type": "constant", "value": ""}, + } + ], + "user_actions": [ + { + "id": "Accept", + "title": "Approve", + "button_style": "default", + } + ], + }, + } + + fake_form = SimpleNamespace( + id="form-1", + rendered_content=form_content, + submitted=True, + selected_action_id="Accept", + submitted_data={"name": "Alice"}, + status=HumanInputFormStatus.SUBMITTED, + expiration_time=naive_utc_now() + datetime.timedelta(days=1), + ) + + repo = _FakeFormRepository(fake_form) + return HumanInputNode( + id="node-1", + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + ) + + +def _build_timeout_node() -> HumanInputNode: + system_variables = SystemVariable.default() + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), + start_at=0.0, + ) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + config = { + "id": "node-1", + "type": NodeType.HUMAN_INPUT.value, + "data": { + "title": "Human Input", + "form_content": "Please enter your name:\n\n{{#$output.name#}}", + "inputs": [ + { + "type": "text_input", + "output_variable_name": "name", + "default": {"type": "constant", "value": ""}, + } + ], + "user_actions": [ + { + "id": "Accept", + "title": "Approve", + "button_style": "default", + } + ], + }, + } + + fake_form = SimpleNamespace( + id="form-1", + rendered_content="content", + submitted=False, + selected_action_id=None, + submitted_data=None, + status=HumanInputFormStatus.TIMEOUT, + expiration_time=naive_utc_now() - datetime.timedelta(minutes=1), + ) + + repo = _FakeFormRepository(fake_form) + return HumanInputNode( + id="node-1", + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + ) + + +def test_human_input_node_emits_form_filled_event_before_succeeded(): + node = _build_node() + + events = list(node.run()) + + assert isinstance(events[0], NodeRunStartedEvent) + assert isinstance(events[1], NodeRunHumanInputFormFilledEvent) + + filled_event = events[1] + assert filled_event.node_title == "Human Input" + assert filled_event.rendered_content.endswith("Alice") + assert filled_event.action_id == "Accept" + assert filled_event.action_text == "Approve" + + +def test_human_input_node_emits_timeout_event_before_succeeded(): + node = _build_timeout_node() + + events = list(node.run()) + + assert isinstance(events[0], NodeRunStartedEvent) + assert isinstance(events[1], NodeRunHumanInputFormTimeoutEvent) + + timeout_event = events[1] + assert timeout_event.node_title == "Human Input" diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool_conver.py b/api/tests/unit_tests/core/workflow/test_variable_pool_conver.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index 34d48fa94e..2ec7d6b4fc 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -104,6 +104,7 @@ class TestCelerySSLConfiguration: def test_celery_init_applies_ssl_to_broker_and_backend(self): """Test that SSL options are applied to both broker and backend when using Redis.""" mock_config = MagicMock() + mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1 mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.CELERY_BACKEND = "redis" mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" diff --git a/api/tests/unit_tests/extensions/test_pubsub_channel.py b/api/tests/unit_tests/extensions/test_pubsub_channel.py new file mode 100644 index 0000000000..a5b41a7266 --- /dev/null +++ b/api/tests/unit_tests/extensions/test_pubsub_channel.py @@ -0,0 +1,20 @@ +from configs import dify_config +from extensions import ext_redis +from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel +from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel + + +def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch): + monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + + channel = ext_redis.get_pubsub_broadcast_channel() + + assert isinstance(channel, RedisBroadcastChannel) + + +def test_get_pubsub_broadcast_channel_sharded(monkeypatch): + monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") + + channel = ext_redis.get_pubsub_broadcast_channel() + + assert isinstance(channel, ShardedRedisBroadcastChannel) diff --git a/api/tests/unit_tests/libs/_human_input/__init__.py b/api/tests/unit_tests/libs/_human_input/__init__.py new file mode 100644 index 0000000000..66714e72f8 --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/__init__.py @@ -0,0 +1 @@ +# Treat this directory as a package so support modules can be imported relatively. diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py new file mode 100644 index 0000000000..bd86c13a2c --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/support.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any + +from core.workflow.nodes.human_input.entities import FormInput +from core.workflow.nodes.human_input.enums import TimeoutUnit + + +# Exceptions +class HumanInputError(Exception): + error_code: str = "unknown" + + def __init__(self, message: str = "", error_code: str | None = None): + super().__init__(message) + self.message = message or self.__class__.__name__ + if error_code: + self.error_code = error_code + + +class FormNotFoundError(HumanInputError): + error_code = "form_not_found" + + +class FormExpiredError(HumanInputError): + error_code = "human_input_form_expired" + + +class FormAlreadySubmittedError(HumanInputError): + error_code = "human_input_form_submitted" + + +class InvalidFormDataError(HumanInputError): + error_code = "invalid_form_data" + + +# Models +@dataclass +class HumanInputForm: + form_id: str + workflow_run_id: str + node_id: str + tenant_id: str + app_id: str | None + form_content: str + inputs: list[FormInput] + user_actions: list[dict[str, Any]] + timeout: int + timeout_unit: TimeoutUnit + form_token: str | None = None + created_at: datetime = field(default_factory=datetime.utcnow) + expires_at: datetime | None = None + submitted_at: datetime | None = None + submitted_data: dict[str, Any] | None = None + submitted_action: str | None = None + + def __post_init__(self) -> None: + if self.expires_at is None: + self.calculate_expiration() + + @property + def is_expired(self) -> bool: + return self.expires_at is not None and datetime.utcnow() > self.expires_at + + @property + def is_submitted(self) -> bool: + return self.submitted_at is not None + + def mark_submitted(self, inputs: dict[str, Any], action: str) -> None: + self.submitted_data = inputs + self.submitted_action = action + self.submitted_at = datetime.utcnow() + + def submit(self, inputs: dict[str, Any], action: str) -> None: + self.mark_submitted(inputs, action) + + def calculate_expiration(self) -> None: + start = self.created_at + if self.timeout_unit == TimeoutUnit.HOUR: + self.expires_at = start + timedelta(hours=self.timeout) + elif self.timeout_unit == TimeoutUnit.DAY: + self.expires_at = start + timedelta(days=self.timeout) + else: + raise ValueError(f"Unsupported timeout unit {self.timeout_unit}") + + def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]: + inputs_response = [ + { + "type": form_input.type.name.lower().replace("_", "-"), + "output_variable_name": form_input.output_variable_name, + } + for form_input in self.inputs + ] + response = { + "form_content": self.form_content, + "inputs": inputs_response, + "user_actions": self.user_actions, + } + if include_site_info: + response["site"] = {"app_id": self.app_id, "title": "Workflow Form"} + return response + + +@dataclass +class FormSubmissionData: + form_id: str + inputs: dict[str, Any] + action: str + submitted_at: datetime = field(default_factory=datetime.utcnow) + + @classmethod + def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore + return cls(form_id=form_id, inputs=request.inputs, action=request.action) + + +@dataclass +class FormSubmissionRequest: + inputs: dict[str, Any] + action: str + + +# Repository +class InMemoryFormRepository: + """ + Simple in-memory repository used by unit tests. + """ + + def __init__(self): + self._forms: dict[str, HumanInputForm] = {} + + @property + def forms(self) -> dict[str, HumanInputForm]: + return self._forms + + def save(self, form: HumanInputForm) -> None: + self._forms[form.form_id] = form + + def get_by_id(self, form_id: str) -> HumanInputForm | None: + return self._forms.get(form_id) + + def get_by_token(self, token: str) -> HumanInputForm | None: + for form in self._forms.values(): + if form.form_token == token: + return form + return None + + def delete(self, form_id: str) -> None: + self._forms.pop(form_id, None) + + +# Service +class FormService: + """Service layer for managing human input forms in tests.""" + + def __init__(self, repository: InMemoryFormRepository): + self.repository = repository + + def create_form( + self, + *, + form_id: str, + workflow_run_id: str, + node_id: str, + tenant_id: str, + app_id: str | None, + form_content: str, + inputs, + user_actions, + timeout: int, + timeout_unit: TimeoutUnit, + form_token: str | None = None, + ) -> HumanInputForm: + form = HumanInputForm( + form_id=form_id, + workflow_run_id=workflow_run_id, + node_id=node_id, + tenant_id=tenant_id, + app_id=app_id, + form_content=form_content, + inputs=list(inputs), + user_actions=[{"id": action.id, "title": action.title} for action in user_actions], + timeout=timeout, + timeout_unit=timeout_unit, + form_token=form_token, + ) + form.calculate_expiration() + self.repository.save(form) + return form + + def get_form_by_id(self, form_id: str) -> HumanInputForm: + form = self.repository.get_by_id(form_id) + if form is None: + raise FormNotFoundError() + return form + + def get_form_by_token(self, token: str) -> HumanInputForm: + form = self.repository.get_by_token(token) + if form is None: + raise FormNotFoundError() + return form + + def get_form_definition(self, form_id: str, *, is_token: bool) -> dict: + form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) + if form.is_expired: + raise FormExpiredError() + if form.is_submitted: + raise FormAlreadySubmittedError() + + definition = { + "form_content": form.form_content, + "inputs": form.inputs, + "user_actions": form.user_actions, + } + if is_token: + definition["site"] = {"title": "Workflow Form"} + return definition + + def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None: + form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id) + if form.is_expired: + raise FormExpiredError() + if form.is_submitted: + raise FormAlreadySubmittedError() + + self._validate_submission(form=form, submission_data=submission_data) + form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action) + self.repository.save(form) + + def cleanup_expired_forms(self) -> int: + expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired] + for form_id in expired_ids: + self.repository.delete(form_id) + return len(expired_ids) + + def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None: + defined_actions = {action["id"] for action in form.user_actions} + if submission_data.action not in defined_actions: + raise InvalidFormDataError(f"Invalid action: {submission_data.action}") + + missing_inputs = [] + for form_input in form.inputs: + if form_input.output_variable_name not in submission_data.inputs: + missing_inputs.append(form_input.output_variable_name) + + if missing_inputs: + raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}") + + # Extra inputs are allowed; no further validation required. diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py new file mode 100644 index 0000000000..15e7d41e85 --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -0,0 +1,326 @@ +""" +Unit tests for FormService. +""" + +from datetime import datetime, timedelta + +import pytest + +from core.workflow.nodes.human_input.entities import ( + FormInput, + UserAction, +) +from core.workflow.nodes.human_input.enums import ( + FormInputType, + TimeoutUnit, +) +from libs.datetime_utils import naive_utc_now + +from .support import ( + FormAlreadySubmittedError, + FormExpiredError, + FormNotFoundError, + FormService, + FormSubmissionData, + InMemoryFormRepository, + InvalidFormDataError, +) + + +class TestFormService: + """Test FormService functionality.""" + + @pytest.fixture + def repository(self): + """Create in-memory repository for testing.""" + return InMemoryFormRepository() + + @pytest.fixture + def form_service(self, repository): + """Create FormService with in-memory repository.""" + return FormService(repository) + + @pytest.fixture + def sample_form_data(self): + """Create sample form data.""" + return { + "form_id": "form-123", + "workflow_run_id": "run-456", + "node_id": "node-789", + "tenant_id": "tenant-abc", + "app_id": "app-def", + "form_content": "# Test Form\n\nInput: {{#$output.input#}}", + "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)], + "user_actions": [UserAction(id="submit", title="Submit")], + "timeout": 1, + "timeout_unit": TimeoutUnit.HOUR, + "form_token": "token-xyz", + } + + def test_create_form(self, form_service, sample_form_data): + """Test form creation.""" + form = form_service.create_form(**sample_form_data) + + assert form.form_id == "form-123" + assert form.workflow_run_id == "run-456" + assert form.node_id == "node-789" + assert form.tenant_id == "tenant-abc" + assert form.app_id == "app-def" + assert form.form_token == "token-xyz" + assert form.timeout == 1 + assert form.timeout_unit == TimeoutUnit.HOUR + assert form.expires_at is not None + assert not form.is_expired + assert not form.is_submitted + + def test_get_form_by_id(self, form_service, sample_form_data): + """Test getting form by ID.""" + # Create form first + created_form = form_service.create_form(**sample_form_data) + + # Retrieve form + retrieved_form = form_service.get_form_by_id("form-123") + + assert retrieved_form.form_id == created_form.form_id + assert retrieved_form.workflow_run_id == created_form.workflow_run_id + + def test_get_form_by_id_not_found(self, form_service): + """Test getting non-existent form by ID.""" + with pytest.raises(FormNotFoundError) as exc_info: + form_service.get_form_by_id("non-existent-form") + + assert exc_info.value.error_code == "form_not_found" + + def test_get_form_by_token(self, form_service, sample_form_data): + """Test getting form by token.""" + # Create form first + created_form = form_service.create_form(**sample_form_data) + + # Retrieve form by token + retrieved_form = form_service.get_form_by_token("token-xyz") + + assert retrieved_form.form_id == created_form.form_id + assert retrieved_form.form_token == "token-xyz" + + def test_get_form_by_token_not_found(self, form_service): + """Test getting non-existent form by token.""" + with pytest.raises(FormNotFoundError) as exc_info: + form_service.get_form_by_token("non-existent-token") + + assert exc_info.value.error_code == "form_not_found" + + def test_get_form_definition_by_id(self, form_service, sample_form_data): + """Test getting form definition by ID.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Get form definition + definition = form_service.get_form_definition("form-123", is_token=False) + + assert "form_content" in definition + assert "inputs" in definition + assert definition["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}" + assert len(definition["inputs"]) == 1 + assert "site" not in definition # Should not include site info for ID-based access + + def test_get_form_definition_by_token(self, form_service, sample_form_data): + """Test getting form definition by token.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Get form definition + definition = form_service.get_form_definition("token-xyz", is_token=True) + + assert "form_content" in definition + assert "inputs" in definition + assert "site" in definition # Should include site info for token-based access + + def test_get_form_definition_expired_form(self, form_service, sample_form_data): + """Test getting definition for expired form.""" + # Create form with past expiry + form_service.create_form(**sample_form_data) + + # Manually expire the form by modifying expiry time + form = form_service.get_form_by_id("form-123") + form.expires_at = datetime.utcnow() - timedelta(hours=1) + form_service.repository.save(form) + + # Should raise FormExpiredError + with pytest.raises(FormExpiredError) as exc_info: + form_service.get_form_definition("form-123", is_token=False) + + assert exc_info.value.error_code == "human_input_form_expired" + + def test_get_form_definition_submitted_form(self, form_service, sample_form_data): + """Test getting definition for already submitted form.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Submit the form + submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") + form_service.submit_form("form-123", submission_data, is_token=False) + + # Should raise FormAlreadySubmittedError + with pytest.raises(FormAlreadySubmittedError) as exc_info: + form_service.get_form_definition("form-123", is_token=False) + + assert exc_info.value.error_code == "human_input_form_submitted" + + def test_submit_form_success(self, form_service, sample_form_data): + """Test successful form submission.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Submit form + submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") + + # Should not raise any exception + form_service.submit_form("form-123", submission_data, is_token=False) + + # Verify form is marked as submitted + form = form_service.get_form_by_id("form-123") + assert form.is_submitted + assert form.submitted_data == {"input": "test value"} + assert form.submitted_action == "submit" + assert form.submitted_at is not None + + def test_submit_form_missing_inputs(self, form_service, sample_form_data): + """Test form submission with missing inputs.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Submit form with missing required input + submission_data = FormSubmissionData( + form_id="form-123", + inputs={}, # Missing required "input" field + action="submit", + ) + + with pytest.raises(InvalidFormDataError) as exc_info: + form_service.submit_form("form-123", submission_data, is_token=False) + + assert "Missing required inputs" in exc_info.value.message + assert "input" in exc_info.value.message + + def test_submit_form_invalid_action(self, form_service, sample_form_data): + """Test form submission with invalid action.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Submit form with invalid action + submission_data = FormSubmissionData( + form_id="form-123", + inputs={"input": "test value"}, + action="invalid_action", # Not in the allowed actions + ) + + with pytest.raises(InvalidFormDataError) as exc_info: + form_service.submit_form("form-123", submission_data, is_token=False) + + assert "Invalid action" in exc_info.value.message + assert "invalid_action" in exc_info.value.message + + def test_submit_form_expired(self, form_service, sample_form_data): + """Test submitting expired form.""" + # Create form first + form_service.create_form(**sample_form_data) + + # Manually expire the form + form = form_service.get_form_by_id("form-123") + form.expires_at = datetime.utcnow() - timedelta(hours=1) + form_service.repository.save(form) + + # Try to submit expired form + submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit") + + with pytest.raises(FormExpiredError) as exc_info: + form_service.submit_form("form-123", submission_data, is_token=False) + + assert exc_info.value.error_code == "human_input_form_expired" + + def test_submit_form_already_submitted(self, form_service, sample_form_data): + """Test submitting form that's already submitted.""" + # Create and submit form first + form_service.create_form(**sample_form_data) + + submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "first submission"}, action="submit") + form_service.submit_form("form-123", submission_data, is_token=False) + + # Try to submit again + second_submission = FormSubmissionData( + form_id="form-123", inputs={"input": "second submission"}, action="submit" + ) + + with pytest.raises(FormAlreadySubmittedError) as exc_info: + form_service.submit_form("form-123", second_submission, is_token=False) + + assert exc_info.value.error_code == "human_input_form_submitted" + + def test_cleanup_expired_forms(self, form_service, sample_form_data): + """Test cleanup of expired forms.""" + # Create multiple forms + for i in range(3): + data = sample_form_data.copy() + data["form_id"] = f"form-{i}" + data["form_token"] = f"token-{i}" + form_service.create_form(**data) + + # Manually expire some forms + for i in range(2): # Expire first 2 forms + form = form_service.get_form_by_id(f"form-{i}") + form.expires_at = naive_utc_now() - timedelta(hours=1) + form_service.repository.save(form) + + # Clean up expired forms + cleaned_count = form_service.cleanup_expired_forms() + + assert cleaned_count == 2 + + # Verify expired forms are gone + with pytest.raises(FormNotFoundError): + form_service.get_form_by_id("form-0") + + with pytest.raises(FormNotFoundError): + form_service.get_form_by_id("form-1") + + # Verify non-expired form still exists + form = form_service.get_form_by_id("form-2") + assert form.form_id == "form-2" + + +class TestFormValidation: + """Test form validation logic.""" + + def test_validate_submission_with_extra_inputs(self): + """Test validation allows extra inputs that aren't defined in form.""" + repository = InMemoryFormRepository() + form_service = FormService(repository) + + # Create form with one input + form_data = { + "form_id": "form-123", + "workflow_run_id": "run-456", + "node_id": "node-789", + "tenant_id": "tenant-abc", + "app_id": "app-def", + "form_content": "Test form", + "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="required_input", default=None)], + "user_actions": [UserAction(id="submit", title="Submit")], + "timeout": 1, + "timeout_unit": TimeoutUnit.HOUR, + } + + form_service.create_form(**form_data) + + # Submit with extra input (should be allowed) + submission_data = FormSubmissionData( + form_id="form-123", + inputs={ + "required_input": "value1", + "extra_input": "value2", # Extra input not defined in form + }, + action="submit", + ) + + # Should not raise any exception + form_service.submit_form("form-123", submission_data, is_token=False) diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py new file mode 100644 index 0000000000..962eeb9e11 --- /dev/null +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -0,0 +1,232 @@ +""" +Unit tests for human input form models. +""" + +from datetime import datetime, timedelta + +import pytest + +from core.workflow.nodes.human_input.entities import ( + FormInput, + UserAction, +) +from core.workflow.nodes.human_input.enums import ( + FormInputType, + TimeoutUnit, +) + +from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm + + +class TestHumanInputForm: + """Test HumanInputForm model.""" + + @pytest.fixture + def sample_form_data(self): + """Create sample form data.""" + return { + "form_id": "form-123", + "workflow_run_id": "run-456", + "node_id": "node-789", + "tenant_id": "tenant-abc", + "app_id": "app-def", + "form_content": "# Test Form\n\nInput: {{#$output.input#}}", + "inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)], + "user_actions": [UserAction(id="submit", title="Submit")], + "timeout": 2, + "timeout_unit": TimeoutUnit.HOUR, + "form_token": "token-xyz", + } + + def test_form_creation(self, sample_form_data): + """Test form creation.""" + form = HumanInputForm(**sample_form_data) + + assert form.form_id == "form-123" + assert form.workflow_run_id == "run-456" + assert form.node_id == "node-789" + assert form.tenant_id == "tenant-abc" + assert form.app_id == "app-def" + assert form.form_token == "token-xyz" + assert form.timeout == 2 + assert form.timeout_unit == TimeoutUnit.HOUR + assert form.created_at is not None + assert form.expires_at is not None + assert form.submitted_at is None + assert form.submitted_data is None + assert form.submitted_action is None + + def test_form_expiry_calculation_hours(self, sample_form_data): + """Test form expiry calculation for hours.""" + form = HumanInputForm(**sample_form_data) + + # Should expire 2 hours after creation + expected_expiry = form.created_at + timedelta(hours=2) + assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second + + def test_form_expiry_calculation_days(self, sample_form_data): + """Test form expiry calculation for days.""" + sample_form_data["timeout"] = 3 + sample_form_data["timeout_unit"] = TimeoutUnit.DAY + + form = HumanInputForm(**sample_form_data) + + # Should expire 3 days after creation + expected_expiry = form.created_at + timedelta(days=3) + assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second + + def test_form_expiry_property_not_expired(self, sample_form_data): + """Test is_expired property for non-expired form.""" + form = HumanInputForm(**sample_form_data) + assert not form.is_expired + + def test_form_expiry_property_expired(self, sample_form_data): + """Test is_expired property for expired form.""" + # Create form with past expiry + past_time = datetime.utcnow() - timedelta(hours=1) + sample_form_data["created_at"] = past_time + + form = HumanInputForm(**sample_form_data) + # Manually set expiry to past time + form.expires_at = past_time + + assert form.is_expired + + def test_form_submission_property_not_submitted(self, sample_form_data): + """Test is_submitted property for non-submitted form.""" + form = HumanInputForm(**sample_form_data) + assert not form.is_submitted + + def test_form_submission_property_submitted(self, sample_form_data): + """Test is_submitted property for submitted form.""" + form = HumanInputForm(**sample_form_data) + form.submit({"input": "test value"}, "submit") + + assert form.is_submitted + assert form.submitted_at is not None + assert form.submitted_data == {"input": "test value"} + assert form.submitted_action == "submit" + + def test_form_submit_method(self, sample_form_data): + """Test form submit method.""" + form = HumanInputForm(**sample_form_data) + + submission_time_before = datetime.utcnow() + form.submit({"input": "test value"}, "submit") + submission_time_after = datetime.utcnow() + + assert form.is_submitted + assert form.submitted_data == {"input": "test value"} + assert form.submitted_action == "submit" + assert submission_time_before <= form.submitted_at <= submission_time_after + + def test_form_to_response_dict_without_site_info(self, sample_form_data): + """Test converting form to response dict without site info.""" + form = HumanInputForm(**sample_form_data) + + response = form.to_response_dict(include_site_info=False) + + assert "form_content" in response + assert "inputs" in response + assert "site" not in response + assert response["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}" + assert len(response["inputs"]) == 1 + assert response["inputs"][0]["type"] == "text-input" + assert response["inputs"][0]["output_variable_name"] == "input" + + def test_form_to_response_dict_with_site_info(self, sample_form_data): + """Test converting form to response dict with site info.""" + form = HumanInputForm(**sample_form_data) + + response = form.to_response_dict(include_site_info=True) + + assert "form_content" in response + assert "inputs" in response + assert "site" in response + assert response["site"]["app_id"] == "app-def" + assert response["site"]["title"] == "Workflow Form" + + def test_form_without_web_app_token(self, sample_form_data): + """Test form creation without web app token.""" + sample_form_data["form_token"] = None + + form = HumanInputForm(**sample_form_data) + + assert form.form_token is None + assert form.form_id == "form-123" # Other fields should still work + + def test_form_with_explicit_timestamps(self): + """Test form creation with explicit timestamps.""" + created_time = datetime(2024, 1, 15, 10, 30, 0) + expires_time = datetime(2024, 1, 15, 12, 30, 0) + + form = HumanInputForm( + form_id="form-123", + workflow_run_id="run-456", + node_id="node-789", + tenant_id="tenant-abc", + app_id="app-def", + form_content="Test content", + inputs=[], + user_actions=[], + timeout=2, + timeout_unit=TimeoutUnit.HOUR, + created_at=created_time, + expires_at=expires_time, + ) + + assert form.created_at == created_time + assert form.expires_at == expires_time + + +class TestFormSubmissionData: + """Test FormSubmissionData model.""" + + def test_submission_data_creation(self): + """Test submission data creation.""" + submission_data = FormSubmissionData( + form_id="form-123", inputs={"field1": "value1", "field2": "value2"}, action="submit" + ) + + assert submission_data.form_id == "form-123" + assert submission_data.inputs == {"field1": "value1", "field2": "value2"} + assert submission_data.action == "submit" + assert submission_data.submitted_at is not None + + def test_submission_data_from_request(self): + """Test creating submission data from API request.""" + request = FormSubmissionRequest(inputs={"input": "test value"}, action="confirm") + + submission_data = FormSubmissionData.from_request("form-456", request) + + assert submission_data.form_id == "form-456" + assert submission_data.inputs == {"input": "test value"} + assert submission_data.action == "confirm" + assert submission_data.submitted_at is not None + + def test_submission_data_with_empty_inputs(self): + """Test submission data with empty inputs.""" + submission_data = FormSubmissionData(form_id="form-123", inputs={}, action="cancel") + + assert submission_data.inputs == {} + assert submission_data.action == "cancel" + + def test_submission_data_timestamps(self): + """Test submission data timestamp handling.""" + before_time = datetime.utcnow() + + submission_data = FormSubmissionData(form_id="form-123", inputs={"test": "value"}, action="submit") + + after_time = datetime.utcnow() + + assert before_time <= submission_data.submitted_at <= after_time + + def test_submission_data_with_explicit_timestamp(self): + """Test submission data with explicit timestamp.""" + specific_time = datetime(2024, 1, 15, 14, 30, 0) + + submission_data = FormSubmissionData( + form_id="form-123", inputs={"test": "value"}, action="submit", submitted_at=specific_time + ) + + assert submission_data.submitted_at == specific_time diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py index ccba075fdf..f206c411fd 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py @@ -181,6 +181,7 @@ class TestShardedTopic: subscription = sharded_topic.subscribe() assert isinstance(subscription, _RedisShardedSubscription) + assert subscription._client is mock_redis_client assert subscription._pubsub is mock_redis_client.pubsub.return_value assert subscription._topic == "test-sharded-topic" @@ -200,6 +201,11 @@ class SubscriptionTestCase: class TestRedisSubscription: """Test cases for the _RedisSubscription class.""" + @pytest.fixture + def mock_redis_client(self) -> MagicMock: + client = MagicMock() + return client + @pytest.fixture def mock_pubsub(self) -> MagicMock: """Create a mock PubSub instance for testing.""" @@ -211,9 +217,12 @@ class TestRedisSubscription: return pubsub @pytest.fixture - def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]: + def subscription( + self, mock_pubsub: MagicMock, mock_redis_client: MagicMock + ) -> Generator[_RedisSubscription, None, None]: """Create a _RedisSubscription instance for testing.""" subscription = _RedisSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-topic", ) @@ -228,13 +237,15 @@ class TestRedisSubscription: # ==================== Lifecycle Tests ==================== - def test_subscription_initialization(self, mock_pubsub: MagicMock): + def test_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Test that subscription is properly initialized.""" subscription = _RedisSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-topic", ) + assert subscription._client is mock_redis_client assert subscription._pubsub is mock_pubsub assert subscription._topic == "test-topic" assert not subscription._closed.is_set() @@ -486,9 +497,12 @@ class TestRedisSubscription: ), ], ) - def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock): + def test_subscription_scenarios( + self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock + ): """Test various subscription scenarios using table-driven approach.""" subscription = _RedisSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-topic", ) @@ -572,7 +586,7 @@ class TestRedisSubscription: # Close should still work subscription.close() # Should not raise - def test_channel_name_variations(self, mock_pubsub: MagicMock): + def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Test various channel name formats.""" channel_names = [ "simple", @@ -586,6 +600,7 @@ class TestRedisSubscription: for channel_name in channel_names: subscription = _RedisSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic=channel_name, ) @@ -604,6 +619,11 @@ class TestRedisSubscription: class TestRedisShardedSubscription: """Test cases for the _RedisShardedSubscription class.""" + @pytest.fixture + def mock_redis_client(self) -> MagicMock: + client = MagicMock() + return client + @pytest.fixture def mock_pubsub(self) -> MagicMock: """Create a mock PubSub instance for testing.""" @@ -615,9 +635,12 @@ class TestRedisShardedSubscription: return pubsub @pytest.fixture - def sharded_subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisShardedSubscription, None, None]: + def sharded_subscription( + self, mock_pubsub: MagicMock, mock_redis_client: MagicMock + ) -> Generator[_RedisShardedSubscription, None, None]: """Create a _RedisShardedSubscription instance for testing.""" subscription = _RedisShardedSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-sharded-topic", ) @@ -634,13 +657,15 @@ class TestRedisShardedSubscription: # ==================== Lifecycle Tests ==================== - def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock): + def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Test that sharded subscription is properly initialized.""" subscription = _RedisShardedSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-sharded-topic", ) + assert subscription._client is mock_redis_client assert subscription._pubsub is mock_pubsub assert subscription._topic == "test-sharded-topic" assert not subscription._closed.is_set() @@ -808,6 +833,37 @@ class TestRedisShardedSubscription: assert not sharded_subscription._queue.empty() assert sharded_subscription._queue.get_nowait() == b"test sharded payload" + def test_get_message_uses_target_node_for_cluster_client(self, mock_pubsub: MagicMock, monkeypatch): + """Test that cluster clients use target_node for sharded messages.""" + + class DummyRedisCluster: + def __init__(self): + self.get_node_from_key = MagicMock(return_value="node-1") + + monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.RedisCluster", DummyRedisCluster) + + client = DummyRedisCluster() + subscription = _RedisShardedSubscription( + client=client, + pubsub=mock_pubsub, + topic="test-sharded-topic", + ) + mock_pubsub.get_sharded_message.return_value = { + "type": "smessage", + "channel": "test-sharded-topic", + "data": b"payload", + } + + result = subscription._get_message() + + client.get_node_from_key.assert_called_once_with("test-sharded-topic") + mock_pubsub.get_sharded_message.assert_called_once_with( + ignore_subscribe_messages=False, + timeout=1, + target_node="node-1", + ) + assert result == mock_pubsub.get_sharded_message.return_value + def test_listener_thread_ignores_subscribe_messages( self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock ): @@ -913,9 +969,12 @@ class TestRedisShardedSubscription: ), ], ) - def test_sharded_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock): + def test_sharded_subscription_scenarios( + self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock + ): """Test various sharded subscription scenarios using table-driven approach.""" subscription = _RedisShardedSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic="test-sharded-topic", ) @@ -999,7 +1058,7 @@ class TestRedisShardedSubscription: # Close should still work sharded_subscription.close() # Should not raise - def test_channel_name_variations(self, mock_pubsub: MagicMock): + def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Test various sharded channel name formats.""" channel_names = [ "simple", @@ -1013,6 +1072,7 @@ class TestRedisShardedSubscription: for channel_name in channel_names: subscription = _RedisShardedSubscription( + client=mock_redis_client, pubsub=mock_pubsub, topic=channel_name, ) @@ -1060,6 +1120,11 @@ class TestRedisSubscriptionCommon: """Parameterized fixture providing subscription type and class.""" return request.param + @pytest.fixture + def mock_redis_client(self) -> MagicMock: + client = MagicMock() + return client + @pytest.fixture def mock_pubsub(self) -> MagicMock: """Create a mock PubSub instance for testing.""" @@ -1075,11 +1140,12 @@ class TestRedisSubscriptionCommon: return pubsub @pytest.fixture - def subscription(self, subscription_params, mock_pubsub: MagicMock): + def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock): """Create a subscription instance based on parameterized type.""" subscription_type, subscription_class = subscription_params topic_name = f"test-{subscription_type}-topic" subscription = subscription_class( + client=mock_redis_client, pubsub=mock_pubsub, topic=topic_name, ) diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index de74eff82f..1a93dbbca1 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,6 +1,8 @@ +from datetime import datetime + import pytest -from libs.helper import escape_like_pattern, extract_tenant_id +from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -65,6 +67,19 @@ class TestExtractTenantId: extract_tenant_id(dict_user) +class TestOptionalTimestampField: + def test_format_returns_none_for_none(self): + field = OptionalTimestampField() + + assert field.format(None) is None + + def test_format_returns_unix_timestamp_for_datetime(self): + field = OptionalTimestampField() + value = datetime(2024, 1, 2, 3, 4, 5) + + assert field.format(value) == int(value.timestamp()) + + class TestEscapeLikePattern: """Test cases for the escape_like_pattern utility function.""" diff --git a/api/tests/unit_tests/libs/test_rate_limiter.py b/api/tests/unit_tests/libs/test_rate_limiter.py new file mode 100644 index 0000000000..9d44b07b5e --- /dev/null +++ b/api/tests/unit_tests/libs/test_rate_limiter.py @@ -0,0 +1,68 @@ +from unittest.mock import MagicMock + +from libs import helper as helper_module + + +class _FakeRedis: + def __init__(self) -> None: + self._zsets: dict[str, dict[str, float]] = {} + self._expiry: dict[str, int] = {} + + def zadd(self, key: str, mapping: dict[str, float]) -> int: + zset = self._zsets.setdefault(key, {}) + for member, score in mapping.items(): + zset[str(member)] = float(score) + return len(mapping) + + def zremrangebyscore(self, key: str, min_score: str | float, max_score: str | float) -> int: + zset = self._zsets.get(key, {}) + min_value = float("-inf") if min_score == "-inf" else float(min_score) + max_value = float("inf") if max_score == "+inf" else float(max_score) + to_delete = [member for member, score in zset.items() if min_value <= score <= max_value] + for member in to_delete: + del zset[member] + return len(to_delete) + + def zcard(self, key: str) -> int: + return len(self._zsets.get(key, {})) + + def expire(self, key: str, ttl: int) -> bool: + self._expiry[key] = ttl + return True + + +def test_rate_limiter_counts_attempts_within_same_second(monkeypatch): + fake_redis = _FakeRedis() + monkeypatch.setattr(helper_module.time, "time", lambda: 1000) + + limiter = helper_module.RateLimiter( + prefix="test_rate_limit", + max_attempts=2, + time_window=60, + redis_client=fake_redis, + ) + + limiter.increment_rate_limit("203.0.113.10") + limiter.increment_rate_limit("203.0.113.10") + + assert limiter.is_rate_limited("203.0.113.10") is True + + +def test_rate_limiter_uses_injected_redis(monkeypatch): + redis_client = MagicMock() + redis_client.zcard.return_value = 1 + monkeypatch.setattr(helper_module.time, "time", lambda: 1000) + + limiter = helper_module.RateLimiter( + prefix="test_rate_limit", + max_attempts=1, + time_window=60, + redis_client=redis_client, + ) + + limiter.increment_rate_limit("203.0.113.10") + limiter.is_rate_limited("203.0.113.10") + + assert redis_client.zadd.called is True + assert redis_client.zremrangebyscore.called is True + assert redis_client.zcard.called is True diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 8be2eea121..c6dfd41803 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1296,6 +1296,7 @@ class TestConversationStatusCount: assert result["success"] == 1 # One SUCCEEDED assert result["failed"] == 1 # One FAILED assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED + assert result["paused"] == 0 def test_status_count_app_id_filtering(self): """Test that status_count filters workflow runs by app_id for security.""" @@ -1350,6 +1351,7 @@ class TestConversationStatusCount: assert result["success"] == 0 assert result["failed"] == 0 assert result["partial_success"] == 0 + assert result["paused"] == 0 def test_status_count_handles_invalid_workflow_status(self): """Test that status_count gracefully handles invalid workflow status values.""" @@ -1404,3 +1406,57 @@ class TestConversationStatusCount: assert result["success"] == 0 assert result["failed"] == 0 assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_paused(self): + """Test status_count includes paused workflow runs.""" + # Arrange + from core.workflow.enums import WorkflowExecutionStatus + + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_run_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id, + ), + ] + + mock_workflow_runs = [ + MagicMock( + id=workflow_run_id, + status=WorkflowExecutionStatus.PAUSED.value, + app_id=app_id, + ), + ] + + with patch("models.model.db.session.scalars") as mock_scalars: + + def mock_scalars_side_effect(query): + mock_result = MagicMock() + if "messages" in str(query): + mock_result.all.return_value = mock_messages + elif "workflow_runs" in str(query): + mock_result.all.return_value = mock_workflow_runs + else: + mock_result.all.return_value = [] + return mock_result + + mock_scalars.side_effect = mock_scalars_side_effect + + # Act + result = conversation.status_count + + # Assert + assert result["paused"] == 1 diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..ceb1406a4b --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -0,0 +1,40 @@ +"""Unit tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository implementation.""" + +from unittest.mock import Mock + +from sqlalchemy.orm import Session, sessionmaker + +from repositories.sqlalchemy_api_workflow_node_execution_repository import ( + DifyAPISQLAlchemyWorkflowNodeExecutionRepository, +) + + +class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository: + def test_get_executions_by_workflow_run_keeps_paused_records(self): + mock_session = Mock(spec=Session) + execute_result = Mock() + execute_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = execute_result + + session_maker = Mock(spec=sessionmaker) + context_manager = Mock() + context_manager.__enter__ = Mock(return_value=mock_session) + context_manager.__exit__ = Mock(return_value=None) + session_maker.return_value = context_manager + + repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker) + + repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-123", + workflow_run_id="workflow-run-123", + ) + + stmt = mock_session.execute.call_args[0][0] + where_clauses = list(getattr(stmt, "_where_criteria", []) or []) + where_strs = [str(clause).lower() for clause in where_clauses] + + assert any("tenant_id" in clause for clause in where_strs) + assert any("app_id" in clause for clause in where_strs) + assert any("workflow_run_id" in clause for clause in where_strs) + assert not any("paused" in clause for clause in where_strs) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index d443c4c9a5..4caaa056ff 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -1,5 +1,6 @@ """Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation.""" +import secrets from datetime import UTC, datetime from unittest.mock import Mock, patch @@ -7,12 +8,17 @@ import pytest from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session, sessionmaker +from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus +from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowRun +from models.workflow import WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, + _build_human_input_required_reason, _PrivateWorkflowPauseEntity, _WorkflowRunError, ) @@ -205,11 +211,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ): """Test workflow pause creation when workflow not in RUNNING status.""" # Arrange - sample_workflow_run.status = WorkflowExecutionStatus.PAUSED + sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED mock_session.get.return_value = sample_workflow_run # Act & Assert - with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"): + with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"): repository.create_workflow_pause( workflow_run_id="workflow-run-123", state_owner_user_id="user-123", @@ -295,6 +301,7 @@ class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): sample_workflow_pause.resumed_at = None mock_session.scalar.return_value = sample_workflow_run + mock_session.scalars.return_value.all.return_value = [] with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now: mock_now.return_value = datetime.now(UTC) @@ -455,3 +462,53 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository) assert result1 == expected_state assert result2 == expected_state mock_storage.load.assert_called_once() # Only called once due to caching + + +class TestBuildHumanInputRequiredReason: + def test_prefers_backstage_token_when_available(self): + expiration_time = datetime.now(UTC) + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + form_model = HumanInputForm( + id="form-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_run_id="run-1", + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + reason_model = WorkflowPauseReason( + pause_id="pause-1", + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id="form-1", + node_id="node-1", + message="", + ) + access_token = secrets.token_urlsafe(8) + backstage_recipient = HumanInputFormRecipient( + form_id="form-1", + delivery_id="delivery-1", + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=access_token, + ) + + reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) + + assert isinstance(reason, HumanInputRequired) + assert reason.form_token == access_token + assert reason.node_title == "Ask Name" + assert reason.form_content == "content" + assert reason.inputs[0].output_variable_name == "name" + assert reason.actions[0].id == "approve" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py new file mode 100644 index 0000000000..f5428b46ff --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta + +from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain +from core.entities.execution_extra_content import HumanInputFormSubmissionData +from core.workflow.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from models.execution_extra_content import HumanInputContent as HumanInputContentModel +from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType +from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository + + +class _FakeScalarResult: + def __init__(self, values: Sequence[HumanInputContentModel]): + self._values = list(values) + + def all(self) -> list[HumanInputContentModel]: + return list(self._values) + + +class _FakeSession: + def __init__(self, values: Sequence[Sequence[object]]): + self._values = list(values) + + def scalars(self, _stmt): + if not self._values: + return _FakeScalarResult([]) + return _FakeScalarResult(self._values.pop(0)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +@dataclass +class _FakeSessionMaker: + session: _FakeSession + + def __call__(self) -> _FakeSession: + return self.session + + +def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm: + expiration_time = datetime.now(UTC) + timedelta(days=1) + definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id=action_id, title=action_title)], + rendered_content="rendered", + expiration_time=expiration_time, + node_title="Approval", + display_in_ui=True, + ) + form = HumanInputForm( + id=f"form-{action_id}", + tenant_id="tenant-id", + app_id="app-id", + workflow_run_id="workflow-run", + node_id="node-id", + form_definition=definition.model_dump_json(), + rendered_content=rendered_content, + status=HumanInputFormStatus.SUBMITTED, + expiration_time=expiration_time, + ) + form.selected_action_id = action_id + return form + + +def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel: + form = _build_form( + action_id=action_id, + action_title=action_title, + rendered_content=f"Rendered {action_title}", + ) + content = HumanInputContentModel( + id=f"content-{message_id}", + form_id=form.id, + message_id=message_id, + workflow_run_id=form.workflow_run_id, + ) + content.form = form + return content + + +def test_get_by_message_ids_groups_contents_by_message() -> None: + message_ids = ["msg-1", "msg-2"] + contents = [_build_content("msg-1", "approve", "Approve")] + repository = SQLAlchemyExecutionExtraContentRepository( + session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []])) + ) + + result = repository.get_by_message_ids(message_ids) + + assert len(result) == 2 + assert [content.model_dump(mode="json", exclude_none=True) for content in result[0]] == [ + HumanInputContentDomain( + workflow_run_id="workflow-run", + submitted=True, + form_submission_data=HumanInputFormSubmissionData( + node_id="node-id", + node_title="Approval", + rendered_content="Rendered Approve", + action_id="approve", + action_text="Approve", + ), + ).model_dump(mode="json", exclude_none=True) + ] + assert result[1] == [] + + +def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None: + expiration_time = datetime.now(UTC) + timedelta(days=1) + definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "John"}, + node_title="Approval", + display_in_ui=True, + ) + form = HumanInputForm( + id="form-1", + tenant_id="tenant-id", + app_id="app-id", + workflow_run_id="workflow-run", + node_id="node-id", + form_definition=definition.model_dump_json(), + rendered_content="Rendered block", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + content = HumanInputContentModel( + id="content-msg-1", + form_id=form.id, + message_id="msg-1", + workflow_run_id=form.workflow_run_id, + ) + content.form = form + + recipient = HumanInputFormRecipient( + form_id=form.id, + delivery_id="delivery-1", + recipient_type=RecipientType.CONSOLE, + recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(), + access_token="token-1", + ) + + repository = SQLAlchemyExecutionExtraContentRepository( + session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]])) + ) + + result = repository.get_by_message_ids(["msg-1"]) + + assert len(result) == 1 + assert len(result[0]) == 1 + domain_content = result[0][0] + assert domain_content.submitted is False + assert domain_content.workflow_run_id == "workflow-run" + assert domain_content.form_definition is not None + assert domain_content.form_definition.expiration_time == int(form.expiration_time.timestamp()) + assert domain_content.form_definition is not None + form_definition = domain_content.form_definition + assert form_definition.form_id == "form-1" + assert form_definition.node_id == "node-id" + assert form_definition.node_title == "Approval" + assert form_definition.form_content == "Rendered block" + assert form_definition.display_in_ui is True + assert form_definition.form_token == "token-1" + assert form_definition.resolved_default_values == {"name": "John"} + assert form_definition.expiration_time == int(form.expiration_time.timestamp()) diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py new file mode 100644 index 0000000000..71134464e6 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -0,0 +1,65 @@ +from unittest.mock import MagicMock + +import services.app_generate_service as app_generate_service_module +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +class _DummyRateLimit: + def __init__(self, client_id: str, max_active_requests: int) -> None: + self.client_id = client_id + self.max_active_requests = max_active_requests + + @staticmethod + def gen_request_key() -> str: + return "dummy-request-id" + + def enter(self, request_id: str | None = None) -> str: + return request_id or "dummy-request-id" + + def exit(self, request_id: str) -> None: + return None + + def generate(self, generator, request_id: str): + return generator + + +def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch): + monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) + mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) + + workflow = MagicMock() + workflow.id = "workflow-id" + workflow.created_by = "owner-id" + + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + + generator_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.generate", + return_value={"result": "ok"}, + ) + + app_model = MagicMock() + app_model.mode = AppMode.WORKFLOW + app_model.id = "app-id" + app_model.tenant_id = "tenant-id" + app_model.max_active_requests = 0 + app_model.is_agent = False + + user = MagicMock() + user.id = "user-id" + + result = AppGenerateService.generate( + app_model=app_model, + user=user, + args={"inputs": {"k": "v"}}, + invoke_from=MagicMock(), + streaming=False, + ) + + assert result == {"result": "ok"} + + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert pause_state_config is not None + assert pause_state_config.state_owner_user_id == "owner-id" diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 81135dbbdf..eca1d44d23 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation: within conversations. """ + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session): + def test_pagination_by_first_id_without_first_id( + self, mock_get_conversation, mock_db_session, mock_create_extra_repo + ): """ Test message pagination without specifying first_id. @@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages + mock_repository = MagicMock() + mock_repository.get_by_message_ids.return_value = [[] for _ in messages] + mock_create_extra_repo.return_value = mock_repository # Act - Call the pagination method without first_id result = MessageService.pagination_by_first_id( @@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation: # Verify conversation was looked up with correct parameters mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id) + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session): + def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): """ Test message pagination with first_id specified. @@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation: mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.first.return_value = first_message # First message returned mock_query.all.return_value = messages # Remaining messages returned + mock_repository = MagicMock() + mock_repository.get_by_message_ids.return_value = [[] for _ in messages] + mock_create_extra_repo.return_value = mock_repository # Act - Call the pagination method with first_id result = MessageService.pagination_by_first_id( @@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation: assert result.data == [] assert result.has_more is False + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session): + def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): """ Test that has_more flag is correctly set when there are more messages. @@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages + mock_repository = MagicMock() + mock_repository.get_by_message_ids.return_value = [[] for _ in messages] + mock_create_extra_repo.return_value = mock_repository # Act result = MessageService.pagination_by_first_id( @@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation: assert len(result.data) == limit # Extra message should be removed assert result.has_more is True # Flag should be set + @patch("services.message_service._create_execution_extra_content_repository") @patch("services.message_service.db.session") @patch("services.message_service.ConversationService.get_conversation") - def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session): + def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo): """ Test message pagination with ascending order. @@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation: mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining mock_query.limit.return_value = mock_query # LIMIT returns self for chaining mock_query.all.return_value = messages # Final .all() returns the messages + mock_repository = MagicMock() + mock_repository.get_by_message_ids.return_value = [[] for _ in messages] + mock_create_extra_repo.return_value = mock_repository # Act result = MessageService.pagination_by_first_id( diff --git a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py new file mode 100644 index 0000000000..ab141a7b2d --- /dev/null +++ b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass + +import pytest + +from enums.cloud_plan import CloudPlan +from services import feature_service as feature_service_module +from services.feature_service import FeatureModel, FeatureService + + +@dataclass(frozen=True) +class HumanInputEmailDeliveryCase: + name: str + enterprise_enabled: bool + billing_enabled: bool + tenant_id: str | None + billing_feature_enabled: bool + plan: str + expected: bool + + +CASES = [ + HumanInputEmailDeliveryCase( + name="enterprise_enabled", + enterprise_enabled=True, + billing_enabled=True, + tenant_id=None, + billing_feature_enabled=False, + plan=CloudPlan.SANDBOX, + expected=True, + ), + HumanInputEmailDeliveryCase( + name="billing_disabled", + enterprise_enabled=False, + billing_enabled=False, + tenant_id=None, + billing_feature_enabled=False, + plan=CloudPlan.SANDBOX, + expected=True, + ), + HumanInputEmailDeliveryCase( + name="billing_enabled_requires_tenant", + enterprise_enabled=False, + billing_enabled=True, + tenant_id=None, + billing_feature_enabled=True, + plan=CloudPlan.PROFESSIONAL, + expected=False, + ), + HumanInputEmailDeliveryCase( + name="billing_feature_off", + enterprise_enabled=False, + billing_enabled=True, + tenant_id="tenant-1", + billing_feature_enabled=False, + plan=CloudPlan.PROFESSIONAL, + expected=False, + ), + HumanInputEmailDeliveryCase( + name="professional_plan", + enterprise_enabled=False, + billing_enabled=True, + tenant_id="tenant-1", + billing_feature_enabled=True, + plan=CloudPlan.PROFESSIONAL, + expected=True, + ), + HumanInputEmailDeliveryCase( + name="team_plan", + enterprise_enabled=False, + billing_enabled=True, + tenant_id="tenant-1", + billing_feature_enabled=True, + plan=CloudPlan.TEAM, + expected=True, + ), + HumanInputEmailDeliveryCase( + name="sandbox_plan", + enterprise_enabled=False, + billing_enabled=True, + tenant_id="tenant-1", + billing_feature_enabled=True, + plan=CloudPlan.SANDBOX, + expected=False, + ), +] + + +@pytest.mark.parametrize("case", CASES, ids=lambda case: case.name) +def test_resolve_human_input_email_delivery_enabled_matrix( + monkeypatch: pytest.MonkeyPatch, + case: HumanInputEmailDeliveryCase, +): + monkeypatch.setattr(feature_service_module.dify_config, "ENTERPRISE_ENABLED", case.enterprise_enabled) + monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", case.billing_enabled) + features = FeatureModel() + features.billing.enabled = case.billing_feature_enabled + features.billing.subscription.plan = case.plan + + result = FeatureService._resolve_human_input_email_delivery_enabled( + features=features, + tenant_id=case.tenant_id, + ) + + assert result is case.expected diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py new file mode 100644 index 0000000000..e0d6ad1b39 --- /dev/null +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -0,0 +1,97 @@ +from types import SimpleNamespace + +import pytest + +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, +) +from core.workflow.runtime import VariablePool +from services import human_input_delivery_test_service as service_module +from services.human_input_delivery_test_service import ( + DeliveryTestContext, + DeliveryTestError, + EmailDeliveryTestHandler, +) + + +def _make_email_method() -> EmailDeliveryMethod: + return EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ExternalRecipient(email="tester@example.com")], + ), + subject="Test subject", + body="Test body", + ) + ) + + +def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), + ) + + handler = EmailDeliveryTestHandler(session_factory=object()) + context = DeliveryTestContext( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + node_title="Human Input", + rendered_content="content", + ) + method = _make_email_method() + + with pytest.raises(DeliveryTestError, match="Email delivery is not available"): + handler.send_test(context=context, method=method) + + +def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): + class DummyMail: + def __init__(self): + self.sent: list[dict[str, str]] = [] + + def is_inited(self) -> bool: + return True + + def send(self, *, to: str, subject: str, html: str): + self.sent.append({"to": to, "subject": subject, "html": html}) + + mail = DummyMail() + monkeypatch.setattr(service_module, "mail", mail) + monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template) + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + + handler = EmailDeliveryTestHandler(session_factory=object()) + handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment] + + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]), + subject="Subject", + body="Value {{#node1.value#}}", + ) + ) + variable_pool = VariablePool() + variable_pool.add(["node1", "value"], "OK") + context = DeliveryTestContext( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + node_title="Human Input", + rendered_content="content", + variable_pool=variable_pool, + ) + + handler.send_test(context=context, method=method) + + assert mail.sent[0]["html"] == "Value OK" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py new file mode 100644 index 0000000000..d2cf74daf3 --- /dev/null +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -0,0 +1,290 @@ +import dataclasses +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest + +import services.human_input_service as human_input_service_module +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) +from core.workflow.nodes.human_input.entities import ( + FormDefinition, + FormInput, + UserAction, +) +from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus +from models.human_input import RecipientType +from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError +from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE + + +@pytest.fixture +def mock_session_factory(): + session = MagicMock() + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = None + + factory = MagicMock() + factory.return_value = session_cm + return factory, session + + +@pytest.fixture +def sample_form_record(): + return HumanInputFormRecord( + form_id="form-id", + workflow_run_id="workflow-run-id", + node_id="node-id", + tenant_id="tenant-id", + app_id="app-id", + form_kind=HumanInputFormKind.RUNTIME, + definition=FormDefinition( + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + rendered_content="hello
", + expiration_time=datetime.utcnow() + timedelta(hours=1), + ), + rendered_content="hello
", + created_at=datetime.utcnow(), + expiration_time=datetime.utcnow() + timedelta(hours=1), + status=HumanInputFormStatus.WAITING, + selected_action_id=None, + submitted_data=None, + submitted_at=None, + submission_user_id=None, + submission_end_user_id=None, + completed_by_recipient_id=None, + recipient_id="recipient-id", + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="token", + ) + + +def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + app = MagicMock() + app.mode = "workflow" + session.execute.return_value.scalar_one_or_none.return_value = app + + resume_task = mocker.patch("services.human_input_service.resume_app_execution") + + service.enqueue_resume("workflow-run-id") + + resume_task.apply_async.assert_called_once() + call_kwargs = resume_task.apply_async.call_args.kwargs + assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE + assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" + + +def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + expired_record = dataclasses.replace( + sample_form_record, + created_at=datetime.utcnow() - timedelta(hours=2), + expiration_time=datetime.utcnow() + timedelta(hours=2), + ) + monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) + + with pytest.raises(FormExpiredError): + service.ensure_form_active(Form(expired_record)) + + +def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + app = MagicMock() + app.mode = "advanced-chat" + session.execute.return_value.scalar_one_or_none.return_value = app + + resume_task = mocker.patch("services.human_input_service.resume_app_execution") + + service.enqueue_resume("workflow-run-id") + + resume_task.apply_async.assert_called_once() + call_kwargs = resume_task.apply_async.call_args.kwargs + assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE + assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id" + + +def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + app = MagicMock() + app.mode = "completion" + session.execute.return_value.scalar_one_or_none.return_value = app + + resume_task = mocker.patch("services.human_input_service.resume_app_execution") + + service.enqueue_resume("workflow-run-id") + + resume_task.apply_async.assert_not_called() + + +def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE) + repo.get_by_token.return_value = console_record + + service = HumanInputService(session_factory, form_repository=repo) + form = service.get_form_definition_by_token_for_console("token") + + repo.get_by_token.assert_called_once_with("token") + assert form is not None + assert form.get_definition() == console_record.definition + + +def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + repo.mark_submitted.return_value = sample_form_record + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="submit", + form_data={"field": "value"}, + submission_end_user_id="end-user-id", + ) + + repo.get_by_token.assert_called_once_with("token") + repo.mark_submitted.assert_called_once() + call_kwargs = repo.mark_submitted.call_args.kwargs + assert call_kwargs["form_id"] == sample_form_record.form_id + assert call_kwargs["recipient_id"] == sample_form_record.recipient_id + assert call_kwargs["selected_action_id"] == "submit" + assert call_kwargs["form_data"] == {"field": "value"} + assert call_kwargs["submission_end_user_id"] == "end-user-id" + enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) + + +def test_submit_form_by_token_skips_enqueue_for_delivery_test(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + test_record = dataclasses.replace( + sample_form_record, + form_kind=HumanInputFormKind.DELIVERY_TEST, + workflow_run_id=None, + ) + repo.get_by_token.return_value = test_record + repo.mark_submitted.return_value = test_record + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="submit", + form_data={"field": "value"}, + ) + + enqueue_spy.assert_not_called() + + +def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + repo.mark_submitted.return_value = sample_form_record + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="submit", + form_data={"field": "value"}, + submission_user_id="account-id", + ) + + call_kwargs = repo.mark_submitted.call_args.kwargs + assert call_kwargs["submission_user_id"] == "account-id" + assert call_kwargs["submission_end_user_id"] is None + enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) + + +def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = dataclasses.replace(sample_form_record) + service = HumanInputService(session_factory, form_repository=repo) + + with pytest.raises(InvalidFormDataError) as exc_info: + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="invalid", + form_data={}, + ) + + assert "Invalid action" in str(exc_info.value) + repo.mark_submitted.assert_not_called() + + +def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + + definition_with_input = FormDefinition( + form_content="hello", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")], + user_actions=sample_form_record.definition.user_actions, + rendered_content="hello
", + expiration_time=sample_form_record.expiration_time, + ) + form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input) + repo.get_by_token.return_value = form_with_input + service = HumanInputService(session_factory, form_repository=repo) + + with pytest.raises(InvalidFormDataError) as exc_info: + service.submit_form_by_token( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token", + selected_action_id="submit", + form_data={}, + ) + + assert "Missing required inputs" in str(exc_info.value) + repo.mark_submitted.assert_not_called() diff --git a/api/tests/unit_tests/services/test_message_service_extra_contents.py b/api/tests/unit_tests/services/test_message_service_extra_contents.py new file mode 100644 index 0000000000..3c8e301caa --- /dev/null +++ b/api/tests/unit_tests/services/test_message_service_extra_contents.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import pytest + +from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData +from services import message_service + + +class _FakeMessage: + def __init__(self, message_id: str): + self.id = message_id + self.extra_contents = None + + def set_extra_contents(self, contents): + self.extra_contents = contents + + +def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None: + messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")] + repo = type( + "Repo", + (), + { + "get_by_message_ids": lambda _self, message_ids: [ + [ + HumanInputContent( + workflow_run_id="workflow-run-1", + submitted=True, + form_submission_data=HumanInputFormSubmissionData( + node_id="node-1", + node_title="Approval", + rendered_content="Rendered", + action_id="approve", + action_text="Approve", + ), + ) + ], + [], + ] + }, + )() + + monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo) + + message_service.attach_message_extra_contents(messages) + + assert messages[0].extra_contents == [ + { + "type": "human_input", + "workflow_run_id": "workflow-run-1", + "submitted": True, + "form_submission_data": { + "node_id": "node-1", + "node_title": "Approval", + "rendered_content": "Rendered", + "action_id": "approve", + "action_text": "Approve", + }, + } + ] + assert messages[1].extra_contents == [] diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index f45a72927e..ded141f01a 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -35,7 +35,6 @@ class TestDataFactory: app_id: str = "app-789", workflow_id: str = "workflow-101", status: str | WorkflowExecutionStatus = "paused", - pause_id: str | None = None, **kwargs, ) -> MagicMock: """Create a mock WorkflowRun object.""" @@ -45,7 +44,6 @@ class TestDataFactory: mock_run.app_id = app_id mock_run.workflow_id = workflow_id mock_run.status = status - mock_run.pause_id = pause_id for key, value in kwargs.items(): setattr(mock_run, key, value) diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py new file mode 100644 index 0000000000..ae59da0a3d --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py @@ -0,0 +1,162 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError +from models.model import App +from models.tools import WorkflowToolProvider +from services.tools import workflow_tools_manage_service + + +class DummyWorkflow: + def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: + self._graph_dict = graph_dict + self.version = version + + @property + def graph_dict(self) -> dict: + return self._graph_dict + + +class FakeQuery: + def __init__(self, result): + self._result = result + + def where(self, *args, **kwargs): + return self + + def first(self): + return self._result + + +class DummySession: + def __init__(self) -> None: + self.added: list[object] = [] + + def __enter__(self) -> "DummySession": + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def add(self, obj) -> None: + self.added.append(obj) + + def begin(self): + return DummyBegin(self) + + +class DummyBegin: + def __init__(self, session: DummySession) -> None: + self._session = session + + def __enter__(self) -> DummySession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class DummySessionContext: + def __init__(self, session: DummySession) -> None: + self._session = session + + def __enter__(self) -> DummySession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class DummySessionFactory: + def __init__(self, session: DummySession) -> None: + self._session = session + + def create_session(self) -> DummySessionContext: + return DummySessionContext(self._session) + + +def _build_fake_session(app) -> SimpleNamespace: + def query(model): + if model is WorkflowToolProvider: + return FakeQuery(None) + if model is App: + return FakeQuery(app) + return FakeQuery(None) + + return SimpleNamespace(query=query) + + +def _build_parameters() -> list[WorkflowToolParameterConfiguration]: + return [ + WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM), + ] + + +def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch): + workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]}) + app = SimpleNamespace(workflow=workflow) + + fake_session = _build_fake_session(app) + monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) + + mock_from_db = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) + mock_invalidate = MagicMock() + + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( + user_id="user-id", + tenant_id="tenant-id", + workflow_app_id="app-id", + name="tool_name", + label="Tool", + icon={"type": "emoji", "emoji": "tool"}, + description="desc", + parameters=_build_parameters(), + ) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + mock_from_db.assert_not_called() + mock_invalidate.assert_not_called() + + +def test_create_workflow_tool_success(monkeypatch): + workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]}) + app = SimpleNamespace(workflow=workflow) + + fake_db = MagicMock() + fake_session = _build_fake_session(app) + fake_db.session = fake_session + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + + dummy_session = DummySession() + monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) + + mock_from_db = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) + + icon = {"type": "emoji", "emoji": "tool"} + + result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( + user_id="user-id", + tenant_id="tenant-id", + workflow_app_id="app-id", + name="tool_name", + label="Tool", + icon=icon, + description="desc", + parameters=_build_parameters(), + ) + + assert result == {"result": "success"} + assert len(dummy_session.added) == 1 + created_provider = dummy_session.added[0] + assert created_provider.name == "tool_name" + assert created_provider.label == "Tool" + assert created_provider.icon == json.dumps(icon) + assert created_provider.version == workflow.version + mock_from_db.assert_called_once() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py new file mode 100644 index 0000000000..844dab8976 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import json +import queue +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import UTC, datetime +from threading import Event + +import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper +from core.workflow.entities.pause_reason import HumanInputRequired +from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from core.workflow.runtime import GraphRuntimeState, VariablePool +from models.enums import CreatorUserRole +from models.model import AppMode +from models.workflow import WorkflowRun +from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot +from repositories.entities.workflow_pause import WorkflowPauseEntity +from services.workflow_event_snapshot_service import ( + BufferState, + MessageContext, + _build_snapshot_events, + _resolve_task_id, +) + + +@dataclass(frozen=True) +class _FakePauseEntity(WorkflowPauseEntity): + pause_id: str + workflow_run_id: str + paused_at_value: datetime + pause_reasons: Sequence[HumanInputRequired] + + @property + def id(self) -> str: + return self.pause_id + + @property + def workflow_execution_id(self) -> str: + return self.workflow_run_id + + def get_state(self) -> bytes: + raise AssertionError("state is not required for snapshot tests") + + @property + def resumed_at(self) -> datetime | None: + return None + + @property + def paused_at(self) -> datetime: + return self.paused_at_value + + def get_pause_reasons(self) -> Sequence[HumanInputRequired]: + return self.pause_reasons + + +def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun: + return WorkflowRun( + id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + type="workflow", + triggered_from="app-run", + version="v1", + graph=None, + inputs=json.dumps({"input": "value"}), + status=status, + outputs=json.dumps({}), + error=None, + elapsed_time=0.0, + total_tokens=0, + total_steps=0, + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + + +def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot: + created_at = datetime(2024, 1, 1, tzinfo=UTC) + finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) + return WorkflowNodeExecutionSnapshot( + execution_id="exec-1", + node_id="node-1", + node_type="human-input", + title="Human Input", + index=1, + status=status.value, + elapsed_time=0.5, + created_at=created_at, + finished_at=finished_at, + iteration_id=None, + loop_id=None, + ) + + +def _build_resumption_context(task_id: str) -> WorkflowResumptionContext: + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-1", + app_id="app-1", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-1", + ) + generate_entity = WorkflowAppGenerateEntity( + task_id=task_id, + app_config=app_config, + inputs={}, + files=[], + user_id="user-1", + stream=True, + invoke_from=InvokeFrom.EXPLORE, + call_depth=0, + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + runtime_state.register_paused_node("node-1") + runtime_state.outputs = {"result": "value"} + wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) + return WorkflowResumptionContext( + generate_entity=wrapper, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + +def test_build_snapshot_events_includes_pause_event() -> None: + workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) + snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) + resumption_context = _build_resumption_context("task-ctx") + pause_entity = _FakePauseEntity( + pause_id="pause-1", + workflow_run_id="run-1", + paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), + pause_reasons=[ + HumanInputRequired( + form_id="form-1", + form_content="content", + node_id="node-1", + node_title="Human Input", + ) + ], + ) + + events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=[snapshot], + task_id="task-ctx", + message_context=None, + pause_entity=pause_entity, + resumption_context=resumption_context, + ) + + assert [event["event"] for event in events] == [ + "workflow_started", + "node_started", + "node_finished", + "workflow_paused", + ] + assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value + pause_data = events[-1]["data"] + assert pause_data["paused_nodes"] == ["node-1"] + assert pause_data["outputs"] == {"result": "value"} + assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value + assert pause_data["created_at"] == int(workflow_run.created_at.timestamp()) + assert pause_data["elapsed_time"] == workflow_run.elapsed_time + assert pause_data["total_tokens"] == workflow_run.total_tokens + assert pause_data["total_steps"] == workflow_run.total_steps + + +def test_build_snapshot_events_applies_message_context() -> None: + workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING) + snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED) + message_context = MessageContext( + conversation_id="conv-1", + message_id="msg-1", + created_at=1700000000, + answer="snapshot message", + ) + + events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=[snapshot], + task_id="task-1", + message_context=message_context, + pause_entity=None, + resumption_context=None, + ) + + assert [event["event"] for event in events] == [ + "workflow_started", + "message_replace", + "node_started", + "node_finished", + ] + assert events[1]["answer"] == "snapshot message" + for event in events: + assert event["conversation_id"] == "conv-1" + assert event["message_id"] == "msg-1" + assert event["created_at"] == 1700000000 + + +@pytest.mark.parametrize( + ("context_task_id", "buffered_task_id", "expected"), + [ + ("task-ctx", "task-buffer", "task-ctx"), + (None, "task-buffer", "task-buffer"), + (None, None, "run-1"), + ], +) +def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None: + resumption_context = _build_resumption_context(context_task_id) if context_task_id else None + buffer_state = BufferState( + queue=queue.Queue(), + stop_event=Event(), + done_event=Event(), + task_id_ready=Event(), + task_id_hint=buffered_task_id, + ) + if buffered_task_id: + buffer_state.task_id_ready.set() + task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0) + assert task_id == expected diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py new file mode 100644 index 0000000000..5ac5ac8ad2 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -0,0 +1,184 @@ +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.orm import sessionmaker + +from core.workflow.enums import NodeType +from core.workflow.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, + MemberRecipient, +) +from services import workflow_service as workflow_service_module +from services.workflow_service import WorkflowService + + +def _make_service() -> WorkflowService: + return WorkflowService(session_maker=sessionmaker()) + + +def _build_node_config(delivery_methods): + node_data = HumanInputNodeData( + title="Human Input", + delivery_methods=delivery_methods, + form_content="Test content", + inputs=[], + user_actions=[], + ).model_dump(mode="json") + node_data["type"] = NodeType.HUMAN_INPUT.value + return {"id": "node-1", "data": node_data} + + +def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod: + return EmailDeliveryMethod( + id=uuid.uuid4(), + enabled=enabled, + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[ExternalRecipient(email="tester@example.com")], + ), + subject="Test subject", + body="Test body", + debug_mode=debug_mode, + ), + ) + + +def test_human_input_delivery_requires_draft_workflow(): + service = _make_service() + service.get_draft_workflow = MagicMock(return_value=None) # type: ignore[method-assign] + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + account = SimpleNamespace(id="account-1") + + with pytest.raises(ValueError, match="Workflow not initialized"): + service.test_human_input_delivery( + app_model=app_model, + account=account, + node_id="node-1", + delivery_method_id="delivery-1", + ) + + +def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch): + service = _make_service() + delivery_method = _make_email_method(enabled=False) + node_config = _build_node_config([delivery_method]) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = node_config + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] + node_stub = MagicMock() + node_stub._render_form_content_before_submission.return_value = "rendered" + node_stub._resolve_default_values.return_value = {} + service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] + service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] + return_value=("form-1", {}) + ) + + test_service_instance = MagicMock() + monkeypatch.setattr( + workflow_service_module, + "HumanInputDeliveryTestService", + MagicMock(return_value=test_service_instance), + ) + + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + account = SimpleNamespace(id="account-1") + + service.test_human_input_delivery( + app_model=app_model, + account=account, + node_id="node-1", + delivery_method_id=str(delivery_method.id), + ) + + test_service_instance.send_test.assert_called_once() + + +def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch): + service = _make_service() + delivery_method = _make_email_method(enabled=True) + node_config = _build_node_config([delivery_method]) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = node_config + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] + node_stub = MagicMock() + node_stub._render_form_content_before_submission.return_value = "rendered" + node_stub._resolve_default_values.return_value = {} + service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] + service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] + return_value=("form-1", {}) + ) + + test_service_instance = MagicMock() + monkeypatch.setattr( + workflow_service_module, + "HumanInputDeliveryTestService", + MagicMock(return_value=test_service_instance), + ) + + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + account = SimpleNamespace(id="account-1") + + service.test_human_input_delivery( + app_model=app_model, + account=account, + node_id="node-1", + delivery_method_id=str(delivery_method.id), + inputs={"#node-1.output#": "value"}, + ) + + pool_args = service._build_human_input_variable_pool.call_args.kwargs + assert pool_args["manual_inputs"] == {"#node-1.output#": "value"} + test_service_instance.send_test.assert_called_once() + + +def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch): + service = _make_service() + delivery_method = _make_email_method(enabled=True, debug_mode=True) + node_config = _build_node_config([delivery_method]) + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = node_config + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined] + node_stub = MagicMock() + node_stub._render_form_content_before_submission.return_value = "rendered" + node_stub._resolve_default_values.return_value = {} + service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined] + service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined] + return_value=("form-1", {}) + ) + + test_service_instance = MagicMock() + monkeypatch.setattr( + workflow_service_module, + "HumanInputDeliveryTestService", + MagicMock(return_value=test_service_instance), + ) + + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + account = SimpleNamespace(id="account-1") + + service.test_human_input_delivery( + app_model=app_model, + account=account, + node_id="node-1", + delivery_method_id=str(delivery_method.id), + ) + + test_service_instance.send_test.assert_called_once() + sent_method = test_service_instance.send_test.call_args.kwargs["method"] + assert isinstance(sent_method, EmailDeliveryMethod) + assert sent_method.config.debug_mode is True + assert sent_method.config.recipients.whole_workspace is False + assert len(sent_method.config.recipients.items) == 1 + recipient = sent_method.config.recipients.items[0] + assert isinstance(recipient, MemberRecipient) + assert recipient.user_id == account.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py index 32d2f8b7e0..70d7bde870 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.workflow.enums import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel from repositories.sqlalchemy_api_workflow_node_execution_repository import ( DifyAPISQLAlchemyWorkflowNodeExecutionRepository, @@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: call_args = mock_session.scalar.call_args[0][0] assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + compiled = call_args.compile() + assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values() + def test_get_node_last_execution_not_found(self, repository): """Test getting the last execution for a node when it doesn't exist.""" # Arrange @@ -71,28 +75,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result is None mock_session.scalar.assert_called_once() - def test_get_executions_by_workflow_run(self, repository, mock_execution): - """Test getting all executions for a workflow run.""" - # Arrange - mock_session = MagicMock(spec=Session) - repository._session_maker.return_value.__enter__.return_value = mock_session - executions = [mock_execution] - mock_session.execute.return_value.scalars.return_value.all.return_value = executions - - # Act - result = repository.get_executions_by_workflow_run( - tenant_id="tenant-123", - app_id="app-456", - workflow_run_id="run-101", - ) - - # Assert - assert result == executions - mock_session.execute.assert_called_once() - # Verify the query was constructed correctly - call_args = mock_session.execute.call_args[0][0] - assert hasattr(call_args, "compile") # It's a SQLAlchemy statement - def test_get_executions_by_workflow_run_empty(self, repository): """Test getting executions for a workflow run when none exist.""" # Arrange diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 9700cbaf0e..015dac257e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -1,9 +1,15 @@ +from contextlib import nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from core.workflow.enums import NodeType +from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import FormInputType from models.model import App from models.workflow import Workflow +from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService @@ -161,3 +167,120 @@ class TestWorkflowService: assert workflows == [] assert has_more is False mock_session.scalars.assert_called_once() + + def test_submit_human_input_form_preview_uses_rendered_content( + self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch + ) -> None: + service = workflow_service + node_data = HumanInputNodeData( + title="Human Input", + form_content="{{#$output.name#}}
", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + ) + node = MagicMock() + node.node_data = node_data + node.render_form_content_before_submission.return_value = "preview
" + node.render_form_content_with_outputs.return_value = "rendered
" + + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] + + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + workflow.get_enclosing_node_type_and_id.return_value = None + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + + saved_outputs: dict[str, object] = {} + + class DummySession: + def __init__(self, *args, **kwargs): + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return nullcontext() + + class DummySaver: + def __init__(self, *args, **kwargs): + pass + + def save(self, outputs, process_data): + saved_outputs.update(outputs) + + monkeypatch.setattr(workflow_service_module, "Session", DummySession) + monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + account = SimpleNamespace(id="account-1") + + result = service.submit_human_input_form_preview( + app_model=app_model, + account=account, + node_id="node-1", + form_inputs={"name": "Ada", "extra": "ignored"}, + inputs={"#node-0.result#": "LLM output"}, + action="approve", + ) + + service._build_human_input_variable_pool.assert_called_once_with( + app_model=app_model, + workflow=workflow, + node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}, + manual_inputs={"#node-0.result#": "LLM output"}, + ) + + node.render_form_content_with_outputs.assert_called_once() + called_args = node.render_form_content_with_outputs.call_args.args + assert called_args[0] == "preview
" + assert called_args[2] == node_data.outputs_field_names() + rendered_outputs = called_args[1] + assert rendered_outputs["name"] == "Ada" + assert rendered_outputs["extra"] == "ignored" + assert "extra" in saved_outputs + assert "extra" in result + assert saved_outputs["name"] == "Ada" + assert result["name"] == "Ada" + assert result["__action_id"] == "approve" + assert "__rendered_content" in result + + def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None: + service = workflow_service + node_data = HumanInputNodeData( + title="Human Input", + form_content="{{#$output.name#}}
", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + ) + node = MagicMock() + node.node_data = node_data + node._render_form_content_before_submission.return_value = "preview
" + node._render_form_content_with_outputs.return_value = "rendered
" + + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] + + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + account = SimpleNamespace(id="account-1") + + with pytest.raises(ValueError) as exc_info: + service.submit_human_input_form_preview( + app_model=app_model, + account=account, + node_id="node-1", + form_inputs={}, + inputs={}, + action="approve", + ) + + assert "Missing required inputs" in str(exc_info.value) diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py new file mode 100644 index 0000000000..ee0699ba2d --- /dev/null +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from types import SimpleNamespace +from typing import Any + +import pytest + +from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from tasks import human_input_timeout_tasks as task_module + + +class _FakeScalarResult: + def __init__(self, items: list[Any]): + self._items = items + + def all(self) -> list[Any]: + return self._items + + +class _FakeSession: + def __init__(self, items: list[Any], capture: dict[str, Any]): + self._items = items + self._capture = capture + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, stmt): + self._capture["stmt"] = stmt + return _FakeScalarResult(self._items) + + +class _FakeSessionFactory: + def __init__(self, items: list[Any], capture: dict[str, Any]): + self._items = items + self._capture = capture + self._capture["session_factory"] = self + + def __call__(self): + session = _FakeSession(self._items, self._capture) + self._capture["session"] = session + return session + + +class _FakeFormRepo: + def __init__(self, _session_factory, form_map: dict[str, Any] | None = None): + self.calls: list[dict[str, Any]] = [] + self._form_map = form_map or {} + + def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None): + self.calls.append( + { + "form_id": form_id, + "timeout_status": timeout_status, + "reason": reason, + } + ) + form = self._form_map.get(form_id) + return SimpleNamespace( + form_id=form_id, + workflow_run_id=getattr(form, "workflow_run_id", None), + node_id=getattr(form, "node_id", None), + ) + + +class _FakeService: + def __init__(self, _session_factory, form_repository=None): + self.enqueued: list[str] = [] + + def enqueue_resume(self, workflow_run_id: str | None) -> None: + if workflow_run_id is not None: + self.enqueued.append(workflow_run_id) + + +def _build_form( + *, + form_id: str, + form_kind: HumanInputFormKind, + created_at: datetime, + expiration_time: datetime, + workflow_run_id: str | None, + node_id: str, +) -> SimpleNamespace: + return SimpleNamespace( + id=form_id, + form_kind=form_kind, + created_at=created_at, + expiration_time=expiration_time, + workflow_run_id=workflow_run_id, + node_id=node_id, + status=HumanInputFormStatus.WAITING, + ) + + +def test_is_global_timeout_uses_created_at(): + now = datetime(2025, 1, 1, 12, 0, 0) + form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1") + + assert task_module._is_global_timeout(form, 60, now=now) is True + + form.workflow_run_id = None + assert task_module._is_global_timeout(form, 60, now=now) is False + + form.workflow_run_id = "run-1" + form.created_at = now - timedelta(seconds=59) + assert task_module._is_global_timeout(form, 60, now=now) is False + + assert task_module._is_global_timeout(form, 0, now=now) is False + + +def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch): + now = datetime(2025, 1, 1, 12, 0, 0) + monkeypatch.setattr(task_module, "naive_utc_now", lambda: now) + monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) + monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object())) + + forms = [ + _build_form( + form_id="form-global", + form_kind=HumanInputFormKind.RUNTIME, + created_at=now - timedelta(hours=2), + expiration_time=now + timedelta(hours=1), + workflow_run_id="run-global", + node_id="node-global", + ), + _build_form( + form_id="form-node", + form_kind=HumanInputFormKind.RUNTIME, + created_at=now - timedelta(minutes=5), + expiration_time=now - timedelta(seconds=1), + workflow_run_id="run-node", + node_id="node-node", + ), + _build_form( + form_id="form-delivery", + form_kind=HumanInputFormKind.DELIVERY_TEST, + created_at=now - timedelta(minutes=1), + expiration_time=now - timedelta(seconds=1), + workflow_run_id=None, + node_id="node-delivery", + ), + ] + + capture: dict[str, Any] = {} + monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture)) + + form_map = {form.id: form for form in forms} + repo = _FakeFormRepo(None, form_map=form_map) + + def _repo_factory(_session_factory): + return repo + + service = _FakeService(None) + + def _service_factory(_session_factory, form_repository=None): + return service + + global_calls: list[dict[str, Any]] = [] + + monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory) + monkeypatch.setattr(task_module, "HumanInputService", _service_factory) + monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs)) + + task_module.check_and_handle_human_input_timeouts(limit=100) + + assert {(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls} == { + ("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"), + ("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"), + ("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"), + } + assert service.enqueued == ["run-node"] + assert global_calls == [ + { + "form_id": "form-global", + "workflow_run_id": "run-global", + "node_id": "node-global", + "session_factory": capture.get("session_factory"), + } + ] + + stmt = capture.get("stmt") + assert stmt is not None + stmt_text = str(stmt) + assert "created_at <=" in stmt_text + assert "expiration_time <=" in stmt_text + assert "ORDER BY human_input_forms.id" in stmt_text + + +def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch): + now = datetime(2025, 1, 1, 12, 0, 0) + monkeypatch.setattr(task_module, "naive_utc_now", lambda: now) + monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0) + monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object())) + + capture: dict[str, Any] = {} + monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture)) + monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo) + monkeypatch.setattr(task_module, "HumanInputService", _FakeService) + monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None) + + task_module.check_and_handle_human_input_timeouts(limit=1) + + stmt = capture.get("stmt") + assert stmt is not None + stmt_text = str(stmt) + assert "created_at <=" not in stmt_text diff --git a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py new file mode 100644 index 0000000000..20cb7a211e --- /dev/null +++ b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py @@ -0,0 +1,123 @@ +from collections.abc import Sequence +from types import SimpleNamespace + +import pytest + +from tasks import mail_human_input_delivery_task as task_module + + +class _DummyMail: + def __init__(self): + self.sent: list[dict[str, str]] = [] + self._inited = True + + def is_inited(self) -> bool: + return self._inited + + def send(self, *, to: str, subject: str, html: str): + self.sent.append({"to": to, "subject": subject, "html": html}) + + +class _DummySession: + def __init__(self, form): + self._form = form + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def get(self, _model, _form_id): + return self._form + + +def _build_job(recipient_count: int = 1) -> task_module._EmailDeliveryJob: + recipients: list[task_module._EmailRecipient] = [] + for idx in range(recipient_count): + recipients.append(task_module._EmailRecipient(email=f"user{idx}@example.com", token=f"token-{idx}")) + + return task_module._EmailDeliveryJob( + form_id="form-1", + subject="Subject", + body="Body for {{#url}}", + form_content="content", + recipients=recipients, + ) + + +def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: pytest.MonkeyPatch): + mail = _DummyMail() + form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) + + monkeypatch.setattr(task_module, "mail", mail) + monkeypatch.setattr( + task_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)] + monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs) + + task_module.dispatch_human_input_email_task( + form_id="form-1", + node_title="Approve", + session_factory=lambda: _DummySession(form), + ) + + assert len(mail.sent) == 2 + assert all(payload["subject"] == "Subject" for payload in mail.sent) + assert all("Body for" in payload["html"] for payload in mail.sent) + + +def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): + mail = _DummyMail() + form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) + + monkeypatch.setattr(task_module, "mail", mail) + monkeypatch.setattr( + task_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), + ) + monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: []) + + task_module.dispatch_human_input_email_task( + form_id="form-1", + node_title="Approve", + session_factory=lambda: _DummySession(form), + ) + + assert mail.sent == [] + + +def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): + mail = _DummyMail() + form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id="run-1") + job = task_module._EmailDeliveryJob( + form_id="form-1", + subject="Subject", + body="Body {{#node1.value#}}", + form_content="content", + recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")], + ) + + variable_pool = task_module.VariablePool() + variable_pool.add(["node1", "value"], "OK") + + monkeypatch.setattr(task_module, "mail", mail) + monkeypatch.setattr( + task_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job]) + monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool) + + task_module.dispatch_human_input_email_task( + form_id="form-1", + node_title="Approve", + session_factory=lambda: _DummySession(form), + ) + + assert mail.sent[0]["html"] == "Body OK" diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py new file mode 100644 index 0000000000..161151305d --- /dev/null +++ b/api/tests/unit_tests/tasks/test_workflow_execute_task.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import json +import uuid +from unittest.mock import MagicMock + +import pytest + +from models.model import AppMode +from tasks.app_generate.workflow_execute_task import _publish_streaming_response + + +@pytest.fixture +def mock_topic(mocker) -> MagicMock: + topic = MagicMock() + mocker.patch( + "tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic", + return_value=topic, + ) + return topic + + +def test_publish_streaming_response_with_uuid(mock_topic: MagicMock): + workflow_run_id = uuid.uuid4() + response_stream = iter([{"event": "foo"}, "ping"]) + + _publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT) + + payloads = [call.args[0] for call in mock_topic.publish.call_args_list] + assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()] + + +def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): + workflow_run_id = uuid.uuid4() + response_stream = iter([{"event": "bar"}]) + + _publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT) + + mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode()) diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py new file mode 100644 index 0000000000..fd5f0713a4 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py @@ -0,0 +1,488 @@ +# """ +# Unit tests for workflow node execution Celery tasks. + +# These tests verify the asynchronous storage functionality for workflow node execution data, +# including truncation and offloading logic. +# """ + +# import json +# from unittest.mock import MagicMock, Mock, patch +# from uuid import uuid4 + +# import pytest + +# from core.workflow.entities.workflow_node_execution import ( +# WorkflowNodeExecution, +# WorkflowNodeExecutionStatus, +# ) +# from core.workflow.enums import NodeType +# from libs.datetime_utils import naive_utc_now +# from models import WorkflowNodeExecutionModel +# from models.enums import ExecutionOffLoadType +# from models.model import UploadFile +# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom +# from tasks.workflow_node_execution_tasks import ( +# _create_truncator, +# _json_encode, +# _replace_or_append_offload, +# _truncate_and_upload_async, +# save_workflow_node_execution_data_task, +# save_workflow_node_execution_task, +# ) + + +# @pytest.fixture +# def sample_execution_data(): +# """Sample execution data for testing.""" +# execution = WorkflowNodeExecution( +# id=str(uuid4()), +# node_execution_id=str(uuid4()), +# workflow_id=str(uuid4()), +# workflow_execution_id=str(uuid4()), +# index=1, +# node_id="test_node", +# node_type=NodeType.LLM, +# title="Test Node", +# inputs={"input_key": "input_value"}, +# outputs={"output_key": "output_value"}, +# process_data={"process_key": "process_value"}, +# status=WorkflowNodeExecutionStatus.RUNNING, +# created_at=naive_utc_now(), +# ) +# return execution.model_dump() + + +# @pytest.fixture +# def mock_db_model(): +# """Mock database model for testing.""" +# db_model = Mock(spec=WorkflowNodeExecutionModel) +# db_model.id = "test-execution-id" +# db_model.offload_data = [] +# return db_model + + +# @pytest.fixture +# def mock_file_service(): +# """Mock file service for testing.""" +# file_service = Mock() +# mock_upload_file = Mock(spec=UploadFile) +# mock_upload_file.id = "mock-file-id" +# file_service.upload_file.return_value = mock_upload_file +# return file_service + + +# class TestSaveWorkflowNodeExecutionDataTask: +# """Test cases for save_workflow_node_execution_data_task.""" + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_execution_data_task_success( +# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model +# ): +# """Test successful execution of save_workflow_node_execution_data_task.""" +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session +# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model + +# # Execute task +# result = save_workflow_node_execution_data_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# user_data={"user_id": "test-user-id", "user_type": "account"}, +# ) + +# # Verify success +# assert result is True +# mock_session.merge.assert_called_once_with(mock_db_model) +# mock_session.commit.assert_called_once() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker, +# sample_execution_data): +# """Test task when execution is not found in database.""" +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session +# mock_session.execute.return_value.scalars.return_value.first.return_value = None + +# # Execute task +# result = save_workflow_node_execution_data_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# user_data={"user_id": "test-user-id", "user_type": "account"}, +# ) + +# # Verify failure +# assert result is False +# mock_session.merge.assert_not_called() +# mock_session.commit.assert_not_called() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model): +# """Test task with data that requires truncation.""" +# # Create execution with large data +# large_data = {"large_field": "x" * 10000} +# execution = WorkflowNodeExecution( +# id=str(uuid4()), +# node_execution_id=str(uuid4()), +# workflow_id=str(uuid4()), +# workflow_execution_id=str(uuid4()), +# index=1, +# node_id="test_node", +# node_type=NodeType.LLM, +# title="Test Node", +# inputs=large_data, +# outputs=large_data, +# process_data=large_data, +# status=WorkflowNodeExecutionStatus.RUNNING, +# created_at=naive_utc_now(), +# ) +# execution_data = execution.model_dump() + +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session +# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model + +# # Create mock upload file +# mock_upload_file = Mock(spec=UploadFile) +# mock_upload_file.id = "mock-file-id" + +# # Execute task +# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate: +# # Mock truncation results +# mock_truncate.return_value = { +# "truncated_value": {"large_field": "[TRUNCATED]"}, +# "file": mock_upload_file, +# "offload": WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# node_execution_id=execution.id, +# type_=ExecutionOffLoadType.INPUTS, +# file_id=mock_upload_file.id, +# ), +# } + +# result = save_workflow_node_execution_data_task( +# execution_data=execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# user_data={"user_id": "test-user-id", "user_type": "account"}, +# ) + +# # Verify success and truncation was called +# assert result is True +# assert mock_truncate.call_count == 3 # inputs, outputs, process_data +# mock_session.merge.assert_called_once_with(mock_db_model) +# mock_session.commit.assert_called_once() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data): +# """Test task retry mechanism on exception.""" +# # Setup mock to raise exception +# mock_sessionmaker.side_effect = Exception("Database error") + +# # Create a mock task instance with proper retry behavior +# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry: +# mock_retry.side_effect = Exception("Retry called") + +# # Execute task and expect retry +# with pytest.raises(Exception, match="Retry called"): +# save_workflow_node_execution_data_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# user_data={"user_id": "test-user-id", "user_type": "account"}, +# ) + +# # Verify retry was called +# mock_retry.assert_called_once() + + +# class TestTruncateAndUploadAsync: +# """Test cases for _truncate_and_upload_async function.""" + +# def test_truncate_and_upload_with_none_values(self, mock_file_service): +# """Test _truncate_and_upload_async with None values.""" +# # The function handles None values internally, so we test with empty dict instead +# result = _truncate_and_upload_async( +# values={}, +# execution_id="test-id", +# type_=ExecutionOffLoadType.INPUTS, +# tenant_id="test-tenant", +# app_id="test-app", +# user_data={"user_id": "test-user", "user_type": "account"}, +# file_service=mock_file_service, +# ) + +# # Empty dict should not require truncation +# assert result is None +# mock_file_service.upload_file.assert_not_called() + +# @patch("tasks.workflow_node_execution_tasks._create_truncator") +# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service): +# """Test _truncate_and_upload_async when no truncation is needed.""" +# # Mock truncator to return no truncation +# mock_truncator = Mock() +# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False) +# mock_create_truncator.return_value = mock_truncator + +# small_values = {"small": "data"} +# result = _truncate_and_upload_async( +# values=small_values, +# execution_id="test-id", +# type_=ExecutionOffLoadType.INPUTS, +# tenant_id="test-tenant", +# app_id="test-app", +# user_data={"user_id": "test-user", "user_type": "account"}, +# file_service=mock_file_service, +# ) + +# assert result is None +# mock_file_service.upload_file.assert_not_called() + +# @patch("tasks.workflow_node_execution_tasks._create_truncator") +# @patch("models.Account") +# @patch("models.Tenant") +# def test_truncate_and_upload_with_account_user( +# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service +# ): +# """Test _truncate_and_upload_async with account user.""" +# # Mock truncator to return truncation needed +# mock_truncator = Mock() +# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True) +# mock_create_truncator.return_value = mock_truncator + +# # Mock user and tenant creation +# mock_account = Mock() +# mock_account.id = "test-user" +# mock_account_class.return_value = mock_account + +# mock_tenant = Mock() +# mock_tenant.id = "test-tenant" +# mock_tenant_class.return_value = mock_tenant + +# large_values = {"large": "x" * 10000} +# result = _truncate_and_upload_async( +# values=large_values, +# execution_id="test-id", +# type_=ExecutionOffLoadType.INPUTS, +# tenant_id="test-tenant", +# app_id="test-app", +# user_data={"user_id": "test-user", "user_type": "account"}, +# file_service=mock_file_service, +# ) + +# # Verify result structure +# assert result is not None +# assert "truncated_value" in result +# assert "file" in result +# assert "offload" in result +# assert result["truncated_value"] == {"truncated": "data"} + +# # Verify file upload was called +# mock_file_service.upload_file.assert_called_once() +# upload_call = mock_file_service.upload_file.call_args +# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json" +# assert upload_call[1]["mimetype"] == "application/json" +# assert upload_call[1]["user"] == mock_account + +# @patch("tasks.workflow_node_execution_tasks._create_truncator") +# @patch("models.EndUser") +# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service): +# """Test _truncate_and_upload_async with end user.""" +# # Mock truncator to return truncation needed +# mock_truncator = Mock() +# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True) +# mock_create_truncator.return_value = mock_truncator + +# # Mock end user creation +# mock_end_user = Mock() +# mock_end_user.id = "test-user" +# mock_end_user.tenant_id = "test-tenant" +# mock_end_user_class.return_value = mock_end_user + +# large_values = {"large": "x" * 10000} +# result = _truncate_and_upload_async( +# values=large_values, +# execution_id="test-id", +# type_=ExecutionOffLoadType.OUTPUTS, +# tenant_id="test-tenant", +# app_id="test-app", +# user_data={"user_id": "test-user", "user_type": "end_user"}, +# file_service=mock_file_service, +# ) + +# # Verify result structure +# assert result is not None +# assert result["truncated_value"] == {"truncated": "data"} + +# # Verify file upload was called with end user +# mock_file_service.upload_file.assert_called_once() +# upload_call = mock_file_service.upload_file.call_args +# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json" +# assert upload_call[1]["user"] == mock_end_user + + +# class TestHelperFunctions: +# """Test cases for helper functions.""" + +# @patch("tasks.workflow_node_execution_tasks.dify_config") +# def test_create_truncator(self, mock_config): +# """Test _create_truncator function.""" +# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 +# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 +# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 + +# truncator = _create_truncator() + +# # Verify truncator was created with correct config +# assert truncator is not None + +# def test_json_encode(self): +# """Test _json_encode function.""" +# test_data = {"key": "value", "number": 42} +# result = _json_encode(test_data) + +# assert isinstance(result, str) +# decoded = json.loads(result) +# assert decoded == test_data + +# def test_replace_or_append_offload_replace_existing(self): +# """Test _replace_or_append_offload replaces existing offload of same type.""" +# existing_offload = WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant", +# app_id="test-app", +# node_execution_id="test-execution", +# type_=ExecutionOffLoadType.INPUTS, +# file_id="old-file-id", +# ) + +# new_offload = WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant", +# app_id="test-app", +# node_execution_id="test-execution", +# type_=ExecutionOffLoadType.INPUTS, +# file_id="new-file-id", +# ) + +# result = _replace_or_append_offload([existing_offload], new_offload) + +# assert len(result) == 1 +# assert result[0].file_id == "new-file-id" + +# def test_replace_or_append_offload_append_new_type(self): +# """Test _replace_or_append_offload appends new offload of different type.""" +# existing_offload = WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant", +# app_id="test-app", +# node_execution_id="test-execution", +# type_=ExecutionOffLoadType.INPUTS, +# file_id="inputs-file-id", +# ) + +# new_offload = WorkflowNodeExecutionOffload( +# id=str(uuid4()), +# tenant_id="test-tenant", +# app_id="test-app", +# node_execution_id="test-execution", +# type_=ExecutionOffLoadType.OUTPUTS, +# file_id="outputs-file-id", +# ) + +# result = _replace_or_append_offload([existing_offload], new_offload) + +# assert len(result) == 2 +# file_ids = [offload.file_id for offload in result] +# assert "inputs-file-id" in file_ids +# assert "outputs-file-id" in file_ids + + +# class TestSaveWorkflowNodeExecutionTask: +# """Test cases for save_workflow_node_execution_task.""" + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker, +# sample_execution_data): +# """Test creating a new workflow node execution.""" +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session +# mock_session.scalar.return_value = None # No existing execution + +# # Execute task +# result = save_workflow_node_execution_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, +# creator_user_id="test-user-id", +# creator_user_role="account", +# ) + +# # Verify success +# assert result is True +# mock_session.add.assert_called_once() +# mock_session.commit.assert_called_once() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# @patch("tasks.workflow_node_execution_tasks.select") +# def test_save_workflow_node_execution_task_update_existing( +# self, mock_select, mock_sessionmaker, sample_execution_data +# ): +# """Test updating an existing workflow node execution.""" +# # Setup mocks +# mock_session = MagicMock() +# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session + +# existing_execution = Mock(spec=WorkflowNodeExecutionModel) +# mock_session.scalar.return_value = existing_execution + +# # Execute task +# result = save_workflow_node_execution_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, +# creator_user_id="test-user-id", +# creator_user_role="account", +# ) + +# # Verify success +# assert result is True +# mock_session.add.assert_not_called() # Should not add new, just update existing +# mock_session.commit.assert_called_once() + +# @patch("tasks.workflow_node_execution_tasks.sessionmaker") +# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data): +# """Test task retry mechanism on exception.""" +# # Setup mock to raise exception +# mock_sessionmaker.side_effect = Exception("Database error") + +# # Create a mock task instance with proper retry behavior +# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry: +# mock_retry.side_effect = Exception("Retry called") + +# # Execute task and expect retry +# with pytest.raises(Exception, match="Retry called"): +# save_workflow_node_execution_task( +# execution_data=sample_execution_data, +# tenant_id="test-tenant-id", +# app_id="test-app-id", +# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, +# creator_user_id="test-user-id", +# creator_user_role="account", +# ) + +# # Verify retry was called +# mock_retry.assert_called_once() diff --git a/api/ty.toml b/api/ty.toml index 380e14dbef..ace2b7c0e8 100644 --- a/api/ty.toml +++ b/api/ty.toml @@ -25,10 +25,26 @@ exclude = [ # non-producition or generated code "migrations", "tests", + # targeted ignores for current type-check errors + # TODO(QuantumGhost): suppress type errors in HITL related code. + # fix the type error later + "configs/middleware/cache/redis_pubsub_config.py", + "extensions/ext_redis.py", + "models/execution_extra_content.py", + "tasks/workflow_execution_tasks.py", + "core/workflow/nodes/base/node.py", + "services/human_input_delivery_test_service.py", + "core/app/apps/advanced_chat/app_generator.py", + "controllers/console/human_input_form.py", + "controllers/console/app/workflow_run.py", + "repositories/sqlalchemy_api_workflow_node_execution_repository.py", + "extensions/logstore/repositories/logstore_api_workflow_run_repository.py", + "controllers/web/workflow_events.py", + "tasks/app_generate/workflow_execute_task.py", ] [rules] deprecated = "ignore" unused-ignore-comment = "ignore" -# possibly-missing-attribute = "ignore" \ No newline at end of file +# possibly-missing-attribute = "ignore" diff --git a/docker/.env.example b/docker/.env.example index 41a0205bf5..93099347bd 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1399,9 +1399,9 @@ PLUGIN_STDIO_BUFFER_SIZE=1024 PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 -# Plugin Daemon side timeout (configure to match the API side below) +# Plugin Daemon side timeout (configure to match the API side below) PLUGIN_MAX_EXECUTION_TIMEOUT=600 -# API side timeout (configure to match the Plugin Daemon side above) +# API side timeout (configure to match the Plugin Daemon side above) PLUGIN_DAEMON_TIMEOUT=600.0 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple PIP_MIRROR_URL= @@ -1519,4 +1519,31 @@ AMPLITUDE_API_KEY= SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 + + +# Redis URL used for PubSub between API and +# celery worker +# defaults to url constructed from `REDIS_*` +# configurations +PUBSUB_REDIS_URL= +# Pub/sub channel type for streaming events. +# valid options are: +# +# - pubsub: for normal Pub/Sub +# - sharded: for sharded Pub/Sub +# +# It's highly recommended to use sharded Pub/Sub AND redis cluster +# for large deployments. +PUBSUB_REDIS_CHANNEL_TYPE=pubsub +# Whether to use Redis cluster mode while running +# PubSub. +# It's highly recommended to enable this for large deployments. +PUBSUB_REDIS_USE_CLUSTERS=false + +# Whether to Enable human input timeout check task +ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true +# Human input timeout check interval in minutes +HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1 + + SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1886f848e0..161fdc6c3f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -683,6 +683,11 @@ x-shared-env: &shared-api-worker-env SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} + PUBSUB_REDIS_URL: ${PUBSUB_REDIS_URL:-} + PUBSUB_REDIS_CHANNEL_TYPE: ${PUBSUB_REDIS_CHANNEL_TYPE:-pubsub} + PUBSUB_REDIS_USE_CLUSTERS: ${PUBSUB_REDIS_USE_CLUSTERS:-false} + ENABLE_HUMAN_INPUT_TIMEOUT_TASK: ${ENABLE_HUMAN_INPUT_TIMEOUT_TASK:-true} + HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: ${HUMAN_INPUT_TIMEOUT_TASK_INTERVAL:-1} SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: ${SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL:-90000} services: diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts index 373c2f86d3..d3296bacd0 100644 --- a/web/__mocks__/provider-context.ts +++ b/web/__mocks__/provider-context.ts @@ -35,6 +35,7 @@ export const baseProviderContextValue: ProviderContextState = { refreshLicenseLimit: noop, isAllowTransferWorkspace: false, isAllowPublishAsCustomKnowledgePipelineTemplate: false, + humanInputEmailDeliveryEnabled: false, } export const createMockProviderContextValue = (overrides: Partial+ {t('nodes.humanInput.deliveryMethod.upgradeTip', { ns: 'workflow' })} +
++ {t('nodes.humanInput.deliveryMethod.upgradeTipContent', { ns: 'workflow' })} +
++
{category}
-+
{title}