diff --git a/api/tests/unit_tests/controllers/console/app/test_message.py b/api/tests/unit_tests/controllers/console/app/test_message.py index 3ffa53b6db..d5c7de143b 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message.py +++ b/api/tests/unit_tests/controllers/console/app/test_message.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask, request from werkzeug.exceptions import InternalServerError, NotFound -from werkzeug.local import LocalProxy from controllers.console.app.error import ( ProviderModelCurrentlyNotSupportError, @@ -76,6 +75,7 @@ def setup_test_context( patch("extensions.ext_database.db") as mock_db, patch("controllers.console.app.wraps.db", mock_db), patch("controllers.console.wraps.db", mock_db), + patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), patch("controllers.console.app.message.db", mock_db), patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), @@ -99,14 +99,12 @@ def setup_test_context( mock_db.data_query = data_query_mock # Let the caller override the stat db query logic - proxy_mock = LocalProxy(lambda: mock_account) - query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()]) full_path = f"{route_path}?{query_string}" if qs else route_path with ( - patch("libs.login.current_user", proxy_mock), - patch("flask_login.current_user", proxy_mock), + patch("libs.login._get_user", return_value=mock_account), + patch("flask_login.current_user", mock_account), patch("controllers.console.app.message.attach_message_extra_contents", return_value=None), ): with test_app.test_request_context(full_path, method=method, json=payload): diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic.py b/api/tests/unit_tests/controllers/console/app/test_statistic.py index beba23385d..7d62a3de51 100644 --- a/api/tests/unit_tests/controllers/console/app/test_statistic.py +++ b/api/tests/unit_tests/controllers/console/app/test_statistic.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask, request -from werkzeug.local import LocalProxy from controllers.console.app.statistic import ( AverageResponseTimeStatistic, @@ -81,9 +80,7 @@ def setup_test_context( mock_query.where.return_value.where.return_value.first.return_value = mock_app_model mock_db_wraps.session.query.return_value = mock_query - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with patch("libs.login._get_user", return_value=mock_account), patch("flask_login.current_user", mock_account): with test_app.test_request_context(route_path, method="GET"): request.view_args = {"app_id": "app_123"} api_instance = endpoint_class() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py index 9b5d47c208..c8c605d6dc 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask, request -from werkzeug.local import LocalProxy from controllers.console.app.error import DraftWorkflowNotExist from controllers.console.app.workflow_draft_variable import ( @@ -60,6 +59,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun with ( patch("controllers.console.app.wraps.db") as mock_db_wraps, patch("controllers.console.wraps.db", mock_db_wraps), + patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), patch("controllers.console.app.workflow_draft_variable.db"), patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), @@ -71,9 +71,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun mock_query.where.return_value.where.return_value.first.return_value = mock_app_model mock_db_wraps.session.query.return_value = mock_query - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with patch("libs.login._get_user", return_value=mock_account), patch("flask_login.current_user", mock_account): with test_app.test_request_context(route_path, method=method, json=payload): request.view_args = {"app_id": "app_123"} # extract node_id or variable_id from path manually since view_args overrides @@ -101,6 +99,7 @@ class TestWorkflowDraftVariableEndpoints: mock_var = MagicMock() mock_var.app_id = "app_123" + mock_var.user_id = "user_123" mock_var.id = "var_123" mock_var.name = "test_var" mock_var.description = "" diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py index bc4c7e0993..0b68b7e1d9 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -49,9 +49,7 @@ class TestApiKeyAuthDataSource: ), ): with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): + with patch("libs.login._get_user", return_value=mock_account): api_instance = ApiKeyAuthDataSource() response = api_instance.get() @@ -81,9 +79,7 @@ class TestApiKeyAuthDataSource: ), ): with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): + with patch("libs.login._get_user", return_value=mock_account): api_instance = ApiKeyAuthDataSource() response = api_instance.get() @@ -124,9 +120,10 @@ class TestApiKeyAuthDataSourceBinding: method="POST", json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with ( + patch("libs.login._get_user", return_value=mock_account), + patch("flask_login.current_user", mock_account), + ): api_instance = ApiKeyAuthDataSourceBinding() response = api_instance.post() @@ -162,9 +159,10 @@ class TestApiKeyAuthDataSourceBinding: method="POST", json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with ( + patch("libs.login._get_user", return_value=mock_account), + patch("flask_login.current_user", mock_account), + ): api_instance = ApiKeyAuthDataSourceBinding() with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"): api_instance.post() @@ -198,9 +196,10 @@ class TestApiKeyAuthDataSourceBindingDelete: ), ): with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with ( + patch("libs.login._get_user", return_value=mock_account), + patch("flask_login.current_user", mock_account), + ): api_instance = ApiKeyAuthDataSourceBindingDelete() response = api_instance.delete("binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py index f369565946..db29495121 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from werkzeug.local import LocalProxy from controllers.console.auth.data_source_oauth import ( OAuthDataSource, @@ -21,12 +20,12 @@ class TestOAuthDataSource: @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") @patch("flask_login.current_user") - @patch("libs.login.current_user") + @patch("libs.login._get_user") @patch("libs.login.check_csrf_token") @patch("controllers.console.wraps.db") @patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None) def test_get_oauth_url_successful( - self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app + self, mock_db, mock_csrf, mock_get_user, mock_flask_user, mock_get_providers, app ): mock_oauth_provider = MagicMock() mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth" @@ -39,16 +38,14 @@ class TestOAuthDataSource: mock_account.status = AccountStatus.ACTIVE mock_account.is_admin_or_owner = True mock_account.current_tenant.current_role = "owner" - mock_libs_user.return_value = mock_account + mock_get_user.return_value = mock_account mock_flask_user.return_value = mock_account # also patch current_account_with_tenant with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSource() - response = api_instance.get("notion") + api_instance = OAuthDataSource() + response = api_instance.get("notion") assert response[0]["data"] == "http://oauth.provider/auth" assert response[1] == 200 @@ -71,8 +68,7 @@ class TestOAuthDataSource: with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): + with patch("libs.login._get_user", return_value=mock_account): api_instance = OAuthDataSource() response = api_instance.get("unknown_provider") @@ -181,8 +177,7 @@ class TestOAuthDataSourceSync: with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): + with patch("libs.login._get_user", return_value=mock_account): api_instance = OAuthDataSourceSync() # The route pattern uses , so we just pass a string for unit testing response = api_instance.get("notion", "binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py index fc5663e72d..5703f532bd 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py @@ -117,7 +117,7 @@ class TestOAuthServerUserAuthorizeApi: mock_sign.return_value = "auth_code_123" with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}): - with patch("libs.login.current_user", mock_account): + with patch("libs.login._get_user", return_value=mock_account): api_instance = OAuthServerUserAuthorizeApi() response = api_instance.post() diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 1530036395..4db29b1b2e 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -9,9 +9,10 @@ from typing import Any from unittest.mock import MagicMock import pytest -from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError from core.repositories.human_input_repository import ( + FormCreateParams, + FormNotFoundError, HumanInputFormRecord, HumanInputFormRepositoryImpl, HumanInputFormSubmissionRepository, @@ -20,16 +21,15 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType @@ -212,7 +212,7 @@ def test_recipient_entity_id_and_token_success() -> None: assert entity.token == "tok" -def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: +def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None: form = _DummyForm( id="f1", workflow_run_id="run", @@ -229,13 +229,13 @@ def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> No ) entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] - assert entity.web_app_token == "ctok" + assert entity.submission_token == "ctok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] - assert entity.web_app_token == "wtok" + assert entity.submission_token == "wtok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] - assert entity.web_app_token is None + assert entity.submission_token is None def test_form_entity_submitted_data_parsed() -> None: @@ -364,8 +364,8 @@ def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), subject="s", body="b", @@ -388,7 +388,7 @@ def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatc session=MagicMock(), form_id="f", delivery_id="d", - recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), + recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]), ) assert recipients == ["ok"] @@ -407,8 +407,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m form_id="f", delivery_id="d", recipients_config=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), ) assert recipients == ["ok"] @@ -416,8 +416,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - assert repo.get_form("run", "node") is None + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + assert repo.get_form("node") is None form = _DummyForm( id="f1", @@ -437,8 +437,8 @@ def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.Monke ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - entity = repo.get_form("run", "node") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + entity = repo.get_form("node") assert entity is not None assert entity.id == "f1" assert entity.recipients[0].id == "r1" @@ -454,7 +454,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M session = _FakeSession() _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + repo = HumanInputFormRepositoryImpl( + tenant_id="tenant", + app_id="app", + workflow_execution_id="run", + invoke_source="debugger", + submission_actor_id="acc-1", + ) form_config = HumanInputNodeData( title="Title", @@ -464,8 +470,7 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M user_actions=[UserAction(id="submit", title="Submit")], ) params = FormCreateParams( - app_id="app", - workflow_execution_id="run", + workflow_execution_id=None, node_id="node", form_config=form_config, rendered_content="

hello

", @@ -473,16 +478,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M display_in_ui=True, resolved_default_values={}, form_kind=HumanInputFormKind.RUNTIME, - console_recipient_required=True, - console_creator_account_id="acc-1", - backstage_recipient_required=True, ) entity = repo.create_form(params) assert entity.id == "form-id" assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) # Console token should take precedence when console recipient is present. - assert entity.web_app_token == "token-console" + assert entity.submission_token == "token-console" assert len(entity.recipients) == 3 diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 527e4b1a8d..481487971a 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -10,12 +10,12 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from sqlalchemy import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from configs import dify_config +from core.repositories.factory import OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, _deterministic_json_dump, @@ -25,7 +25,7 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( ) from dify_graph.entities import WorkflowNodeExecution from dify_graph.enums import ( - NodeType, + BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -67,7 +67,7 @@ def _execution( index=1, predecessor_node_id=None, node_id="node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Title", inputs=inputs, outputs=outputs, @@ -387,7 +387,7 @@ def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) db_model.index = 1 db_model.predecessor_node_id = None db_model.node_id = "node" - db_model.node_type = NodeType.LLM + db_model.node_type = BuiltinNodeTypes.LLM db_model.title = "t" db_model.inputs = json.dumps({"trunc": "i"}) db_model.process_data = json.dumps({"trunc": "p"}) @@ -441,7 +441,7 @@ def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest. db_model.index = 1 db_model.predecessor_node_id = None db_model.node_id = "node" - db_model.node_type = NodeType.LLM + db_model.node_type = BuiltinNodeTypes.LLM db_model.title = "t" db_model.inputs = json.dumps({"i": 1}) db_model.process_data = json.dumps({"p": 2}) @@ -768,5 +768,5 @@ def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> lambda max_workers: FakeExecutor(), ) - result = repo.get_by_workflow_run("run", order_config=None) + result = repo.get_by_workflow_execution("run", order_config=None) assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index 49e572584b..fea2177f3e 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -69,9 +69,10 @@ def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig: def service(mocker: MockerFixture) -> ModelLoadBalancingService: # Arrange provider_manager = MagicMock() - mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager) + mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager) svc = ModelLoadBalancingService() svc.provider_manager = provider_manager + svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign] return svc @@ -708,7 +709,10 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") mock_factory = MagicMock() mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} - mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) + mocker.patch( + "services.model_load_balancing_service.create_plugin_model_provider_factory", + return_value=mock_factory, + ) mock_encrypt = mocker.patch( "services.model_load_balancing_service.encrypter.encrypt_token", side_effect=lambda tenant_id, value: f"enc:{value}", @@ -740,7 +744,10 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) mock_factory = MagicMock() mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} - mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) + mocker.patch( + "services.model_load_balancing_service.create_plugin_model_provider_factory", + return_value=mock_factory, + ) mocker.patch( "services.model_load_balancing_service.encrypter.encrypt_token", side_effect=lambda tenant_id, value: f"enc:{value}", diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index d26c2f674f..dde062bc7c 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -12,7 +12,7 @@ This test suite covers: import json import uuid from typing import Any, cast -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest @@ -2053,22 +2053,37 @@ class TestSetupVariablePool: workflow = self._make_workflow() # Act - with patch("services.workflow_service.VariablePool") as MockPool: + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.build_system_variables") as mock_build_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool, + ): _setup_variable_pool( query="hello", files=[], user_id="u-1", user_inputs={"k": "v"}, workflow=workflow, + node_id="start-node", node_type=BuiltinNodeTypes.START, conversation_id="conv-1", conversation_variables=[], ) - # Assert — VariablePool should be called with a SystemVariable (non-default) - MockPool.assert_called_once() - call_kwargs = MockPool.call_args.kwargs - assert call_kwargs["user_inputs"] == {"k": "v"} + # Assert — start nodes should build bootstrap variables and attach node inputs. + MockPool.assert_called_once_with() + mock_build_system_variables.assert_called_once() + mock_add_variables_to_pool.assert_called_once_with( + MockPool.return_value, + mock_build_bootstrap_variables.return_value, + ) + mock_add_node_inputs_to_pool.assert_called_once_with( + MockPool.return_value, + node_id="start-node", + inputs={"k": "v"}, + ) def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node( self, @@ -2079,7 +2094,10 @@ class TestSetupVariablePool: # Act with ( patch("services.workflow_service.VariablePool") as MockPool, - patch("services.workflow_service.SystemVariable.default") as mock_default, + patch("services.workflow_service.default_system_variables") as mock_default_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool, ): _setup_variable_pool( query="", @@ -2087,14 +2105,20 @@ class TestSetupVariablePool: user_id="u-1", user_inputs={}, workflow=workflow, + node_id="llm-node", node_type=BuiltinNodeTypes.LLM, # not a start/trigger node conversation_id="conv-1", conversation_variables=[], ) - # Assert — SystemVariable.default() should be used for non-start nodes - mock_default.assert_called_once() - MockPool.assert_called_once() + # Assert — default system variables should be used and node inputs should not be added. + mock_default_system_variables.assert_called_once() + MockPool.assert_called_once_with() + mock_add_variables_to_pool.assert_called_once_with( + MockPool.return_value, + mock_build_bootstrap_variables.return_value, + ) + mock_add_node_inputs_to_pool.assert_not_called() def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type( self, @@ -2106,20 +2130,31 @@ class TestSetupVariablePool: workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value) # Act - with patch("services.workflow_service.VariablePool") as MockPool: + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.build_system_variables") as mock_build_system_variables, + patch("services.workflow_service.build_bootstrap_variables"), + patch("services.workflow_service.add_variables_to_pool"), + patch("services.workflow_service.add_node_inputs_to_pool"), + ): _setup_variable_pool( query="what is AI?", files=[], user_id="u-1", user_inputs={}, workflow=workflow, + node_id="start-node", node_type=BuiltinNodeTypes.START, conversation_id="conv-abc", conversation_variables=[], ) - # Assert — we just verify VariablePool was called (chatflow path executed) - MockPool.assert_called_once() + # Assert — chatflow system variables should include query, conversation_id and dialogue_count. + MockPool.assert_called_once_with() + system_variable_values = mock_build_system_variables.call_args.args[0] + assert system_variable_values["query"] == "what is AI?" + assert system_variable_values["conversation_id"] == "conv-abc" + assert system_variable_values["dialogue_count"] == 1 class TestRebuildSingleFile: @@ -2142,7 +2177,7 @@ class TestRebuildSingleFile: # Assert assert result is mock_file - mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id) + mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id, access_controller=ANY) def test_rebuild_single_file_should_raise_when_file_value_not_dict( self, @@ -2165,7 +2200,7 @@ class TestRebuildSingleFile: # Assert assert result is mock_files - mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id) + mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id, access_controller=ANY) def test_rebuild_single_file_should_raise_when_file_list_value_not_list( self, @@ -2279,13 +2314,12 @@ class TestWorkflowServiceResolveDeliveryMethod: # Arrange method_a = self._make_method("method-1") method_b = self._make_method("method-2") - node_data = MagicMock() - node_data.delivery_methods = [method_a, method_b] # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="method-2" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a, method_b]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="method-2" + ) # Assert assert result is method_b @@ -2293,26 +2327,22 @@ class TestWorkflowServiceResolveDeliveryMethod: def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None: # Arrange method_a = self._make_method("method-1") - node_data = MagicMock() - node_data.delivery_methods = [method_a] # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="does-not-exist" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="does-not-exist" + ) # Assert assert result is None def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None: - # Arrange - node_data = MagicMock() - node_data.delivery_methods = [] - # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="method-1" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="method-1" + ) # Assert assert result is None @@ -2435,6 +2465,9 @@ class TestWorkflowServiceDraftExecution: patch("services.workflow_service.Session"), patch("services.workflow_service.WorkflowDraftVariableService"), patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.default_system_variables") as mock_default_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, patch("services.workflow_service.DraftVarLoader"), patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, patch("services.workflow_service.DifyCoreRepositoryFactory"), @@ -2475,10 +2508,16 @@ class TestWorkflowServiceDraftExecution: ) # Assert - # For non-start nodes, VariablePool should be initialized with environment_variables - mock_pool_cls.assert_called_once() - args, kwargs = mock_pool_cls.call_args - assert "environment_variables" in kwargs + # For non-start nodes, bootstrap variables should be loaded into an empty pool. + mock_pool_cls.assert_called_once_with() + mock_default_system_variables.assert_called_once() + mock_build_bootstrap_variables.assert_called_once_with( + system_variables=mock_default_system_variables.return_value, + environment_variables=draft_workflow.environment_variables, + ) + mock_add_variables_to_pool.assert_called_once_with( + mock_pool_cls.return_value, mock_build_bootstrap_variables.return_value + ) # =========================================================================== @@ -2588,7 +2627,7 @@ class TestWorkflowServiceHumanInputOperations: patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), patch("services.workflow_service.HumanInputNodeData.model_validate"), patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve, - patch("services.workflow_service.apply_debug_email_recipient"), + patch("services.workflow_service.apply_dify_debug_email_recipient"), patch.object(service, "_build_human_input_variable_pool"), patch.object(service, "_build_human_input_node"), patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])), @@ -2730,13 +2769,15 @@ class TestWorkflowServiceFreeNodeExecution: variable_pool = MagicMock() with ( - patch("services.workflow_service.GraphInitParams"), + patch("services.workflow_service.GraphInitParams") as mock_graph_init_params, patch("services.workflow_service.GraphRuntimeState"), + patch("services.workflow_service.build_dify_run_context"), + patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls, patch("services.workflow_service.HumanInputNode") as mock_node_cls, - patch("services.workflow_service.HumanInputFormRepositoryImpl"), ): node = service._build_human_input_node( workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool ) assert node == mock_node_cls.return_value mock_node_cls.assert_called_once() + mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context)