fix: test

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2026-03-23 14:50:01 +08:00
parent 93e8d856d8
commit 3b53f01378
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
10 changed files with 151 additions and 113 deletions

View File

@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from flask import Flask, request from flask import Flask, request
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.local import LocalProxy
from controllers.console.app.error import ( from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -76,6 +75,7 @@ def setup_test_context(
patch("extensions.ext_database.db") as mock_db, patch("extensions.ext_database.db") as mock_db,
patch("controllers.console.app.wraps.db", mock_db), patch("controllers.console.app.wraps.db", mock_db),
patch("controllers.console.wraps.db", mock_db), patch("controllers.console.wraps.db", mock_db),
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("controllers.console.app.message.db", mock_db), patch("controllers.console.app.message.db", mock_db),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
@ -99,14 +99,12 @@ def setup_test_context(
mock_db.data_query = data_query_mock mock_db.data_query = data_query_mock
# Let the caller override the stat db query logic # Let the caller override the stat db query logic
proxy_mock = LocalProxy(lambda: mock_account)
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()]) query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
full_path = f"{route_path}?{query_string}" if qs else route_path full_path = f"{route_path}?{query_string}" if qs else route_path
with ( with (
patch("libs.login.current_user", proxy_mock), patch("libs.login._get_user", return_value=mock_account),
patch("flask_login.current_user", proxy_mock), patch("flask_login.current_user", mock_account),
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None), patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
): ):
with test_app.test_request_context(full_path, method=method, json=payload): with test_app.test_request_context(full_path, method=method, json=payload):

View File

@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from flask import Flask, request from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.statistic import ( from controllers.console.app.statistic import (
AverageResponseTimeStatistic, AverageResponseTimeStatistic,
@ -81,9 +80,7 @@ def setup_test_context(
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account) with patch("libs.login._get_user", return_value=mock_account), patch("flask_login.current_user", mock_account):
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with test_app.test_request_context(route_path, method="GET"): with test_app.test_request_context(route_path, method="GET"):
request.view_args = {"app_id": "app_123"} request.view_args = {"app_id": "app_123"}
api_instance = endpoint_class() api_instance = endpoint_class()

View File

@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from flask import Flask, request from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.error import DraftWorkflowNotExist from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.workflow_draft_variable import ( from controllers.console.app.workflow_draft_variable import (
@ -60,6 +59,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun
with ( with (
patch("controllers.console.app.wraps.db") as mock_db_wraps, patch("controllers.console.app.wraps.db") as mock_db_wraps,
patch("controllers.console.wraps.db", mock_db_wraps), patch("controllers.console.wraps.db", mock_db_wraps),
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("controllers.console.app.workflow_draft_variable.db"), patch("controllers.console.app.workflow_draft_variable.db"),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
@ -71,9 +71,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account) with patch("libs.login._get_user", return_value=mock_account), patch("flask_login.current_user", mock_account):
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with test_app.test_request_context(route_path, method=method, json=payload): with test_app.test_request_context(route_path, method=method, json=payload):
request.view_args = {"app_id": "app_123"} request.view_args = {"app_id": "app_123"}
# extract node_id or variable_id from path manually since view_args overrides # extract node_id or variable_id from path manually since view_args overrides
@ -101,6 +99,7 @@ class TestWorkflowDraftVariableEndpoints:
mock_var = MagicMock() mock_var = MagicMock()
mock_var.app_id = "app_123" mock_var.app_id = "app_123"
mock_var.user_id = "user_123"
mock_var.id = "var_123" mock_var.id = "var_123"
mock_var.name = "test_var" mock_var.name = "test_var"
mock_var.description = "" mock_var.description = ""

View File

@ -49,9 +49,7 @@ class TestApiKeyAuthDataSource:
), ),
): ):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock() with patch("libs.login._get_user", return_value=mock_account):
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSource() api_instance = ApiKeyAuthDataSource()
response = api_instance.get() response = api_instance.get()
@ -81,9 +79,7 @@ class TestApiKeyAuthDataSource:
), ),
): ):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock() with patch("libs.login._get_user", return_value=mock_account):
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSource() api_instance = ApiKeyAuthDataSource()
response = api_instance.get() response = api_instance.get()
@ -124,9 +120,10 @@ class TestApiKeyAuthDataSourceBinding:
method="POST", method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
): ):
proxy_mock = MagicMock() with (
proxy_mock._get_current_object.return_value = mock_account patch("libs.login._get_user", return_value=mock_account),
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): patch("flask_login.current_user", mock_account),
):
api_instance = ApiKeyAuthDataSourceBinding() api_instance = ApiKeyAuthDataSourceBinding()
response = api_instance.post() response = api_instance.post()
@ -162,9 +159,10 @@ class TestApiKeyAuthDataSourceBinding:
method="POST", method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
): ):
proxy_mock = MagicMock() with (
proxy_mock._get_current_object.return_value = mock_account patch("libs.login._get_user", return_value=mock_account),
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): patch("flask_login.current_user", mock_account),
):
api_instance = ApiKeyAuthDataSourceBinding() api_instance = ApiKeyAuthDataSourceBinding()
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"): with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
api_instance.post() api_instance.post()
@ -198,9 +196,10 @@ class TestApiKeyAuthDataSourceBindingDelete:
), ),
): ):
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"): with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
proxy_mock = MagicMock() with (
proxy_mock._get_current_object.return_value = mock_account patch("libs.login._get_user", return_value=mock_account),
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): patch("flask_login.current_user", mock_account),
):
api_instance = ApiKeyAuthDataSourceBindingDelete() api_instance = ApiKeyAuthDataSourceBindingDelete()
response = api_instance.delete("binding_123") response = api_instance.delete("binding_123")

View File

@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from flask import Flask from flask import Flask
from werkzeug.local import LocalProxy
from controllers.console.auth.data_source_oauth import ( from controllers.console.auth.data_source_oauth import (
OAuthDataSource, OAuthDataSource,
@ -21,12 +20,12 @@ class TestOAuthDataSource:
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers") @patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("flask_login.current_user") @patch("flask_login.current_user")
@patch("libs.login.current_user") @patch("libs.login._get_user")
@patch("libs.login.check_csrf_token") @patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db") @patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None) @patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
def test_get_oauth_url_successful( def test_get_oauth_url_successful(
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app self, mock_db, mock_csrf, mock_get_user, mock_flask_user, mock_get_providers, app
): ):
mock_oauth_provider = MagicMock() mock_oauth_provider = MagicMock()
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth" mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
@ -39,16 +38,14 @@ class TestOAuthDataSource:
mock_account.status = AccountStatus.ACTIVE mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner" mock_account.current_tenant.current_role = "owner"
mock_libs_user.return_value = mock_account mock_get_user.return_value = mock_account
mock_flask_user.return_value = mock_account mock_flask_user.return_value = mock_account
# also patch current_account_with_tenant # also patch current_account_with_tenant
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"): with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account) api_instance = OAuthDataSource()
with patch("libs.login.current_user", proxy_mock): response = api_instance.get("notion")
api_instance = OAuthDataSource()
response = api_instance.get("notion")
assert response[0]["data"] == "http://oauth.provider/auth" assert response[0]["data"] == "http://oauth.provider/auth"
assert response[1] == 200 assert response[1] == 200
@ -71,8 +68,7 @@ class TestOAuthDataSource:
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"): with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account) with patch("libs.login._get_user", return_value=mock_account):
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSource() api_instance = OAuthDataSource()
response = api_instance.get("unknown_provider") response = api_instance.get("unknown_provider")
@ -181,8 +177,7 @@ class TestOAuthDataSourceSync:
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"): with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account) with patch("libs.login._get_user", return_value=mock_account):
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSourceSync() api_instance = OAuthDataSourceSync()
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing # The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
response = api_instance.get("notion", "binding_123") response = api_instance.get("notion", "binding_123")

View File

@ -117,7 +117,7 @@ class TestOAuthServerUserAuthorizeApi:
mock_sign.return_value = "auth_code_123" mock_sign.return_value = "auth_code_123"
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}): with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
with patch("libs.login.current_user", mock_account): with patch("libs.login._get_user", return_value=mock_account):
api_instance = OAuthServerUserAuthorizeApi() api_instance = OAuthServerUserAuthorizeApi()
response = api_instance.post() response = api_instance.post()

View File

@ -9,9 +9,10 @@ from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
from core.repositories.human_input_repository import ( from core.repositories.human_input_repository import (
FormCreateParams,
FormNotFoundError,
HumanInputFormRecord, HumanInputFormRecord,
HumanInputFormRepositoryImpl, HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository, HumanInputFormSubmissionRepository,
@ -20,16 +21,15 @@ from core.repositories.human_input_repository import (
_InvalidTimeoutStatusError, _InvalidTimeoutStatusError,
_WorkspaceMemberInfo, _WorkspaceMemberInfo,
) )
from dify_graph.nodes.human_input.entities import ( from core.workflow.human_input_compat import (
EmailDeliveryConfig, EmailDeliveryConfig,
EmailDeliveryMethod, EmailDeliveryMethod,
EmailRecipients, EmailRecipients,
ExternalRecipient, ExternalRecipient,
HumanInputNodeData,
MemberRecipient, MemberRecipient,
UserAction,
WebAppDeliveryMethod, WebAppDeliveryMethod,
) )
from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.human_input import HumanInputFormRecipient, RecipientType from models.human_input import HumanInputFormRecipient, RecipientType
@ -212,7 +212,7 @@ def test_recipient_entity_id_and_token_success() -> None:
assert entity.token == "tok" assert entity.token == "tok"
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None:
form = _DummyForm( form = _DummyForm(
id="f1", id="f1",
workflow_run_id="run", workflow_run_id="run",
@ -229,13 +229,13 @@ def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> No
) )
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
assert entity.web_app_token == "ctok" assert entity.submission_token == "ctok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
assert entity.web_app_token == "wtok" assert entity.submission_token == "wtok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.web_app_token is None assert entity.submission_token is None
def test_form_entity_submitted_data_parsed() -> None: def test_form_entity_submitted_data_parsed() -> None:
@ -364,8 +364,8 @@ def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch:
method = EmailDeliveryMethod( method = EmailDeliveryMethod(
config=EmailDeliveryConfig( config=EmailDeliveryConfig(
recipients=EmailRecipients( recipients=EmailRecipients(
whole_workspace=False, include_bound_group=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")],
), ),
subject="s", subject="s",
body="b", body="b",
@ -388,7 +388,7 @@ def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatc
session=MagicMock(), session=MagicMock(),
form_id="f", form_id="f",
delivery_id="d", delivery_id="d",
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]),
) )
assert recipients == ["ok"] assert recipients == ["ok"]
@ -407,8 +407,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m
form_id="f", form_id="f",
delivery_id="d", delivery_id="d",
recipients_config=EmailRecipients( recipients_config=EmailRecipients(
whole_workspace=False, include_bound_group=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")],
), ),
) )
assert recipients == ["ok"] assert recipients == ["ok"]
@ -416,8 +416,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant") repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run")
assert repo.get_form("run", "node") is None assert repo.get_form("node") is None
form = _DummyForm( form = _DummyForm(
id="f1", id="f1",
@ -437,8 +437,8 @@ def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.Monke
) )
session = _FakeSession(scalars_results=[form, [recipient]]) session = _FakeSession(scalars_results=[form, [recipient]])
_patch_session_factory(monkeypatch, session) _patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant") repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run")
entity = repo.get_form("run", "node") entity = repo.get_form("node")
assert entity is not None assert entity is not None
assert entity.id == "f1" assert entity.id == "f1"
assert entity.recipients[0].id == "r1" assert entity.recipients[0].id == "r1"
@ -454,7 +454,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
session = _FakeSession() session = _FakeSession()
_patch_session_factory(monkeypatch, session) _patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant") repo = HumanInputFormRepositoryImpl(
tenant_id="tenant",
app_id="app",
workflow_execution_id="run",
invoke_source="debugger",
submission_actor_id="acc-1",
)
form_config = HumanInputNodeData( form_config = HumanInputNodeData(
title="Title", title="Title",
@ -464,8 +470,7 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
user_actions=[UserAction(id="submit", title="Submit")], user_actions=[UserAction(id="submit", title="Submit")],
) )
params = FormCreateParams( params = FormCreateParams(
app_id="app", workflow_execution_id=None,
workflow_execution_id="run",
node_id="node", node_id="node",
form_config=form_config, form_config=form_config,
rendered_content="<p>hello</p>", rendered_content="<p>hello</p>",
@ -473,16 +478,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
display_in_ui=True, display_in_ui=True,
resolved_default_values={}, resolved_default_values={},
form_kind=HumanInputFormKind.RUNTIME, form_kind=HumanInputFormKind.RUNTIME,
console_recipient_required=True,
console_creator_account_id="acc-1",
backstage_recipient_required=True,
) )
entity = repo.create_form(params) entity = repo.create_form(params)
assert entity.id == "form-id" assert entity.id == "form-id"
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
# Console token should take precedence when console recipient is present. # Console token should take precedence when console recipient is present.
assert entity.web_app_token == "token-console" assert entity.submission_token == "token-console"
assert len(entity.recipients) == 3 assert len(entity.recipients) == 3

View File

@ -10,12 +10,12 @@ from unittest.mock import MagicMock, Mock
import psycopg2.errors import psycopg2.errors
import pytest import pytest
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
from sqlalchemy import Engine, create_engine from sqlalchemy import Engine, create_engine
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from core.repositories.factory import OrderConfig
from core.repositories.sqlalchemy_workflow_node_execution_repository import ( from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository, SQLAlchemyWorkflowNodeExecutionRepository,
_deterministic_json_dump, _deterministic_json_dump,
@ -25,7 +25,7 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import (
) )
from dify_graph.entities import WorkflowNodeExecution from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import ( from dify_graph.enums import (
NodeType, BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
) )
@ -67,7 +67,7 @@ def _execution(
index=1, index=1,
predecessor_node_id=None, predecessor_node_id=None,
node_id="node-id", node_id="node-id",
node_type=NodeType.LLM, node_type=BuiltinNodeTypes.LLM,
title="Title", title="Title",
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
@ -387,7 +387,7 @@ def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch)
db_model.index = 1 db_model.index = 1
db_model.predecessor_node_id = None db_model.predecessor_node_id = None
db_model.node_id = "node" db_model.node_id = "node"
db_model.node_type = NodeType.LLM db_model.node_type = BuiltinNodeTypes.LLM
db_model.title = "t" db_model.title = "t"
db_model.inputs = json.dumps({"trunc": "i"}) db_model.inputs = json.dumps({"trunc": "i"})
db_model.process_data = json.dumps({"trunc": "p"}) db_model.process_data = json.dumps({"trunc": "p"})
@ -441,7 +441,7 @@ def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.
db_model.index = 1 db_model.index = 1
db_model.predecessor_node_id = None db_model.predecessor_node_id = None
db_model.node_id = "node" db_model.node_id = "node"
db_model.node_type = NodeType.LLM db_model.node_type = BuiltinNodeTypes.LLM
db_model.title = "t" db_model.title = "t"
db_model.inputs = json.dumps({"i": 1}) db_model.inputs = json.dumps({"i": 1})
db_model.process_data = json.dumps({"p": 2}) db_model.process_data = json.dumps({"p": 2})
@ -768,5 +768,5 @@ def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) ->
lambda max_workers: FakeExecutor(), lambda max_workers: FakeExecutor(),
) )
result = repo.get_by_workflow_run("run", order_config=None) result = repo.get_by_workflow_execution("run", order_config=None)
assert result == ["domain:db1", "domain:db2"] assert result == ["domain:db1", "domain:db2"]

View File

@ -69,9 +69,10 @@ def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig:
def service(mocker: MockerFixture) -> ModelLoadBalancingService: def service(mocker: MockerFixture) -> ModelLoadBalancingService:
# Arrange # Arrange
provider_manager = MagicMock() provider_manager = MagicMock()
mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager) mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager)
svc = ModelLoadBalancingService() svc = ModelLoadBalancingService()
svc.provider_manager = provider_manager svc.provider_manager = provider_manager
svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign]
return svc return svc
@ -708,7 +709,10 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
mock_factory = MagicMock() mock_factory = MagicMock()
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) mocker.patch(
"services.model_load_balancing_service.create_plugin_model_provider_factory",
return_value=mock_factory,
)
mock_encrypt = mocker.patch( mock_encrypt = mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token", "services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}", side_effect=lambda tenant_id, value: f"enc:{value}",
@ -740,7 +744,10 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
mock_factory = MagicMock() mock_factory = MagicMock()
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) mocker.patch(
"services.model_load_balancing_service.create_plugin_model_provider_factory",
return_value=mock_factory,
)
mocker.patch( mocker.patch(
"services.model_load_balancing_service.encrypter.encrypt_token", "services.model_load_balancing_service.encrypter.encrypt_token",
side_effect=lambda tenant_id, value: f"enc:{value}", side_effect=lambda tenant_id, value: f"enc:{value}",

View File

@ -12,7 +12,7 @@ This test suite covers:
import json import json
import uuid import uuid
from typing import Any, cast from typing import Any, cast
from unittest.mock import MagicMock, patch from unittest.mock import ANY, MagicMock, patch
import pytest import pytest
@ -2053,22 +2053,37 @@ class TestSetupVariablePool:
workflow = self._make_workflow() workflow = self._make_workflow()
# Act # Act
with patch("services.workflow_service.VariablePool") as MockPool: with (
patch("services.workflow_service.VariablePool") as MockPool,
patch("services.workflow_service.build_system_variables") as mock_build_system_variables,
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool,
):
_setup_variable_pool( _setup_variable_pool(
query="hello", query="hello",
files=[], files=[],
user_id="u-1", user_id="u-1",
user_inputs={"k": "v"}, user_inputs={"k": "v"},
workflow=workflow, workflow=workflow,
node_id="start-node",
node_type=BuiltinNodeTypes.START, node_type=BuiltinNodeTypes.START,
conversation_id="conv-1", conversation_id="conv-1",
conversation_variables=[], conversation_variables=[],
) )
# Assert — VariablePool should be called with a SystemVariable (non-default) # Assert — start nodes should build bootstrap variables and attach node inputs.
MockPool.assert_called_once() MockPool.assert_called_once_with()
call_kwargs = MockPool.call_args.kwargs mock_build_system_variables.assert_called_once()
assert call_kwargs["user_inputs"] == {"k": "v"} mock_add_variables_to_pool.assert_called_once_with(
MockPool.return_value,
mock_build_bootstrap_variables.return_value,
)
mock_add_node_inputs_to_pool.assert_called_once_with(
MockPool.return_value,
node_id="start-node",
inputs={"k": "v"},
)
def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node( def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node(
self, self,
@ -2079,7 +2094,10 @@ class TestSetupVariablePool:
# Act # Act
with ( with (
patch("services.workflow_service.VariablePool") as MockPool, patch("services.workflow_service.VariablePool") as MockPool,
patch("services.workflow_service.SystemVariable.default") as mock_default, patch("services.workflow_service.default_system_variables") as mock_default_system_variables,
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool,
): ):
_setup_variable_pool( _setup_variable_pool(
query="", query="",
@ -2087,14 +2105,20 @@ class TestSetupVariablePool:
user_id="u-1", user_id="u-1",
user_inputs={}, user_inputs={},
workflow=workflow, workflow=workflow,
node_id="llm-node",
node_type=BuiltinNodeTypes.LLM, # not a start/trigger node node_type=BuiltinNodeTypes.LLM, # not a start/trigger node
conversation_id="conv-1", conversation_id="conv-1",
conversation_variables=[], conversation_variables=[],
) )
# Assert — SystemVariable.default() should be used for non-start nodes # Assert — default system variables should be used and node inputs should not be added.
mock_default.assert_called_once() mock_default_system_variables.assert_called_once()
MockPool.assert_called_once() MockPool.assert_called_once_with()
mock_add_variables_to_pool.assert_called_once_with(
MockPool.return_value,
mock_build_bootstrap_variables.return_value,
)
mock_add_node_inputs_to_pool.assert_not_called()
def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type( def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type(
self, self,
@ -2106,20 +2130,31 @@ class TestSetupVariablePool:
workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value) workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value)
# Act # Act
with patch("services.workflow_service.VariablePool") as MockPool: with (
patch("services.workflow_service.VariablePool") as MockPool,
patch("services.workflow_service.build_system_variables") as mock_build_system_variables,
patch("services.workflow_service.build_bootstrap_variables"),
patch("services.workflow_service.add_variables_to_pool"),
patch("services.workflow_service.add_node_inputs_to_pool"),
):
_setup_variable_pool( _setup_variable_pool(
query="what is AI?", query="what is AI?",
files=[], files=[],
user_id="u-1", user_id="u-1",
user_inputs={}, user_inputs={},
workflow=workflow, workflow=workflow,
node_id="start-node",
node_type=BuiltinNodeTypes.START, node_type=BuiltinNodeTypes.START,
conversation_id="conv-abc", conversation_id="conv-abc",
conversation_variables=[], conversation_variables=[],
) )
# Assert — we just verify VariablePool was called (chatflow path executed) # Assert — chatflow system variables should include query, conversation_id and dialogue_count.
MockPool.assert_called_once() MockPool.assert_called_once_with()
system_variable_values = mock_build_system_variables.call_args.args[0]
assert system_variable_values["query"] == "what is AI?"
assert system_variable_values["conversation_id"] == "conv-abc"
assert system_variable_values["dialogue_count"] == 1
class TestRebuildSingleFile: class TestRebuildSingleFile:
@ -2142,7 +2177,7 @@ class TestRebuildSingleFile:
# Assert # Assert
assert result is mock_file assert result is mock_file
mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id) mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id, access_controller=ANY)
def test_rebuild_single_file_should_raise_when_file_value_not_dict( def test_rebuild_single_file_should_raise_when_file_value_not_dict(
self, self,
@ -2165,7 +2200,7 @@ class TestRebuildSingleFile:
# Assert # Assert
assert result is mock_files assert result is mock_files
mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id) mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id, access_controller=ANY)
def test_rebuild_single_file_should_raise_when_file_list_value_not_list( def test_rebuild_single_file_should_raise_when_file_list_value_not_list(
self, self,
@ -2279,13 +2314,12 @@ class TestWorkflowServiceResolveDeliveryMethod:
# Arrange # Arrange
method_a = self._make_method("method-1") method_a = self._make_method("method-1")
method_b = self._make_method("method-2") method_b = self._make_method("method-2")
node_data = MagicMock()
node_data.delivery_methods = [method_a, method_b]
# Act # Act
result = WorkflowService._resolve_human_input_delivery_method( with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a, method_b]):
node_data=node_data, delivery_method_id="method-2" result = WorkflowService._resolve_human_input_delivery_method(
) node_data=MagicMock(), delivery_method_id="method-2"
)
# Assert # Assert
assert result is method_b assert result is method_b
@ -2293,26 +2327,22 @@ class TestWorkflowServiceResolveDeliveryMethod:
def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None: def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None:
# Arrange # Arrange
method_a = self._make_method("method-1") method_a = self._make_method("method-1")
node_data = MagicMock()
node_data.delivery_methods = [method_a]
# Act # Act
result = WorkflowService._resolve_human_input_delivery_method( with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a]):
node_data=node_data, delivery_method_id="does-not-exist" result = WorkflowService._resolve_human_input_delivery_method(
) node_data=MagicMock(), delivery_method_id="does-not-exist"
)
# Assert # Assert
assert result is None assert result is None
def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None: def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None:
# Arrange
node_data = MagicMock()
node_data.delivery_methods = []
# Act # Act
result = WorkflowService._resolve_human_input_delivery_method( with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[]):
node_data=node_data, delivery_method_id="method-1" result = WorkflowService._resolve_human_input_delivery_method(
) node_data=MagicMock(), delivery_method_id="method-1"
)
# Assert # Assert
assert result is None assert result is None
@ -2435,6 +2465,9 @@ class TestWorkflowServiceDraftExecution:
patch("services.workflow_service.Session"), patch("services.workflow_service.Session"),
patch("services.workflow_service.WorkflowDraftVariableService"), patch("services.workflow_service.WorkflowDraftVariableService"),
patch("services.workflow_service.VariablePool") as mock_pool_cls, patch("services.workflow_service.VariablePool") as mock_pool_cls,
patch("services.workflow_service.default_system_variables") as mock_default_system_variables,
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
patch("services.workflow_service.DraftVarLoader"), patch("services.workflow_service.DraftVarLoader"),
patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run,
patch("services.workflow_service.DifyCoreRepositoryFactory"), patch("services.workflow_service.DifyCoreRepositoryFactory"),
@ -2475,10 +2508,16 @@ class TestWorkflowServiceDraftExecution:
) )
# Assert # Assert
# For non-start nodes, VariablePool should be initialized with environment_variables # For non-start nodes, bootstrap variables should be loaded into an empty pool.
mock_pool_cls.assert_called_once() mock_pool_cls.assert_called_once_with()
args, kwargs = mock_pool_cls.call_args mock_default_system_variables.assert_called_once()
assert "environment_variables" in kwargs mock_build_bootstrap_variables.assert_called_once_with(
system_variables=mock_default_system_variables.return_value,
environment_variables=draft_workflow.environment_variables,
)
mock_add_variables_to_pool.assert_called_once_with(
mock_pool_cls.return_value, mock_build_bootstrap_variables.return_value
)
# =========================================================================== # ===========================================================================
@ -2588,7 +2627,7 @@ class TestWorkflowServiceHumanInputOperations:
patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT),
patch("services.workflow_service.HumanInputNodeData.model_validate"), patch("services.workflow_service.HumanInputNodeData.model_validate"),
patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve, patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve,
patch("services.workflow_service.apply_debug_email_recipient"), patch("services.workflow_service.apply_dify_debug_email_recipient"),
patch.object(service, "_build_human_input_variable_pool"), patch.object(service, "_build_human_input_variable_pool"),
patch.object(service, "_build_human_input_node"), patch.object(service, "_build_human_input_node"),
patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])), patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])),
@ -2730,13 +2769,15 @@ class TestWorkflowServiceFreeNodeExecution:
variable_pool = MagicMock() variable_pool = MagicMock()
with ( with (
patch("services.workflow_service.GraphInitParams"), patch("services.workflow_service.GraphInitParams") as mock_graph_init_params,
patch("services.workflow_service.GraphRuntimeState"), patch("services.workflow_service.GraphRuntimeState"),
patch("services.workflow_service.build_dify_run_context"),
patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls,
patch("services.workflow_service.HumanInputNode") as mock_node_cls, patch("services.workflow_service.HumanInputNode") as mock_node_cls,
patch("services.workflow_service.HumanInputFormRepositoryImpl"),
): ):
node = service._build_human_input_node( node = service._build_human_input_node(
workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool
) )
assert node == mock_node_cls.return_value assert node == mock_node_cls.return_value
mock_node_cls.assert_called_once() mock_node_cls.assert_called_once()
mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context)