fix: test

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2026-03-23 14:50:01 +08:00
parent 93e8d856d8
commit 3b53f01378
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
10 changed files with 151 additions and 113 deletions

View File

@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.local import LocalProxy
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
@ -76,6 +75,7 @@ def setup_test_context(
patch("extensions.ext_database.db") as mock_db,
patch("controllers.console.app.wraps.db", mock_db),
patch("controllers.console.wraps.db", mock_db),
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("controllers.console.app.message.db", mock_db),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
@ -99,14 +99,12 @@ def setup_test_context(
mock_db.data_query = data_query_mock
# Let the caller override the stat db query logic
proxy_mock = LocalProxy(lambda: mock_account)
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
full_path = f"{route_path}?{query_string}" if qs else route_path
with (
patch("libs.login.current_user", proxy_mock),
patch("flask_login.current_user", proxy_mock),
patch("libs.login._get_user", return_value=mock_account),
patch("flask_login.current_user", mock_account),
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
):
with test_app.test_request_context(full_path, method=method, json=payload):

View File

@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.statistic import (
AverageResponseTimeStatistic,
@ -81,9 +80,7 @@ def setup_test_context(
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with patch("libs.login._get_user", return_value=mock_account), patch("flask_login.current_user", mock_account):
with test_app.test_request_context(route_path, method="GET"):
request.view_args = {"app_id": "app_123"}
api_instance = endpoint_class()

View File

@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.workflow_draft_variable import (
@ -60,6 +59,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun
with (
patch("controllers.console.app.wraps.db") as mock_db_wraps,
patch("controllers.console.wraps.db", mock_db_wraps),
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("controllers.console.app.workflow_draft_variable.db"),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
@ -71,9 +71,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with patch("libs.login._get_user", return_value=mock_account), patch("flask_login.current_user", mock_account):
with test_app.test_request_context(route_path, method=method, json=payload):
request.view_args = {"app_id": "app_123"}
# extract node_id or variable_id from path manually since view_args overrides
@ -101,6 +99,7 @@ class TestWorkflowDraftVariableEndpoints:
mock_var = MagicMock()
mock_var.app_id = "app_123"
mock_var.user_id = "user_123"
mock_var.id = "var_123"
mock_var.name = "test_var"
mock_var.description = ""

View File

@ -49,9 +49,7 @@ class TestApiKeyAuthDataSource:
),
):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
with patch("libs.login._get_user", return_value=mock_account):
api_instance = ApiKeyAuthDataSource()
response = api_instance.get()
@ -81,9 +79,7 @@ class TestApiKeyAuthDataSource:
),
):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
with patch("libs.login._get_user", return_value=mock_account):
api_instance = ApiKeyAuthDataSource()
response = api_instance.get()
@ -124,9 +120,10 @@ class TestApiKeyAuthDataSourceBinding:
method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with (
patch("libs.login._get_user", return_value=mock_account),
patch("flask_login.current_user", mock_account),
):
api_instance = ApiKeyAuthDataSourceBinding()
response = api_instance.post()
@ -162,9 +159,10 @@ class TestApiKeyAuthDataSourceBinding:
method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with (
patch("libs.login._get_user", return_value=mock_account),
patch("flask_login.current_user", mock_account),
):
api_instance = ApiKeyAuthDataSourceBinding()
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
api_instance.post()
@ -198,9 +196,10 @@ class TestApiKeyAuthDataSourceBindingDelete:
),
):
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with (
patch("libs.login._get_user", return_value=mock_account),
patch("flask_login.current_user", mock_account),
):
api_instance = ApiKeyAuthDataSourceBindingDelete()
response = api_instance.delete("binding_123")

View File

@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.local import LocalProxy
from controllers.console.auth.data_source_oauth import (
OAuthDataSource,
@ -21,12 +20,12 @@ class TestOAuthDataSource:
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("flask_login.current_user")
@patch("libs.login.current_user")
@patch("libs.login._get_user")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
def test_get_oauth_url_successful(
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
self, mock_db, mock_csrf, mock_get_user, mock_flask_user, mock_get_providers, app
):
mock_oauth_provider = MagicMock()
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
@ -39,16 +38,14 @@ class TestOAuthDataSource:
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_libs_user.return_value = mock_account
mock_get_user.return_value = mock_account
mock_flask_user.return_value = mock_account
# also patch current_account_with_tenant
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSource()
response = api_instance.get("notion")
api_instance = OAuthDataSource()
response = api_instance.get("notion")
assert response[0]["data"] == "http://oauth.provider/auth"
assert response[1] == 200
@ -71,8 +68,7 @@ class TestOAuthDataSource:
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
with patch("libs.login._get_user", return_value=mock_account):
api_instance = OAuthDataSource()
response = api_instance.get("unknown_provider")
@ -181,8 +177,7 @@ class TestOAuthDataSourceSync:
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
with patch("libs.login._get_user", return_value=mock_account):
api_instance = OAuthDataSourceSync()
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
response = api_instance.get("notion", "binding_123")

View File

@ -117,7 +117,7 @@ class TestOAuthServerUserAuthorizeApi:
mock_sign.return_value = "auth_code_123"
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
with patch("libs.login.current_user", mock_account):
with patch("libs.login._get_user", return_value=mock_account):
api_instance = OAuthServerUserAuthorizeApi()
response = api_instance.post()

View File

@ -9,9 +9,10 @@ from typing import Any
from unittest.mock import MagicMock
import pytest
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
from core.repositories.human_input_repository import (
FormCreateParams,
FormNotFoundError,
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository,
@ -20,16 +21,15 @@ from core.repositories.human_input_repository import (
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
from dify_graph.nodes.human_input.entities import (
from core.workflow.human_input_compat import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
)
from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from models.human_input import HumanInputFormRecipient, RecipientType
@ -212,7 +212,7 @@ def test_recipient_entity_id_and_token_success() -> None:
assert entity.token == "tok"
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
@ -229,13 +229,13 @@ def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> No
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
assert entity.web_app_token == "ctok"
assert entity.submission_token == "ctok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
assert entity.web_app_token == "wtok"
assert entity.submission_token == "wtok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.web_app_token is None
assert entity.submission_token is None
def test_form_entity_submitted_data_parsed() -> None:
@ -364,8 +364,8 @@ def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch:
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
include_bound_group=False,
items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")],
),
subject="s",
body="b",
@ -388,7 +388,7 @@ def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatc
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]),
)
assert recipients == ["ok"]
@ -407,8 +407,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
include_bound_group=False,
items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")],
),
)
assert recipients == ["ok"]
@ -416,8 +416,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo.get_form("run", "node") is None
repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run")
assert repo.get_form("node") is None
form = _DummyForm(
id="f1",
@ -437,8 +437,8 @@ def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.Monke
)
session = _FakeSession(scalars_results=[form, [recipient]])
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
entity = repo.get_form("run", "node")
repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run")
entity = repo.get_form("node")
assert entity is not None
assert entity.id == "f1"
assert entity.recipients[0].id == "r1"
@ -454,7 +454,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
session = _FakeSession()
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
repo = HumanInputFormRepositoryImpl(
tenant_id="tenant",
app_id="app",
workflow_execution_id="run",
invoke_source="debugger",
submission_actor_id="acc-1",
)
form_config = HumanInputNodeData(
title="Title",
@ -464,8 +470,7 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
user_actions=[UserAction(id="submit", title="Submit")],
)
params = FormCreateParams(
app_id="app",
workflow_execution_id="run",
workflow_execution_id=None,
node_id="node",
form_config=form_config,
rendered_content="<p>hello</p>",
@ -473,16 +478,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
display_in_ui=True,
resolved_default_values={},
form_kind=HumanInputFormKind.RUNTIME,
console_recipient_required=True,
console_creator_account_id="acc-1",
backstage_recipient_required=True,
)
entity = repo.create_form(params)
assert entity.id == "form-id"
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
# Console token should take precedence when console recipient is present.
assert entity.web_app_token == "token-console"
assert entity.submission_token == "token-console"
assert len(entity.recipients) == 3

View File

@ -10,12 +10,12 @@ from unittest.mock import MagicMock, Mock
import psycopg2.errors
import pytest
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
from sqlalchemy import Engine, create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.factory import OrderConfig
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
_deterministic_json_dump,
@ -25,7 +25,7 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import (
)
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import (
NodeType,
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
@ -67,7 +67,7 @@ def _execution(
index=1,
predecessor_node_id=None,
node_id="node-id",
node_type=NodeType.LLM,
node_type=BuiltinNodeTypes.LLM,
title="Title",
inputs=inputs,
outputs=outputs,
@ -387,7 +387,7 @@ def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch)
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.node_type = BuiltinNodeTypes.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"trunc": "i"})
db_model.process_data = json.dumps({"trunc": "p"})
@ -441,7 +441,7 @@ def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = NodeType.LLM
db_model.node_type = BuiltinNodeTypes.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"i": 1})
db_model.process_data = json.dumps({"p": 2})
@ -768,5 +768,5 @@ def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) ->
lambda max_workers: FakeExecutor(),
)
result = repo.get_by_workflow_run("run", order_config=None)
result = repo.get_by_workflow_execution("run", order_config=None)
assert result == ["domain:db1", "domain:db2"]

View File

@ -69,9 +69,10 @@ def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig:
def service(mocker: MockerFixture) -> ModelLoadBalancingService:
# Arrange
provider_manager = MagicMock()
mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager)
mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager)
svc = ModelLoadBalancingService()
svc.provider_manager = provider_manager
svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign]
return svc
@ -708,7 +709,10 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
mock_factory = MagicMock()
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
mocker.patch(
"services.model_load_balancing_service.create_plugin_model_provider_factory",
return_value=mock_factory,
)
mock_encrypt = mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",
@ -740,7 +744,10 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
mock_factory = MagicMock()
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
mocker.patch(
"services.model_load_balancing_service.create_plugin_model_provider_factory",
return_value=mock_factory,
)
mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}",

View File

@ -12,7 +12,7 @@ This test suite covers:
import json
import uuid
from typing import Any, cast
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
@ -2053,22 +2053,37 @@ class TestSetupVariablePool:
workflow = self._make_workflow()
# Act
with patch("services.workflow_service.VariablePool") as MockPool:
with (
patch("services.workflow_service.VariablePool") as MockPool,
patch("services.workflow_service.build_system_variables") as mock_build_system_variables,
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool,
):
_setup_variable_pool(
query="hello",
files=[],
user_id="u-1",
user_inputs={"k": "v"},
workflow=workflow,
node_id="start-node",
node_type=BuiltinNodeTypes.START,
conversation_id="conv-1",
conversation_variables=[],
)
# Assert — VariablePool should be called with a SystemVariable (non-default)
MockPool.assert_called_once()
call_kwargs = MockPool.call_args.kwargs
assert call_kwargs["user_inputs"] == {"k": "v"}
# Assert — start nodes should build bootstrap variables and attach node inputs.
MockPool.assert_called_once_with()
mock_build_system_variables.assert_called_once()
mock_add_variables_to_pool.assert_called_once_with(
MockPool.return_value,
mock_build_bootstrap_variables.return_value,
)
mock_add_node_inputs_to_pool.assert_called_once_with(
MockPool.return_value,
node_id="start-node",
inputs={"k": "v"},
)
def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node(
self,
@ -2079,7 +2094,10 @@ class TestSetupVariablePool:
# Act
with (
patch("services.workflow_service.VariablePool") as MockPool,
patch("services.workflow_service.SystemVariable.default") as mock_default,
patch("services.workflow_service.default_system_variables") as mock_default_system_variables,
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool,
):
_setup_variable_pool(
query="",
@ -2087,14 +2105,20 @@ class TestSetupVariablePool:
user_id="u-1",
user_inputs={},
workflow=workflow,
node_id="llm-node",
node_type=BuiltinNodeTypes.LLM, # not a start/trigger node
conversation_id="conv-1",
conversation_variables=[],
)
# Assert — SystemVariable.default() should be used for non-start nodes
mock_default.assert_called_once()
MockPool.assert_called_once()
# Assert — default system variables should be used and node inputs should not be added.
mock_default_system_variables.assert_called_once()
MockPool.assert_called_once_with()
mock_add_variables_to_pool.assert_called_once_with(
MockPool.return_value,
mock_build_bootstrap_variables.return_value,
)
mock_add_node_inputs_to_pool.assert_not_called()
def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type(
self,
@ -2106,20 +2130,31 @@ class TestSetupVariablePool:
workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value)
# Act
with patch("services.workflow_service.VariablePool") as MockPool:
with (
patch("services.workflow_service.VariablePool") as MockPool,
patch("services.workflow_service.build_system_variables") as mock_build_system_variables,
patch("services.workflow_service.build_bootstrap_variables"),
patch("services.workflow_service.add_variables_to_pool"),
patch("services.workflow_service.add_node_inputs_to_pool"),
):
_setup_variable_pool(
query="what is AI?",
files=[],
user_id="u-1",
user_inputs={},
workflow=workflow,
node_id="start-node",
node_type=BuiltinNodeTypes.START,
conversation_id="conv-abc",
conversation_variables=[],
)
# Assert — we just verify VariablePool was called (chatflow path executed)
MockPool.assert_called_once()
# Assert — chatflow system variables should include query, conversation_id and dialogue_count.
MockPool.assert_called_once_with()
system_variable_values = mock_build_system_variables.call_args.args[0]
assert system_variable_values["query"] == "what is AI?"
assert system_variable_values["conversation_id"] == "conv-abc"
assert system_variable_values["dialogue_count"] == 1
class TestRebuildSingleFile:
@ -2142,7 +2177,7 @@ class TestRebuildSingleFile:
# Assert
assert result is mock_file
mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id)
mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id, access_controller=ANY)
def test_rebuild_single_file_should_raise_when_file_value_not_dict(
self,
@ -2165,7 +2200,7 @@ class TestRebuildSingleFile:
# Assert
assert result is mock_files
mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id)
mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id, access_controller=ANY)
def test_rebuild_single_file_should_raise_when_file_list_value_not_list(
self,
@ -2279,13 +2314,12 @@ class TestWorkflowServiceResolveDeliveryMethod:
# Arrange
method_a = self._make_method("method-1")
method_b = self._make_method("method-2")
node_data = MagicMock()
node_data.delivery_methods = [method_a, method_b]
# Act
result = WorkflowService._resolve_human_input_delivery_method(
node_data=node_data, delivery_method_id="method-2"
)
with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a, method_b]):
result = WorkflowService._resolve_human_input_delivery_method(
node_data=MagicMock(), delivery_method_id="method-2"
)
# Assert
assert result is method_b
@ -2293,26 +2327,22 @@ class TestWorkflowServiceResolveDeliveryMethod:
def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None:
# Arrange
method_a = self._make_method("method-1")
node_data = MagicMock()
node_data.delivery_methods = [method_a]
# Act
result = WorkflowService._resolve_human_input_delivery_method(
node_data=node_data, delivery_method_id="does-not-exist"
)
with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a]):
result = WorkflowService._resolve_human_input_delivery_method(
node_data=MagicMock(), delivery_method_id="does-not-exist"
)
# Assert
assert result is None
def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None:
# Arrange
node_data = MagicMock()
node_data.delivery_methods = []
# Act
result = WorkflowService._resolve_human_input_delivery_method(
node_data=node_data, delivery_method_id="method-1"
)
with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[]):
result = WorkflowService._resolve_human_input_delivery_method(
node_data=MagicMock(), delivery_method_id="method-1"
)
# Assert
assert result is None
@ -2435,6 +2465,9 @@ class TestWorkflowServiceDraftExecution:
patch("services.workflow_service.Session"),
patch("services.workflow_service.WorkflowDraftVariableService"),
patch("services.workflow_service.VariablePool") as mock_pool_cls,
patch("services.workflow_service.default_system_variables") as mock_default_system_variables,
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
patch("services.workflow_service.DraftVarLoader"),
patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run,
patch("services.workflow_service.DifyCoreRepositoryFactory"),
@ -2475,10 +2508,16 @@ class TestWorkflowServiceDraftExecution:
)
# Assert
# For non-start nodes, VariablePool should be initialized with environment_variables
mock_pool_cls.assert_called_once()
args, kwargs = mock_pool_cls.call_args
assert "environment_variables" in kwargs
# For non-start nodes, bootstrap variables should be loaded into an empty pool.
mock_pool_cls.assert_called_once_with()
mock_default_system_variables.assert_called_once()
mock_build_bootstrap_variables.assert_called_once_with(
system_variables=mock_default_system_variables.return_value,
environment_variables=draft_workflow.environment_variables,
)
mock_add_variables_to_pool.assert_called_once_with(
mock_pool_cls.return_value, mock_build_bootstrap_variables.return_value
)
# ===========================================================================
@ -2588,7 +2627,7 @@ class TestWorkflowServiceHumanInputOperations:
patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT),
patch("services.workflow_service.HumanInputNodeData.model_validate"),
patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve,
patch("services.workflow_service.apply_debug_email_recipient"),
patch("services.workflow_service.apply_dify_debug_email_recipient"),
patch.object(service, "_build_human_input_variable_pool"),
patch.object(service, "_build_human_input_node"),
patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])),
@ -2730,13 +2769,15 @@ class TestWorkflowServiceFreeNodeExecution:
variable_pool = MagicMock()
with (
patch("services.workflow_service.GraphInitParams"),
patch("services.workflow_service.GraphInitParams") as mock_graph_init_params,
patch("services.workflow_service.GraphRuntimeState"),
patch("services.workflow_service.build_dify_run_context"),
patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls,
patch("services.workflow_service.HumanInputNode") as mock_node_cls,
patch("services.workflow_service.HumanInputFormRepositoryImpl"),
):
node = service._build_human_input_node(
workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool
)
assert node == mock_node_cls.return_value
mock_node_cls.assert_called_once()
mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context)