mirror of https://github.com/langgenius/dify.git
parent
93e8d856d8
commit
3b53f01378
|
|
@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
|
|
@ -76,6 +75,7 @@ def setup_test_context(
|
|||
patch("extensions.ext_database.db") as mock_db,
|
||||
patch("controllers.console.app.wraps.db", mock_db),
|
||||
patch("controllers.console.wraps.db", mock_db),
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.app.message.db", mock_db),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
|
|
@ -99,14 +99,12 @@ def setup_test_context(
|
|||
mock_db.data_query = data_query_mock
|
||||
|
||||
# Let the caller override the stat db query logic
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
|
||||
full_path = f"{route_path}?{query_string}" if qs else route_path
|
||||
|
||||
with (
|
||||
patch("libs.login.current_user", proxy_mock),
|
||||
patch("flask_login.current_user", proxy_mock),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
patch("flask_login.current_user", mock_account),
|
||||
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
|
||||
):
|
||||
with test_app.test_request_context(full_path, method=method, json=payload):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.statistic import (
|
||||
AverageResponseTimeStatistic,
|
||||
|
|
@ -81,9 +80,7 @@ def setup_test_context(
|
|||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with patch("libs.login._get_user", return_value=mock_account), patch("flask_login.current_user", mock_account):
|
||||
with test_app.test_request_context(route_path, method="GET"):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
api_instance = endpoint_class()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
|
|
@ -60,6 +59,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun
|
|||
with (
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.app.workflow_draft_variable.db"),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
|
|
@ -71,9 +71,7 @@ def setup_test_context(test_app, endpoint_class, route_path, method, mock_accoun
|
|||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with patch("libs.login._get_user", return_value=mock_account), patch("flask_login.current_user", mock_account):
|
||||
with test_app.test_request_context(route_path, method=method, json=payload):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
# extract node_id or variable_id from path manually since view_args overrides
|
||||
|
|
@ -101,6 +99,7 @@ class TestWorkflowDraftVariableEndpoints:
|
|||
|
||||
mock_var = MagicMock()
|
||||
mock_var.app_id = "app_123"
|
||||
mock_var.user_id = "user_123"
|
||||
mock_var.id = "var_123"
|
||||
mock_var.name = "test_var"
|
||||
mock_var.description = ""
|
||||
|
|
|
|||
|
|
@ -49,9 +49,7 @@ class TestApiKeyAuthDataSource:
|
|||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
with patch("libs.login._get_user", return_value=mock_account):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
|
|
@ -81,9 +79,7 @@ class TestApiKeyAuthDataSource:
|
|||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
with patch("libs.login._get_user", return_value=mock_account):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
|
|
@ -124,9 +120,10 @@ class TestApiKeyAuthDataSourceBinding:
|
|||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with (
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
patch("flask_login.current_user", mock_account),
|
||||
):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
response = api_instance.post()
|
||||
|
||||
|
|
@ -162,9 +159,10 @@ class TestApiKeyAuthDataSourceBinding:
|
|||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with (
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
patch("flask_login.current_user", mock_account),
|
||||
):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
|
||||
api_instance.post()
|
||||
|
|
@ -198,9 +196,10 @@ class TestApiKeyAuthDataSourceBindingDelete:
|
|||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with (
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
patch("flask_login.current_user", mock_account),
|
||||
):
|
||||
api_instance = ApiKeyAuthDataSourceBindingDelete()
|
||||
response = api_instance.delete("binding_123")
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.auth.data_source_oauth import (
|
||||
OAuthDataSource,
|
||||
|
|
@ -21,12 +20,12 @@ class TestOAuthDataSource:
|
|||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.current_user")
|
||||
@patch("libs.login._get_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
|
||||
def test_get_oauth_url_successful(
|
||||
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
|
||||
self, mock_db, mock_csrf, mock_get_user, mock_flask_user, mock_get_providers, app
|
||||
):
|
||||
mock_oauth_provider = MagicMock()
|
||||
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
|
||||
|
|
@ -39,16 +38,14 @@ class TestOAuthDataSource:
|
|||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
mock_libs_user.return_value = mock_account
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_flask_user.return_value = mock_account
|
||||
|
||||
# also patch current_account_with_tenant
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("notion")
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["data"] == "http://oauth.provider/auth"
|
||||
assert response[1] == 200
|
||||
|
|
@ -71,8 +68,7 @@ class TestOAuthDataSource:
|
|||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
with patch("libs.login._get_user", return_value=mock_account):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("unknown_provider")
|
||||
|
||||
|
|
@ -181,8 +177,7 @@ class TestOAuthDataSourceSync:
|
|||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
with patch("libs.login._get_user", return_value=mock_account):
|
||||
api_instance = OAuthDataSourceSync()
|
||||
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
|
||||
response = api_instance.get("notion", "binding_123")
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ class TestOAuthServerUserAuthorizeApi:
|
|||
mock_sign.return_value = "auth_code_123"
|
||||
|
||||
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
|
||||
with patch("libs.login.current_user", mock_account):
|
||||
with patch("libs.login._get_user", return_value=mock_account):
|
||||
api_instance = OAuthServerUserAuthorizeApi()
|
||||
response = api_instance.post()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,9 +9,10 @@ from typing import Any
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
|
||||
|
||||
from core.repositories.human_input_repository import (
|
||||
FormCreateParams,
|
||||
FormNotFoundError,
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
HumanInputFormSubmissionRepository,
|
||||
|
|
@ -20,16 +21,15 @@ from core.repositories.human_input_repository import (
|
|||
_InvalidTimeoutStatusError,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
from core.workflow.human_input_compat import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
|
@ -212,7 +212,7 @@ def test_recipient_entity_id_and_token_success() -> None:
|
|||
assert entity.token == "tok"
|
||||
|
||||
|
||||
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
|
||||
def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
|
|
@ -229,13 +229,13 @@ def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> No
|
|||
)
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "ctok"
|
||||
assert entity.submission_token == "ctok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "wtok"
|
||||
assert entity.submission_token == "wtok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token is None
|
||||
assert entity.submission_token is None
|
||||
|
||||
|
||||
def test_form_entity_submitted_data_parsed() -> None:
|
||||
|
|
@ -364,8 +364,8 @@ def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch:
|
|||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
include_bound_group=False,
|
||||
items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
subject="s",
|
||||
body="b",
|
||||
|
|
@ -388,7 +388,7 @@ def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatc
|
|||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
|
||||
recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
|
@ -407,8 +407,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m
|
|||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
include_bound_group=False,
|
||||
items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
|
@ -416,8 +416,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m
|
|||
|
||||
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo.get_form("run", "node") is None
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run")
|
||||
assert repo.get_form("node") is None
|
||||
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
|
|
@ -437,8 +437,8 @@ def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.Monke
|
|||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
entity = repo.get_form("run", "node")
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run")
|
||||
entity = repo.get_form("node")
|
||||
assert entity is not None
|
||||
assert entity.id == "f1"
|
||||
assert entity.recipients[0].id == "r1"
|
||||
|
|
@ -454,7 +454,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
|
|||
|
||||
session = _FakeSession()
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
repo = HumanInputFormRepositoryImpl(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_execution_id="run",
|
||||
invoke_source="debugger",
|
||||
submission_actor_id="acc-1",
|
||||
)
|
||||
|
||||
form_config = HumanInputNodeData(
|
||||
title="Title",
|
||||
|
|
@ -464,8 +470,7 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
|
|||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
params = FormCreateParams(
|
||||
app_id="app",
|
||||
workflow_execution_id="run",
|
||||
workflow_execution_id=None,
|
||||
node_id="node",
|
||||
form_config=form_config,
|
||||
rendered_content="<p>hello</p>",
|
||||
|
|
@ -473,16 +478,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M
|
|||
display_in_ui=True,
|
||||
resolved_default_values={},
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
console_recipient_required=True,
|
||||
console_creator_account_id="acc-1",
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
|
||||
entity = repo.create_form(params)
|
||||
assert entity.id == "form-id"
|
||||
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
|
||||
# Console token should take precedence when console recipient is present.
|
||||
assert entity.web_app_token == "token-console"
|
||||
assert entity.submission_token == "token-console"
|
||||
assert len(entity.recipients) == 3
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ from unittest.mock import MagicMock, Mock
|
|||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.factory import OrderConfig
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
_deterministic_json_dump,
|
||||
|
|
@ -25,7 +25,7 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
|||
)
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
|
|
@ -67,7 +67,7 @@ def _execution(
|
|||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-id",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Title",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
|
|
@ -387,7 +387,7 @@ def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch)
|
|||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.node_type = BuiltinNodeTypes.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"trunc": "i"})
|
||||
db_model.process_data = json.dumps({"trunc": "p"})
|
||||
|
|
@ -441,7 +441,7 @@ def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.
|
|||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.node_type = BuiltinNodeTypes.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"i": 1})
|
||||
db_model.process_data = json.dumps({"p": 2})
|
||||
|
|
@ -768,5 +768,5 @@ def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) ->
|
|||
lambda max_workers: FakeExecutor(),
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("run", order_config=None)
|
||||
result = repo.get_by_workflow_execution("run", order_config=None)
|
||||
assert result == ["domain:db1", "domain:db2"]
|
||||
|
|
|
|||
|
|
@ -69,9 +69,10 @@ def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig:
|
|||
def service(mocker: MockerFixture) -> ModelLoadBalancingService:
|
||||
# Arrange
|
||||
provider_manager = MagicMock()
|
||||
mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager)
|
||||
mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager)
|
||||
svc = ModelLoadBalancingService()
|
||||
svc.provider_manager = provider_manager
|
||||
svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign]
|
||||
return svc
|
||||
|
||||
|
||||
|
|
@ -708,7 +709,10 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val
|
|||
load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json")
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.model_credentials_validate.return_value = {"api_key": "validated"}
|
||||
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
)
|
||||
mock_encrypt = mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
|
|
@ -740,7 +744,10 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m
|
|||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"}
|
||||
mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
)
|
||||
mocker.patch(
|
||||
"services.model_load_balancing_service.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc:{value}",
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ This test suite covers:
|
|||
import json
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -2053,22 +2053,37 @@ class TestSetupVariablePool:
|
|||
workflow = self._make_workflow()
|
||||
|
||||
# Act
|
||||
with patch("services.workflow_service.VariablePool") as MockPool:
|
||||
with (
|
||||
patch("services.workflow_service.VariablePool") as MockPool,
|
||||
patch("services.workflow_service.build_system_variables") as mock_build_system_variables,
|
||||
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
|
||||
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
|
||||
patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool,
|
||||
):
|
||||
_setup_variable_pool(
|
||||
query="hello",
|
||||
files=[],
|
||||
user_id="u-1",
|
||||
user_inputs={"k": "v"},
|
||||
workflow=workflow,
|
||||
node_id="start-node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
conversation_id="conv-1",
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Assert — VariablePool should be called with a SystemVariable (non-default)
|
||||
MockPool.assert_called_once()
|
||||
call_kwargs = MockPool.call_args.kwargs
|
||||
assert call_kwargs["user_inputs"] == {"k": "v"}
|
||||
# Assert — start nodes should build bootstrap variables and attach node inputs.
|
||||
MockPool.assert_called_once_with()
|
||||
mock_build_system_variables.assert_called_once()
|
||||
mock_add_variables_to_pool.assert_called_once_with(
|
||||
MockPool.return_value,
|
||||
mock_build_bootstrap_variables.return_value,
|
||||
)
|
||||
mock_add_node_inputs_to_pool.assert_called_once_with(
|
||||
MockPool.return_value,
|
||||
node_id="start-node",
|
||||
inputs={"k": "v"},
|
||||
)
|
||||
|
||||
def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node(
|
||||
self,
|
||||
|
|
@ -2079,7 +2094,10 @@ class TestSetupVariablePool:
|
|||
# Act
|
||||
with (
|
||||
patch("services.workflow_service.VariablePool") as MockPool,
|
||||
patch("services.workflow_service.SystemVariable.default") as mock_default,
|
||||
patch("services.workflow_service.default_system_variables") as mock_default_system_variables,
|
||||
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
|
||||
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
|
||||
patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool,
|
||||
):
|
||||
_setup_variable_pool(
|
||||
query="",
|
||||
|
|
@ -2087,14 +2105,20 @@ class TestSetupVariablePool:
|
|||
user_id="u-1",
|
||||
user_inputs={},
|
||||
workflow=workflow,
|
||||
node_id="llm-node",
|
||||
node_type=BuiltinNodeTypes.LLM, # not a start/trigger node
|
||||
conversation_id="conv-1",
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Assert — SystemVariable.default() should be used for non-start nodes
|
||||
mock_default.assert_called_once()
|
||||
MockPool.assert_called_once()
|
||||
# Assert — default system variables should be used and node inputs should not be added.
|
||||
mock_default_system_variables.assert_called_once()
|
||||
MockPool.assert_called_once_with()
|
||||
mock_add_variables_to_pool.assert_called_once_with(
|
||||
MockPool.return_value,
|
||||
mock_build_bootstrap_variables.return_value,
|
||||
)
|
||||
mock_add_node_inputs_to_pool.assert_not_called()
|
||||
|
||||
def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type(
|
||||
self,
|
||||
|
|
@ -2106,20 +2130,31 @@ class TestSetupVariablePool:
|
|||
workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value)
|
||||
|
||||
# Act
|
||||
with patch("services.workflow_service.VariablePool") as MockPool:
|
||||
with (
|
||||
patch("services.workflow_service.VariablePool") as MockPool,
|
||||
patch("services.workflow_service.build_system_variables") as mock_build_system_variables,
|
||||
patch("services.workflow_service.build_bootstrap_variables"),
|
||||
patch("services.workflow_service.add_variables_to_pool"),
|
||||
patch("services.workflow_service.add_node_inputs_to_pool"),
|
||||
):
|
||||
_setup_variable_pool(
|
||||
query="what is AI?",
|
||||
files=[],
|
||||
user_id="u-1",
|
||||
user_inputs={},
|
||||
workflow=workflow,
|
||||
node_id="start-node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
conversation_id="conv-abc",
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Assert — we just verify VariablePool was called (chatflow path executed)
|
||||
MockPool.assert_called_once()
|
||||
# Assert — chatflow system variables should include query, conversation_id and dialogue_count.
|
||||
MockPool.assert_called_once_with()
|
||||
system_variable_values = mock_build_system_variables.call_args.args[0]
|
||||
assert system_variable_values["query"] == "what is AI?"
|
||||
assert system_variable_values["conversation_id"] == "conv-abc"
|
||||
assert system_variable_values["dialogue_count"] == 1
|
||||
|
||||
|
||||
class TestRebuildSingleFile:
|
||||
|
|
@ -2142,7 +2177,7 @@ class TestRebuildSingleFile:
|
|||
|
||||
# Assert
|
||||
assert result is mock_file
|
||||
mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id)
|
||||
mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id, access_controller=ANY)
|
||||
|
||||
def test_rebuild_single_file_should_raise_when_file_value_not_dict(
|
||||
self,
|
||||
|
|
@ -2165,7 +2200,7 @@ class TestRebuildSingleFile:
|
|||
|
||||
# Assert
|
||||
assert result is mock_files
|
||||
mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id)
|
||||
mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id, access_controller=ANY)
|
||||
|
||||
def test_rebuild_single_file_should_raise_when_file_list_value_not_list(
|
||||
self,
|
||||
|
|
@ -2279,13 +2314,12 @@ class TestWorkflowServiceResolveDeliveryMethod:
|
|||
# Arrange
|
||||
method_a = self._make_method("method-1")
|
||||
method_b = self._make_method("method-2")
|
||||
node_data = MagicMock()
|
||||
node_data.delivery_methods = [method_a, method_b]
|
||||
|
||||
# Act
|
||||
result = WorkflowService._resolve_human_input_delivery_method(
|
||||
node_data=node_data, delivery_method_id="method-2"
|
||||
)
|
||||
with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a, method_b]):
|
||||
result = WorkflowService._resolve_human_input_delivery_method(
|
||||
node_data=MagicMock(), delivery_method_id="method-2"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is method_b
|
||||
|
|
@ -2293,26 +2327,22 @@ class TestWorkflowServiceResolveDeliveryMethod:
|
|||
def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None:
|
||||
# Arrange
|
||||
method_a = self._make_method("method-1")
|
||||
node_data = MagicMock()
|
||||
node_data.delivery_methods = [method_a]
|
||||
|
||||
# Act
|
||||
result = WorkflowService._resolve_human_input_delivery_method(
|
||||
node_data=node_data, delivery_method_id="does-not-exist"
|
||||
)
|
||||
with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a]):
|
||||
result = WorkflowService._resolve_human_input_delivery_method(
|
||||
node_data=MagicMock(), delivery_method_id="does-not-exist"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None:
|
||||
# Arrange
|
||||
node_data = MagicMock()
|
||||
node_data.delivery_methods = []
|
||||
|
||||
# Act
|
||||
result = WorkflowService._resolve_human_input_delivery_method(
|
||||
node_data=node_data, delivery_method_id="method-1"
|
||||
)
|
||||
with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[]):
|
||||
result = WorkflowService._resolve_human_input_delivery_method(
|
||||
node_data=MagicMock(), delivery_method_id="method-1"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
|
@ -2435,6 +2465,9 @@ class TestWorkflowServiceDraftExecution:
|
|||
patch("services.workflow_service.Session"),
|
||||
patch("services.workflow_service.WorkflowDraftVariableService"),
|
||||
patch("services.workflow_service.VariablePool") as mock_pool_cls,
|
||||
patch("services.workflow_service.default_system_variables") as mock_default_system_variables,
|
||||
patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables,
|
||||
patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool,
|
||||
patch("services.workflow_service.DraftVarLoader"),
|
||||
patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run,
|
||||
patch("services.workflow_service.DifyCoreRepositoryFactory"),
|
||||
|
|
@ -2475,10 +2508,16 @@ class TestWorkflowServiceDraftExecution:
|
|||
)
|
||||
|
||||
# Assert
|
||||
# For non-start nodes, VariablePool should be initialized with environment_variables
|
||||
mock_pool_cls.assert_called_once()
|
||||
args, kwargs = mock_pool_cls.call_args
|
||||
assert "environment_variables" in kwargs
|
||||
# For non-start nodes, bootstrap variables should be loaded into an empty pool.
|
||||
mock_pool_cls.assert_called_once_with()
|
||||
mock_default_system_variables.assert_called_once()
|
||||
mock_build_bootstrap_variables.assert_called_once_with(
|
||||
system_variables=mock_default_system_variables.return_value,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
)
|
||||
mock_add_variables_to_pool.assert_called_once_with(
|
||||
mock_pool_cls.return_value, mock_build_bootstrap_variables.return_value
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
|
|
@ -2588,7 +2627,7 @@ class TestWorkflowServiceHumanInputOperations:
|
|||
patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT),
|
||||
patch("services.workflow_service.HumanInputNodeData.model_validate"),
|
||||
patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve,
|
||||
patch("services.workflow_service.apply_debug_email_recipient"),
|
||||
patch("services.workflow_service.apply_dify_debug_email_recipient"),
|
||||
patch.object(service, "_build_human_input_variable_pool"),
|
||||
patch.object(service, "_build_human_input_node"),
|
||||
patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])),
|
||||
|
|
@ -2730,13 +2769,15 @@ class TestWorkflowServiceFreeNodeExecution:
|
|||
variable_pool = MagicMock()
|
||||
|
||||
with (
|
||||
patch("services.workflow_service.GraphInitParams"),
|
||||
patch("services.workflow_service.GraphInitParams") as mock_graph_init_params,
|
||||
patch("services.workflow_service.GraphRuntimeState"),
|
||||
patch("services.workflow_service.build_dify_run_context"),
|
||||
patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls,
|
||||
patch("services.workflow_service.HumanInputNode") as mock_node_cls,
|
||||
patch("services.workflow_service.HumanInputFormRepositoryImpl"),
|
||||
):
|
||||
node = service._build_human_input_node(
|
||||
workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool
|
||||
)
|
||||
assert node == mock_node_cls.return_value
|
||||
mock_node_cls.assert_called_once()
|
||||
mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context)
|
||||
|
|
|
|||
Loading…
Reference in New Issue