Merge branch 'fix/memeber-settings' into deploy/dev

# Conflicts:
#	web/app/components/app/type-selector/index.tsx
#	web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx
#	web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx
#	web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx
#	web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx
#	web/app/components/header/account-setting/data-source-page/panel/config-item.tsx
#	web/app/components/header/account-setting/data-source-page/panel/index.tsx
#	web/app/components/tools/labels/filter.tsx
#	web/app/components/workflow/nodes/knowledge-retrieval/node.tsx
This commit is contained in:
yyh 2026-03-23 16:31:18 +08:00
commit b3bcaed9b9
No known key found for this signature in database
382 changed files with 20821 additions and 7919 deletions

View File

@ -94,11 +94,6 @@ jobs:
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- name: Setup web environment
if: steps.web-changes.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web

View File

@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process.
## Getting Help
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
## Automated Agent Contributions
> [!NOTE]
> If you are an automated agent just add 🤖🤖🤖 to the end of the PR title to opt-in.

View File

@ -1,7 +1,7 @@
import flask_restx
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -33,16 +33,10 @@ api_key_list_model = console_ns.model(
def _get_resource(resource_id, tenant_id, resource_model):
if resource_model == App:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
else:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
@ -80,10 +74,13 @@ class BaseApiKeyListResource(Resource):
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
current_key_count: int = (
db.session.scalar(
select(func.count(ApiToken.id)).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
)
or 0
)
if current_key_count >= self.max_keys:
@ -119,14 +116,14 @@ class BaseApiKeyResource(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
key = (
db.session.query(ApiToken)
key = db.session.scalar(
select(ApiToken)
.where(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
.limit(1)
)
if key is None:
@ -137,7 +134,7 @@ class BaseApiKeyResource(Resource):
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.commit()
return {"result": "success"}, 204

View File

@ -5,7 +5,7 @@ from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -376,8 +376,12 @@ class CompletionConversationApi(Resource):
# FIXME, the type ignore in this file
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
query = (
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
)
elif args.annotation_status == "not_annotated":
query = (
@ -511,8 +515,12 @@ class ChatConversationApi(Resource):
match args.annotation_status:
case "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
query = (
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
)
case "not_annotated":
query = (

View File

@ -4,7 +4,7 @@ from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -244,27 +244,25 @@ class ChatMessageListApi(Resource):
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = (
db.session.query(Conversation)
conversation = db.session.scalar(
select(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first()
.limit(1)
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args.first_id:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
first_message = db.session.scalar(
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
)
if not first_message:
raise NotFound("First message not found")
history_messages = (
db.session.query(Message)
history_messages = db.session.scalars(
select(Message)
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
@ -272,16 +270,14 @@ class ChatMessageListApi(Resource):
)
.order_by(Message.created_at.desc())
.limit(args.limit)
.all()
)
).all()
else:
history_messages = (
db.session.query(Message)
history_messages = db.session.scalars(
select(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args.limit)
.all()
)
).all()
# Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit:
@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource):
message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")
@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
count = db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
)
return {"count": count}
@ -479,7 +479,9 @@ class MessageApi(Resource):
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")

View File

@ -7,7 +7,7 @@ from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
@ -46,13 +46,14 @@ from models import App
from models.model import AppMode
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -284,7 +285,9 @@ class DraftWorkflowApi(Resource):
workflow_service = WorkflowService()
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables_list = Workflow.normalize_environment_variable_mappings(
args.get("environment_variables") or [],
)
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
@ -994,6 +997,43 @@ class PublishedAllWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>/restore")
class DraftWorkflowRestoreApi(Resource):
@console_ns.doc("restore_workflow_to_draft")
@console_ns.doc(description="Restore a published workflow version into the draft workflow")
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"})
@console_ns.response(200, "Workflow restored successfully")
@console_ns.response(400, "Source workflow must be published")
@console_ns.response(404, "Workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, workflow_id: str):
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
try:
workflow = workflow_service.restore_published_workflow_to_draft(
app_model=app_model,
workflow_id=workflow_id,
account=current_user,
)
except IsDraftWorkflowError as exc:
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
except ValueError as exc:
raise BadRequest(str(exc)) from exc
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id")

View File

@ -1,4 +1,5 @@
import logging
import urllib.parse
import httpx
from flask import current_app, redirect, request
@ -112,6 +113,9 @@ class OAuthCallback(Resource):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400
except ValueError as e:
logger.warning("OAuth error with %s", provider, exc_info=True)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService.get_invitation_by_token(token=invite_token)

View File

@ -6,7 +6,7 @@ from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
@ -16,7 +16,11 @@ from controllers.console.app.error import (
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
from controllers.console.app.workflow import (
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
workflow_model,
workflow_pagination_model,
)
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
@ -42,7 +46,8 @@ from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError
from models.workflow import Workflow
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -203,9 +208,12 @@ class DraftRagPipelineApi(Resource):
abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
rag_pipeline_service = RagPipelineService()
try:
environment_variables_list = payload.environment_variables or []
environment_variables_list = Workflow.normalize_environment_variable_mappings(
payload.environment_variables or [],
)
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
@ -213,7 +221,6 @@ class DraftRagPipelineApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=payload.graph,
@ -705,6 +712,36 @@ class PublishedAllRagPipelineApi(Resource):
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>/restore")
class RagPipelineDraftWorkflowRestoreApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, workflow_id: str):
current_user, _ = current_account_with_tenant()
rag_pipeline_service = RagPipelineService()
try:
workflow = rag_pipeline_service.restore_published_workflow_to_draft(
pipeline=pipeline,
workflow_id=workflow_id,
account=current_user,
)
except IsDraftWorkflowError as exc:
# Use a stable, predefined message to keep the 400 response consistent
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
except WorkflowNotFoundError as exc:
raise NotFound(str(exc)) from exc
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource):
@setup_required

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
@ -17,14 +18,18 @@ class BannerApi(Resource):
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
banners = db.session.scalars(
base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort)
).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
banners = db.session.scalars(
base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort)
).all()
# Convert banners to serializable format
result = []
for banner in banners:

View File

@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource):
def post(self):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
recommended_app = db.session.scalar(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1)
)
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == payload.app_id).first()
app = db.session.get(App, payload.app_id)
if app is None:
raise NotFound("App entity not found")
@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource):
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = (
db.session.query(InstalledApp)
installed_app = db.session.scalar(
select(InstalledApp)
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.first()
.limit(1)
)
if installed_app is None:

View File

@ -4,6 +4,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@ -476,7 +477,7 @@ class TrialSitApi(Resource):
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise Forbidden()
@ -541,13 +542,7 @@ class AppWorkflowApi(Resource):
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
workflow = db.session.get(Workflow, app_model.workflow_id)
return workflow

View File

@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
installed_app = (
db.session.query(InstalledApp)
installed_app = db.session.scalar(
select(InstalledApp)
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if installed_app is None:
@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1))
if trial_app is None:
raise TrialAppNotAllowed()
@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
account_trial_app_record = db.session.scalar(
select(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
.limit(1)
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:

View File

@ -2,6 +2,7 @@ from typing import Literal
from flask import request
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from configs import dify_config
from controllers.fastopenapi import console_router
@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
def get_setup_status() -> DifySetup | bool | None:
if dify_config.EDITION == "SELF_HOSTED":
return db.session.query(DifySetup).first()
return db.session.scalar(select(DifySetup).limit(1))
return True

View File

@ -212,13 +212,13 @@ class AccountInitApi(Resource):
raise ValueError("invitation_code is required")
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
invitation_code = db.session.scalar(
select(InvitationCode)
.where(
InvitationCode.code == args.invitation_code,
InvitationCode.status == InvitationCodeStatus.UNUSED,
)
.first()
.limit(1)
)
if not invitation_code:

View File

@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
member = db.session.get(Account, str(member_id))
if member is None:
abort(404)
else:

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
@ -29,6 +30,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
@ -108,9 +110,29 @@ class TenantListApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED
tenant_plans: dict[str, SubscriptionPlan] = {}
if is_saas:
tenant_ids = [tenant.id for tenant in tenants]
if tenant_ids:
tenant_plans = BillingService.get_plan_bulk(tenant_ids)
if not tenant_plans:
logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path")
for tenant in tenants:
features = FeatureService.get_features(tenant.id)
plan: str = CloudPlan.SANDBOX
if is_saas:
tenant_plan = tenant_plans.get(tenant.id)
if tenant_plan:
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
else:
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
elif not is_enterprise_only:
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
# Create a dictionary with tenant attributes
tenant_dict = {
@ -118,7 +140,7 @@ class TenantListApi(Resource):
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
"plan": plan,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}
@ -198,7 +220,7 @@ class SwitchWorkspaceApi(Resource):
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant
if new_tenant is None:
raise ValueError("Tenant not found")

View File

@ -7,6 +7,7 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from sqlalchemy import select
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"
and os.environ.get("INIT_PASSWORD")
and not db.session.query(DifySetup).first()
):
raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
if os.environ.get("INIT_PASSWORD"):
raise NotInitValidateError()
raise NotSetupError()
return view(*args, **kwargs)

View File

@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.enums import ConversationFromSource, CreatorUserRole, MessageFileBelongsTo
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
@ -130,10 +130,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
from_source = "api"
from_source = ConversationFromSource.API
end_user_id = application_generate_entity.user_id
else:
from_source = "console"
from_source = ConversationFromSource.CONSOLE
account_id = application_generate_entity.user_id
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):

View File

@ -6,7 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.enums import CollectionBindingType, ConversationFromSource
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
@ -68,9 +68,9 @@ class AnnotationReplyFeature:
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
from_source = "api"
from_source = ConversationFromSource.API
else:
from_source = "console"
from_source = ConversationFromSource.CONSOLE
# insert annotation history
AppAnnotationService.add_annotation_history(

View File

@ -181,10 +181,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
arize_phoenix_config: ArizeConfig | PhoenixConfig,
):
super().__init__(arize_phoenix_config)
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
self.arize_phoenix_config = arize_phoenix_config
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project

View File

@ -284,27 +284,29 @@ class TidbOnQdrantVector(BaseVector):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
if not ids:
return
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=ids),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []

View File

@ -50,7 +50,7 @@ class BuiltinTool(Tool):
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
)

View File

@ -38,7 +38,7 @@ class ToolLabelManager:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type.value,
tool_type=controller.provider_type,
label_name=label,
)
)
@ -58,7 +58,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
ToolLabelBinding.tool_type == controller.provider_type,
)
labels = db.session.scalars(stmt).all()

View File

@ -9,6 +9,7 @@ from decimal import Decimal
from typing import cast
from core.model_manager import ModelManager
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.model_runtime.entities.llm_entities import LLMResult
from dify_graph.model_runtime.entities.message_entities import PromptMessage
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
@ -78,7 +79,7 @@ class ModelInvocationUtils:
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context

View File

@ -101,6 +101,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
http_request_config=self._http_request_config,
# Must be 0 to disable executor-level retries, as the graph engine handles them.
# This is critical to prevent nested retries.
max_retries=0,
ssl_verify=self.node_data.ssl_verify,
http_client=self._http_client,
file_manager=self._file_manager,

View File

@ -1,16 +1,19 @@
import logging
import sys
import urllib.parse
from dataclasses import dataclass
from typing import NotRequired
import httpx
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
logger = logging.getLogger(__name__)
JsonObject = dict[str, object]
JsonObjectList = list[JsonObject]
@ -30,8 +33,8 @@ class GitHubEmailRecord(TypedDict, total=False):
class GitHubRawUserInfo(TypedDict):
id: int | str
login: str
name: NotRequired[str]
email: NotRequired[str]
name: NotRequired[str | None]
email: NotRequired[str | None]
class GoogleRawUserInfo(TypedDict):
@ -127,9 +130,14 @@ class GitHubOAuth(OAuth):
response.raise_for_status()
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
try:
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_response.raise_for_status()
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
except (httpx.HTTPStatusError, ValidationError):
logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True)
primary_email = None
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
@ -137,8 +145,11 @@ class GitHubOAuth(OAuth):
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
email = payload.get("email")
if not email:
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
raise ValueError(
'Dify currently not supports the "Keep my email addresses private" feature,'
" please disable it and login again"
)
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email)
class GoogleOAuth(OAuth):

View File

@ -34,13 +34,16 @@ from .enums import (
AppMCPServerStatus,
AppStatus,
BannerStatus,
ConversationFromSource,
ConversationStatus,
CreatorUserRole,
FeedbackFromSource,
FeedbackRating,
InvokeFrom,
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
TagType,
)
from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID
@ -1022,10 +1025,12 @@ class Conversation(Base):
#
# Its value corresponds to the members of `InvokeFrom`.
# (api/core/app/entities/app_invoke_entities.py)
invoke_from = mapped_column(String(255), nullable=True)
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
# ref: ConversationSource.
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[ConversationFromSource] = mapped_column(
EnumText(ConversationFromSource, length=255), nullable=False
)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
read_at = mapped_column(sa.DateTime)
@ -1374,8 +1379,10 @@ class Message(Base):
)
error: Mapped[str | None] = mapped_column(LongText)
message_metadata: Mapped[str | None] = mapped_column(LongText)
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
from_source: Mapped[ConversationFromSource] = mapped_column(
EnumText(ConversationFromSource, length=255), nullable=False
)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
@ -2398,7 +2405,7 @@ class Tag(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False)
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(

View File

@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from core.tools.entities.tool_entities import (
ApiProviderSchemaType,
ToolProviderType,
WorkflowToolParameterConfiguration,
)
from .base import TypeBase
from .engine import db
from .model import Account, App, Tenant
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
if TYPE_CHECKING:
from core.entities.mcp_provider import MCPProviderEntity
@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase):
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# label name
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase):
# provider
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# tool name
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
# invoke parameters

View File

@ -1,3 +1,4 @@
import copy
import json
import logging
from collections.abc import Generator, Mapping, Sequence
@ -302,26 +303,40 @@ class Workflow(Base): # bug
def features(self) -> str:
"""
Convert old features structure to new features structure.
This property avoids rewriting the underlying JSON when normalization
produces no effective change, to prevent marking the row dirty on read.
"""
if not self._features:
return self._features
features = json.loads(self._features)
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
image_enabled = True
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
image_transfer_methods = features["file_upload"]["image"].get(
"transfer_methods", ["remote_url", "local_file"]
)
features["file_upload"]["enabled"] = image_enabled
features["file_upload"]["number_limits"] = image_number_limits
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
"allowed_file_extensions", []
)
del features["file_upload"]["image"]
self._features = json.dumps(features)
# Parse once and deep-copy before normalization to detect in-place changes.
original_dict = self._decode_features_payload(self._features)
if original_dict is None:
return self._features
# Fast-path: if the legacy file_upload.image.enabled shape is absent, skip
# deep-copy and normalization entirely and return the stored JSON.
file_upload_payload = original_dict.get("file_upload")
if not isinstance(file_upload_payload, dict):
return self._features
file_upload = cast(dict[str, Any], file_upload_payload)
image_payload = file_upload.get("image")
if not isinstance(image_payload, dict):
return self._features
image = cast(dict[str, Any], image_payload)
if "enabled" not in image:
return self._features
normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict))
if normalized_dict == original_dict:
# No effective change; return stored JSON unchanged.
return self._features
# Normalization changed the payload: persist the normalized JSON.
self._features = json.dumps(normalized_dict)
return self._features
@features.setter
@ -332,6 +347,44 @@ class Workflow(Base): # bug
def features_dict(self) -> dict[str, Any]:
return json.loads(self.features) if self.features else {}
@property
def serialized_features(self) -> str:
"""Return the stored features JSON without triggering compatibility rewrites."""
return self._features
@property
def normalized_features_dict(self) -> dict[str, Any]:
"""Decode features with legacy normalization without mutating the model state."""
if not self._features:
return {}
features = self._decode_features_payload(self._features)
return self._normalize_features_payload(features) if features is not None else {}
@staticmethod
def _decode_features_payload(features: str) -> dict[str, Any] | None:
"""Decode workflow features JSON when it contains an object payload."""
payload = json.loads(features)
return cast(dict[str, Any], payload) if isinstance(payload, dict) else None
@staticmethod
def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]:
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
image_transfer_methods = features["file_upload"]["image"].get(
"transfer_methods", ["remote_url", "local_file"]
)
features["file_upload"]["enabled"] = True
features["file_upload"]["number_limits"] = image_number_limits
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
"allowed_file_extensions", []
)
del features["file_upload"]["image"]
return features
def walk_nodes(
self, specific_node_type: NodeType | None = None
) -> Generator[tuple[str, Mapping[str, Any]], None, None]:
@ -517,6 +570,31 @@ class Workflow(Base): # bug
)
self._environment_variables = environment_variables_json
@staticmethod
def normalize_environment_variable_mappings(
mappings: Sequence[Mapping[str, Any]],
) -> list[dict[str, Any]]:
"""Convert masked secret placeholders into the draft hidden sentinel.
Regular draft sync requests should preserve existing secrets without shipping
plaintext values back from the client. The dedicated restore endpoint now
copies published secrets server-side, so draft sync only needs to normalize
the UI mask into `HIDDEN_VALUE`.
"""
masked_secret_value = encrypter.full_mask_token()
normalized_mappings: list[dict[str, Any]] = []
for mapping in mappings:
normalized_mapping = dict(mapping)
if (
normalized_mapping.get("value_type") == SegmentType.SECRET.value
and normalized_mapping.get("value") == masked_secret_value
):
normalized_mapping["value"] = HIDDEN_VALUE
normalized_mappings.append(normalized_mapping)
return normalized_mappings
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
environment_variables = list(self.environment_variables)
environment_variables = [
@ -564,6 +642,12 @@ class Workflow(Base): # bug
ensure_ascii=False,
)
def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None:
"""Copy stored variable JSON directly for same-tenant restore flows."""
self._environment_variables = source_workflow._environment_variables
self._conversation_variables = source_workflow._conversation_variables
self._rag_pipeline_variables = source_workflow._rag_pipeline_variables
@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)

View File

@ -8,7 +8,7 @@ dependencies = [
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.3",
"beautifulsoup4==4.14.3",
"boto3==1.42.68",
"boto3==1.42.73",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.6.2",
@ -23,7 +23,7 @@ dependencies = [
"gevent~=25.9.1",
"gmpy2~=2.3.0",
"google-api-core>=2.19.1",
"google-api-python-client==2.192.0",
"google-api-python-client==2.193.0",
"google-auth>=2.47.0",
"google-auth-httplib2==0.3.0",
"google-cloud-aiplatform>=1.123.0",
@ -40,7 +40,7 @@ dependencies = [
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.28.0",
"opentelemetry-distro==0.49b0",
"opentelemetry-exporter-otlp==1.28.0",
@ -72,13 +72,14 @@ dependencies = [
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=7.3.0",
"resend~=2.23.0",
"sentry-sdk[flask]~=2.54.0",
"resend~=2.26.0",
"sentry-sdk[flask]~=2.55.0",
"sqlalchemy~=2.0.29",
"starlette==0.52.1",
"starlette==1.0.0",
"tiktoken~=0.12.0",
"transformers~=5.3.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
"pypandoc~=1.13",
"yarl~=1.23.0",
"webvtt-py~=0.5.1",
"sseclient-py~=1.9.0",
@ -91,7 +92,7 @@ dependencies = [
"apscheduler>=3.11.0",
"weave>=0.52.16",
"fastopenapi[flask]>=0.7.0",
"bleach~=6.2.0",
"bleach~=6.3.0",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@ -118,7 +119,7 @@ dev = [
"ruff~=0.15.5",
"pytest~=9.0.2",
"pytest-benchmark~=5.2.3",
"pytest-cov~=7.0.0",
"pytest-cov~=7.1.0",
"pytest-env~=1.6.0",
"pytest-mock~=3.15.1",
"testcontainers~=4.14.1",

View File

@ -385,7 +385,11 @@ class BillingService:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
plan_dict = json.loads(json_str)
# NOTE (hj24): New billing versions may return timestamp as str, and validate_python
# in non-strict mode will coerce it to the expected int type.
# To preserve compatibility, always keep non-strict mode here and avoid strict mode.
subscription_plan = subscription_adapter.validate_python(plan_dict)
# NOTE END
tenant_plans[tenant_id] = subscription_plan
except Exception:
logger.exception(

View File

@ -79,10 +79,11 @@ from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
PipelineTemplateInfoEntity,
)
from services.errors.app import WorkflowHashNotEqualError
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader
from services.workflow_restore import apply_published_workflow_snapshot_to_draft
logger = logging.getLogger(__name__)
@ -234,6 +235,21 @@ class RagPipelineService:
return workflow
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""Fetch a published workflow snapshot by ID for restore operations."""
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == workflow_id,
)
.first()
)
if workflow and workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("source workflow must be published")
return workflow
def get_all_published_workflow(
self,
*,
@ -327,6 +343,42 @@ class RagPipelineService:
# return draft workflow
return workflow
def restore_published_workflow_to_draft(
self,
*,
pipeline: Pipeline,
workflow_id: str,
account: Account,
) -> Workflow:
"""Restore a published pipeline workflow snapshot into the draft workflow.
Pipelines reuse the shared draft-restore field copy helper, but still own
the pipeline-specific flush/link step that wires a newly created draft
back onto ``pipeline.workflow_id``.
"""
source_workflow = self.get_published_workflow_by_id(pipeline=pipeline, workflow_id=workflow_id)
if not source_workflow:
raise WorkflowNotFoundError("Workflow not found.")
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
source_workflow=source_workflow,
draft_workflow=draft_workflow,
account=account,
updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None),
)
if is_new_draft:
db.session.add(draft_workflow)
db.session.flush()
pipeline.workflow_id = draft_workflow.id
db.session.commit()
return draft_workflow
def publish_workflow(
self,
*,

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding
@ -83,7 +84,7 @@ class TagService:
raise ValueError("Tag name already exists")
tag = Tag(
name=args["name"],
type=args["type"],
type=TagType(args["type"]),
created_by=current_user.id,
tenant_id=current_user.current_tenant_id,
)

View File

@ -1,10 +1,10 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from typing_extensions import TypedDict
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_runtime import ToolRuntime
@ -28,9 +28,16 @@ from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ApiSchemaParseResult(TypedDict):
schema_type: str
parameters_schema: list[dict[str, Any]]
credentials_schema: list[dict[str, Any]]
warning: dict[str, str]
class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> Mapping[str, Any]:
def parser_api_schema(schema: str) -> ApiSchemaParseResult:
"""
parse api schema to tool bundle
"""
@ -71,7 +78,7 @@ class ApiToolManageService:
]
return cast(
Mapping,
ApiSchemaParseResult,
jsonable_encoder(
{
"schema_type": schema_type,

View File

@ -18,6 +18,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.types import Tool as MCPTool
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.utils.encryption import ProviderConfigEncrypter
from models.tools import MCPToolProvider
@ -681,7 +682,7 @@ class MCPToolManageService:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list[MCPTool]
) -> ToolProviderApiEntity:
"""Build API response for tool provider."""
user = db_provider.load_user()
@ -703,7 +704,7 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {server_url} already exists")
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
raise error
def _is_valid_url(self, url: str) -> bool:
"""Validate URL format."""

View File

@ -1,5 +1,7 @@
import json
from typing import Any, TypedDict
from typing import Any
from typing_extensions import TypedDict
from core.app.app_config.entities import (
DatasetEntity,
@ -34,6 +36,17 @@ class _NodeType(TypedDict):
data: dict[str, Any]
class _EdgeType(TypedDict):
id: str
source: str
target: str
class WorkflowGraph(TypedDict):
nodes: list[_NodeType]
edges: list[_EdgeType]
class WorkflowConverter:
"""
App Convert to Workflow Mode
@ -107,7 +120,7 @@ class WorkflowConverter:
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
# init workflow graph
graph: dict[str, Any] = {"nodes": [], "edges": []}
graph: WorkflowGraph = {"nodes": [], "edges": []}
# Convert list:
# - variables -> start
@ -385,7 +398,7 @@ class WorkflowConverter:
self,
original_app_mode: AppMode,
new_app_mode: AppMode,
graph: dict,
graph: WorkflowGraph,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: FileUploadConfig | None = None,
@ -595,7 +608,7 @@ class WorkflowConverter:
"data": {"title": "ANSWER", "type": BuiltinNodeTypes.ANSWER, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str):
def _create_edge(self, source: str, target: str) -> _EdgeType:
"""
Create Edge
:param source: source node id
@ -604,7 +617,7 @@ class WorkflowConverter:
"""
return {"id": f"{source}-{target}", "source": source, "target": target}
def _append_node(self, graph: dict[str, Any], node: _NodeType):
def _append_node(self, graph: WorkflowGraph, node: _NodeType):
"""
Append Node to Graph

View File

@ -5,6 +5,7 @@ from typing import Any
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from dify_graph.enums import WorkflowExecutionStatus
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
@ -14,6 +15,10 @@ from services.plugin.plugin_service import PluginService
from services.workflow.entities import TriggerMetadata
class LogViewDetails(TypedDict):
trigger_metadata: dict[str, Any] | None
# Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it
class LogView:
"""Lightweight wrapper for WorkflowAppLog with computed details.
@ -22,12 +27,12 @@ class LogView:
- Proxies all other attributes to the underlying `WorkflowAppLog`
"""
def __init__(self, log: WorkflowAppLog, details: dict | None):
def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None):
self.log = log
self.details_ = details
@property
def details(self) -> dict | None:
def details(self) -> LogViewDetails | None:
return self.details_
def __getattr__(self, name):

View File

@ -35,7 +35,7 @@ from factories.variable_factory import build_segment, segment_to_variable
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models import Account, App, Conversation
from models.enums import DraftVariableType
from models.enums import ConversationFromSource, DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable
from repositories.factory import DifyAPIRepositoryFactory
from services.file_service import FileService
@ -601,7 +601,7 @@ class WorkflowDraftVariableService:
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.DEBUGGER,
from_source="console",
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account_id,
)

View File

@ -0,0 +1,58 @@
"""Shared helpers for restoring published workflow snapshots into drafts.
Both app workflows and RAG pipeline workflows restore the same workflow fields
from a published snapshot into a draft. Keeping that field-copy logic in one
place prevents the two restore paths from drifting when we add or adjust draft
state in the future. Restore stays within a tenant, so we can safely reuse the
serialized workflow storage blobs without decrypting and re-encrypting secrets.
"""
from collections.abc import Callable
from datetime import datetime
from models import Account
from models.workflow import Workflow, WorkflowType
UpdatedAtFactory = Callable[[], datetime]
def apply_published_workflow_snapshot_to_draft(
*,
tenant_id: str,
app_id: str,
source_workflow: Workflow,
draft_workflow: Workflow | None,
account: Account,
updated_at_factory: UpdatedAtFactory,
) -> tuple[Workflow, bool]:
"""Copy a published workflow snapshot into a draft workflow record.
The caller remains responsible for source lookup, validation, flushing, and
post-commit side effects. This helper only centralizes the shared draft
creation/update semantics used by both restore entry points. Features are
copied from the stored JSON payload so restore does not normalize and dirty
the published source row before the caller commits.
"""
if not draft_workflow:
workflow_type = (
source_workflow.type.value if isinstance(source_workflow.type, WorkflowType) else source_workflow.type
)
draft_workflow = Workflow(
tenant_id=tenant_id,
app_id=app_id,
type=workflow_type,
version=Workflow.VERSION_DRAFT,
graph=source_workflow.graph,
features=source_workflow.serialized_features,
created_by=account.id,
)
draft_workflow.copy_serialized_variable_storage_from(source_workflow)
return draft_workflow, True
draft_workflow.graph = source_workflow.graph
draft_workflow.features = source_workflow.serialized_features
draft_workflow.updated_by = account.id
draft_workflow.updated_at = updated_at_factory()
draft_workflow.copy_serialized_variable_storage_from(source_workflow)
return draft_workflow, False

View File

@ -63,7 +63,12 @@ from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeEx
from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError
from services.errors.app import (
IsDraftWorkflowError,
TriggerNodeLimitExceededError,
WorkflowHashNotEqualError,
WorkflowNotFoundError,
)
from services.workflow.workflow_converter import WorkflowConverter
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
@ -75,6 +80,7 @@ from .human_input_delivery_test_service import (
HumanInputDeliveryTestService,
)
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
from .workflow_restore import apply_published_workflow_snapshot_to_draft
class WorkflowService:
@ -279,6 +285,43 @@ class WorkflowService:
# return draft workflow
return workflow
def restore_published_workflow_to_draft(
self,
*,
app_model: App,
workflow_id: str,
account: Account,
) -> Workflow:
"""Restore a published workflow snapshot into the draft workflow.
Secret environment variables are copied server-side from the selected
published workflow so the normal draft sync flow stays stateless.
"""
source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
if not source_workflow:
raise WorkflowNotFoundError("Workflow not found.")
self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict)
self.validate_graph_structure(graph=source_workflow.graph_dict)
draft_workflow = self.get_draft_workflow(app_model=app_model)
draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
source_workflow=source_workflow,
draft_workflow=draft_workflow,
account=account,
updated_at_factory=naive_utc_now,
)
if is_new_draft:
db.session.add(draft_workflow)
db.session.commit()
app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow)
return draft_workflow
def publish_workflow(
self,
*,

View File

@ -13,6 +13,7 @@ from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.enums import ConversationFromSource
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -154,7 +155,7 @@ class TestChatMessageApiPermissions:
re_sign_file_url_answer="",
answer_tokens=0,
provider_response_latency=0.0,
from_source="console",
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=mock_account.id,
feedbacks=[],

View File

@ -165,8 +165,9 @@ class DifyTestContainers:
# Start Dify Sandbox container for code execution environment
# Dify Sandbox provides a secure environment for executing user code
# Use pinned version 0.2.12 to match production docker-compose configuration
logger.info("Initializing Dify Sandbox container...")
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network)
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network)
self.dify_sandbox.with_exposed_ports(8194)
self.dify_sandbox.env = {
"API_KEY": "test_api_key",

View File

@ -13,7 +13,7 @@ from libs.datetime_utils import naive_utc_now
from libs.token import _real_cookie_name, generate_csrf_token
from models import Account, DifySetup, Tenant, TenantAccountJoin
from models.account import AccountStatus, TenantAccountRole
from models.enums import CreatorUserRole
from models.enums import ConversationFromSource, CreatorUserRole
from models.model import App, AppMode, Conversation, Message
from models.workflow import WorkflowRun
from services.account_service import AccountService
@ -75,7 +75,7 @@ def _create_conversation(db_session: Session, app_id: str, account_id: str) -> C
inputs={},
status="normal",
mode=AppMode.CHAT,
from_source=CreatorUserRole.ACCOUNT,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
)
db_session.add(conversation)
@ -124,7 +124,7 @@ def _create_message(
answer_price_unit=0.001,
currency="USD",
status="normal",
from_source=CreatorUserRole.ACCOUNT,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
workflow_run_id=workflow_run_id,
inputs={"query": "Hello"},

View File

@ -7,6 +7,7 @@ from uuid import uuid4
from dify_graph.nodes.human_input.entities import FormDefinition, UserAction
from models.account import Account, Tenant, TenantAccountJoin
from models.enums import ConversationFromSource, InvokeFrom
from models.execution_extra_content import HumanInputContent
from models.human_input import HumanInputForm, HumanInputFormStatus
from models.model import App, Conversation, Message
@ -78,8 +79,8 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
introduction="",
system_instruction="",
status="normal",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account.id,
from_end_user_id=None,
)
@ -101,7 +102,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
answer_unit_price=Decimal("0.001"),
provider_response_latency=0.5,
currency="USD",
from_source="console",
from_source=ConversationFromSource.CONSOLE,
from_account_id=account.id,
workflow_run_id=workflow_run_id,
)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import secrets
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from unittest.mock import Mock
@ -12,15 +13,26 @@ from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from dify_graph.entities import WorkflowExecution
from dify_graph.entities.pause_reason import PauseReasonType
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.human_input import (
BackstageRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
_WorkflowRunError,
)
@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
WorkflowRun.app_id == scope.app_id,
)
)
form_ids_subquery = select(HumanInputForm.id).where(
HumanInputForm.tenant_id == scope.tenant_id,
HumanInputForm.app_id == scope.app_id,
)
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
session.execute(
delete(HumanInputForm).where(
HumanInputForm.tenant_id == scope.tenant_id,
HumanInputForm.app_id == scope.app_id,
)
)
session.commit()
for state_key in scope.state_keys:
@ -504,3 +529,200 @@ class TestDeleteWorkflowPause:
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
repository.delete_workflow_pause(pause_entity=pause_entity)
class TestPrivateWorkflowPauseEntity:
"""Integration tests for _PrivateWorkflowPauseEntity using real DB models."""
def test_properties(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Entity properties delegate to the persisted WorkflowPause model."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(pause.state_object_key)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
assert entity.id == pause.id
assert entity.workflow_execution_id == workflow_run.id
assert entity.resumed_at is None
def test_get_state(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""get_state loads state data from storage using the persisted state_object_key."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
state_key = f"workflow-state-{uuid4()}.json"
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=state_key,
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(state_key)
expected_state = b'{"test": "state"}'
storage.save(state_key, expected_state)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
result = entity.get_state()
assert result == expected_state
def test_get_state_caching(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""get_state caches the result so storage is only accessed once."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
state_key = f"workflow-state-{uuid4()}.json"
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=state_key,
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(state_key)
expected_state = b'{"test": "state"}'
storage.save(state_key, expected_state)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
result1 = entity.get_state()
# Delete from storage to prove second call uses cache
storage.delete(state_key)
test_scope.state_keys.discard(state_key)
result2 = entity.get_state()
assert result1 == expected_state
assert result2 == expected_state
class TestBuildHumanInputRequiredReason:
"""Integration tests for _build_human_input_required_reason using real DB models."""
def test_prefers_backstage_token_when_available(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Use backstage token when multiple recipient types may exist."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_run_id=str(uuid4()),
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
access_token = secrets.token_urlsafe(8)
recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=access_token,
)
db_session_with_containers.add(recipient)
db_session_with_containers.flush()
# Create a pause so the reason has a valid pause_id
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.flush()
test_scope.state_keys.add(pause.state_object_key)
reason_model = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=form_model.id,
node_id="node-1",
message="",
)
db_session_with_containers.add(reason_model)
db_session_with_containers.commit()
# Refresh to ensure we have DB-round-tripped objects
db_session_with_containers.refresh(form_model)
db_session_with_containers.refresh(reason_model)
db_session_with_containers.refresh(recipient)
reason = _build_human_input_required_reason(reason_model, form_model, [recipient])
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"
assert reason.form_content == "content"
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"

View File

@ -0,0 +1,391 @@
"""Integration tests for get_paginated_workflow_runs and get_workflow_runs_count using testcontainers."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import timedelta
from uuid import uuid4
import pytest
from sqlalchemy import Engine, delete
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import Session, sessionmaker
from dify_graph.entities import WorkflowExecution
from dify_graph.enums import WorkflowExecutionStatus
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun, WorkflowType
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
"""Concrete repository for tests where save() is not under test."""
def save(self, execution: WorkflowExecution) -> None:
return None
@dataclass
class _TestScope:
"""Per-test data scope used to isolate DB rows."""
tenant_id: str = field(default_factory=lambda: str(uuid4()))
app_id: str = field(default_factory=lambda: str(uuid4()))
workflow_id: str = field(default_factory=lambda: str(uuid4()))
user_id: str = field(default_factory=lambda: str(uuid4()))
def _create_workflow_run(
session: Session,
scope: _TestScope,
*,
status: WorkflowExecutionStatus,
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
created_at_offset: timedelta | None = None,
) -> WorkflowRun:
"""Create and persist a workflow run bound to the current test scope."""
now = naive_utc_now()
workflow_run = WorkflowRun(
id=str(uuid4()),
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_id=scope.workflow_id,
type=WorkflowType.WORKFLOW,
triggered_from=triggered_from,
version="draft",
graph="{}",
inputs="{}",
status=status,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=scope.user_id,
created_at=now + created_at_offset if created_at_offset is not None else now,
)
session.add(workflow_run)
session.commit()
return workflow_run
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
"""Remove test-created DB rows for a test scope."""
session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == scope.tenant_id,
WorkflowRun.app_id == scope.app_id,
)
)
session.commit()
@pytest.fixture
def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository:
"""Build a repository backed by the testcontainers database engine."""
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False))
@pytest.fixture
def test_scope(db_session_with_containers: Session) -> _TestScope:
"""Provide an isolated scope and clean related data after each test."""
scope = _TestScope()
yield scope
_cleanup_scope_data(db_session_with_containers, scope)
class TestGetPaginatedWorkflowRuns:
"""Integration tests for get_paginated_workflow_runs."""
def test_returns_runs_without_status_filter(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Return all runs for the given tenant/app when no status filter is applied."""
for status in (
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.RUNNING,
):
_create_workflow_run(db_session_with_containers, test_scope, status=status)
result = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status=None,
)
assert len(result.data) == 3
assert result.limit == 20
assert result.has_more is False
def test_filters_by_status(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Return only runs matching the requested status."""
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
result = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status="succeeded",
)
assert len(result.data) == 2
assert all(run.status == WorkflowExecutionStatus.SUCCEEDED for run in result.data)
def test_pagination_has_more(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Return has_more=True when more records exist beyond the limit."""
for i in range(5):
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_offset=timedelta(seconds=i),
)
result = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=3,
last_id=None,
status=None,
)
assert len(result.data) == 3
assert result.has_more is True
def test_cursor_based_pagination(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Cursor-based pagination returns the next page of results."""
for i in range(5):
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_offset=timedelta(seconds=i),
)
# First page
page1 = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=3,
last_id=None,
status=None,
)
assert len(page1.data) == 3
assert page1.has_more is True
# Second page using cursor
page2 = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=3,
last_id=page1.data[-1].id,
status=None,
)
assert len(page2.data) == 2
assert page2.has_more is False
# No overlap between pages
page1_ids = {r.id for r in page1.data}
page2_ids = {r.id for r in page2.data}
assert page1_ids.isdisjoint(page2_ids)
def test_invalid_last_id_raises(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
test_scope: _TestScope,
) -> None:
"""Raise ValueError when last_id refers to a non-existent run."""
with pytest.raises(ValueError, match="Last workflow run not exists"):
repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=str(uuid4()),
status=None,
)
def test_tenant_isolation(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Runs from other tenants are not returned."""
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
other_scope = _TestScope(app_id=test_scope.app_id)
try:
_create_workflow_run(db_session_with_containers, other_scope, status=WorkflowExecutionStatus.SUCCEEDED)
result = repository.get_paginated_workflow_runs(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status=None,
)
assert len(result.data) == 1
assert result.data[0].tenant_id == test_scope.tenant_id
finally:
_cleanup_scope_data(db_session_with_containers, other_scope)
class TestGetWorkflowRunsCount:
"""Integration tests for get_workflow_runs_count."""
def test_count_without_status_filter(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Count all runs grouped by status when no status filter is applied."""
for _ in range(3):
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
for _ in range(2):
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.RUNNING)
result = repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
)
assert result["total"] == 6
assert result["succeeded"] == 3
assert result["failed"] == 2
assert result["running"] == 1
assert result["stopped"] == 0
assert result["partial-succeeded"] == 0
def test_count_with_status_filter(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Count only runs matching the requested status."""
for _ in range(3):
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
result = repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="succeeded",
)
assert result["total"] == 3
assert result["succeeded"] == 3
assert result["failed"] == 0
def test_count_with_invalid_status_raises(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Invalid status raises StatementError because the column uses an enum type."""
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
with pytest.raises(sa_exc.StatementError) as exc_info:
repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="invalid_status",
)
assert isinstance(exc_info.value.orig, ValueError)
def test_count_with_time_range(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Time range filter excludes runs created outside the window."""
# Recent run (within 1 day)
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
# Old run (8 days ago)
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_offset=timedelta(days=-8),
)
result = repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
time_range="7d",
)
assert result["total"] == 1
assert result["succeeded"] == 1
def test_count_with_status_and_time_range(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Both status and time_range filters apply together."""
# Recent succeeded
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
# Recent failed
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
# Old succeeded (outside time range)
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
created_at_offset=timedelta(days=-8),
)
result = repository.get_workflow_runs_count(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="succeeded",
time_range="7d",
)
assert result["total"] == 1
assert result["succeeded"] == 1
assert result["failed"] == 0

View File

@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
from models.enums import MessageFileBelongsTo
from models.enums import ConversationFromSource, MessageFileBelongsTo
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
from services.agent_service import AgentService
@ -165,7 +165,7 @@ class TestAgentService:
inputs={},
status="normal",
mode="chat",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
@ -204,7 +204,7 @@ class TestAgentService:
answer_unit_price=0.001,
provider_response_latency=1.5,
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(message)
db_session_with_containers.commit()
@ -406,7 +406,7 @@ class TestAgentService:
inputs={},
status="normal",
mode="chat",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
@ -445,7 +445,7 @@ class TestAgentService:
answer_unit_price=0.001,
provider_response_latency=1.5,
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(message)
db_session_with_containers.commit()
@ -478,7 +478,7 @@ class TestAgentService:
inputs={},
status="normal",
mode="chat",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
@ -517,7 +517,7 @@ class TestAgentService:
answer_unit_price=0.001,
provider_response_latency=1.5,
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(message)
db_session_with_containers.commit()
@ -624,7 +624,7 @@ class TestAgentService:
inputs={},
status="normal",
mode="chat",
from_source="api",
from_source=ConversationFromSource.API,
app_model_config_id=None, # Explicitly set to None
)
db_session_with_containers.add(conversation)
@ -647,7 +647,7 @@ class TestAgentService:
answer_unit_price=0.001,
provider_response_latency=1.5,
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
db_session_with_containers.add(message)
db_session_with_containers.commit()

View File

@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from models import Account
from models.enums import ConversationFromSource, InvokeFrom
from models.model import MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.app_service import AppService
@ -136,8 +137,8 @@ class TestAnnotationService:
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account.id,
)
@ -174,8 +175,8 @@ class TestAnnotationService:
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account.id,
)
@ -721,7 +722,7 @@ class TestAnnotationService:
query=f"Query {i}: {fake.sentence()}",
user_id=account.id,
message_id=fake.uuid4(),
from_source="console",
from_source=ConversationFromSource.CONSOLE,
score=0.8 + (i * 0.1),
)
@ -772,7 +773,7 @@ class TestAnnotationService:
query=query,
user_id=account.id,
message_id=message_id,
from_source="console",
from_source=ConversationFromSource.CONSOLE,
score=score,
)

View File

@ -0,0 +1,80 @@
"""Testcontainers integration tests for AttachmentService."""
import base64
from datetime import UTC, datetime
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
import services.attachment_service as attachment_service_module
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.attachment_service import AttachmentService
class TestAttachmentService:
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id or str(uuid4()),
storage_type=StorageType.OPENDAL,
key=f"upload/{uuid4()}.txt",
name="test-file.txt",
size=100,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=datetime.now(UTC),
used=False,
)
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
def test_should_initialize_with_sessionmaker(self):
session_factory = sessionmaker()
service = AttachmentService(session_factory=session_factory)
assert service._session_maker is session_factory
def test_should_initialize_with_engine(self):
engine = create_engine("sqlite:///:memory:")
service = AttachmentService(session_factory=engine)
session = service._session_maker()
try:
assert session.bind == engine
finally:
session.close()
engine.dispose()
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory):
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
AttachmentService(session_factory=invalid_session_factory)
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
upload_file = self._create_upload_file(db_session_with_containers)
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
result = service.get_file_base64(upload_file.id)
assert result == base64.b64encode(b"binary-content").decode()
mock_load.assert_called_once_with(upload_file.key)
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
with pytest.raises(NotFound, match="File not found"):
service.get_file_base64(str(uuid4()))
mock_load.assert_not_called()

View File

@ -10,6 +10,7 @@ from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account, Tenant, TenantAccountJoin
from models.enums import ConversationFromSource
from models.model import App, Conversation, EndUser, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.conversation_service import ConversationService
@ -107,7 +108,7 @@ class ConversationServiceIntegrationTestDataFactory:
system_instruction_tokens=0,
status="normal",
invoke_from=invoke_from.value,
from_source="api" if isinstance(user, EndUser) else "console",
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,
dialogue_count=0,
@ -154,7 +155,7 @@ class ConversationServiceIntegrationTestDataFactory:
currency="USD",
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
from_source="api" if isinstance(user, EndUser) else "console",
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,
)

View File

@ -0,0 +1,58 @@
"""Testcontainers integration tests for ConversationVariableUpdater."""
from uuid import uuid4
import pytest
from sqlalchemy.orm import sessionmaker
from dify_graph.variables import StringVariable
from extensions.ext_database import db
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
class TestConversationVariableUpdater:
def _create_conversation_variable(
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
) -> ConversationVariable:
row = ConversationVariable(
id=variable.id,
conversation_id=conversation_id,
app_id=app_id or str(uuid4()),
data=variable.model_dump_json(),
)
db_session_with_containers.add(row)
db_session_with_containers.commit()
return row
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
self._create_conversation_variable(
db_session_with_containers, conversation_id=conversation_id, variable=variable
)
updated_variable = StringVariable(id=variable.id, name="topic", value="new value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
updater.update(conversation_id=conversation_id, variable=updated_variable)
db_session_with_containers.expire_all()
row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id))
assert row is not None
assert row.data == updated_variable.model_dump_json()
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
updater.update(conversation_id=conversation_id, variable=variable)
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
result = updater.flush()
assert result is None

View File

@ -0,0 +1,103 @@
"""Testcontainers integration tests for CreditPoolService."""
from uuid import uuid4
import pytest
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from services.credit_pool_service import CreditPoolService
class TestCreditPoolService:
def _create_tenant_id(self) -> str:
return str(uuid4())
def test_create_default_pool(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
assert isinstance(pool, TenantCreditPool)
assert pool.tenant_id == tenant_id
assert pool.pool_type == "trial"
assert pool.quota_used == 0
assert pool.quota_limit > 0
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial")
assert result is not None
assert result.tenant_id == tenant_id
assert result.pool_type == "trial"
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial")
assert result is None
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers):
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
assert result is False
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10)
assert result is True
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
# Exhaust credits
pool.quota_used = pool.quota_limit
db_session_with_containers.commit()
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1)
assert result is False
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers):
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
pool.quota_used = pool.quota_limit
db_session_with_containers.commit()
with pytest.raises(QuotaExceededError, match="No credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
credits_required = 10
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required)
assert result == credits_required
db_session_with_containers.expire_all()
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
db_session_with_containers.commit()
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
assert result == remaining
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == pool.quota_limit

View File

@ -397,6 +397,68 @@ class TestDatasetPermissionServiceClearPartialMemberList:
class TestDatasetServiceCheckDatasetPermission:
"""Verify dataset access checks against persisted partial-member permissions."""
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers):
"""Test that users from different tenants cannot access dataset."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM
)
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other_user)
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers):
"""Test that tenant owners can access any dataset regardless of permission level."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
DatasetService.check_dataset_permission(dataset, owner)
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers):
"""Test ONLY_ME permission allows only the dataset creator to access."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
DatasetService.check_dataset_permission(dataset, creator)
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers):
"""Test ONLY_ME permission denies access to non-creators."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other)
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers):
"""Test ALL_TEAM permission allows any team member to access the dataset."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM
)
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
"""
Test that user with explicit permission can access partial_members dataset.
@ -443,6 +505,16 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers):
"""Test PARTIAL_TEAM permission allows creator to access without explicit permission."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM
)
DatasetService.check_dataset_permission(dataset, creator)
class TestDatasetServiceCheckDatasetOperatorPermission:
"""Verify operator permission checks against persisted partial-member permissions."""

View File

@ -7,7 +7,7 @@ import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
from models.model import (
App,
AppAnnotationHitHistory,
@ -94,7 +94,7 @@ class TestAppMessageExportServiceIntegration:
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
@ -129,7 +129,7 @@ class TestAppMessageExportServiceIntegration:
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)

View File

@ -4,7 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import FeedbackRating
from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom
from models.model import MessageFeedback
from services.app_service import AppService
from services.errors.message import (
@ -149,8 +149,8 @@ class TestMessageService:
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account.id,
)
@ -187,8 +187,8 @@ class TestMessageService:
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from="console",
from_source="console",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=None,
from_account_id=account.id,
)

View File

@ -4,6 +4,7 @@ from decimal import Decimal
import pytest
from models.enums import ConversationFromSource
from models.model import Message
from services import message_service
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
@ -36,7 +37,7 @@ def test_attach_message_extra_contents_assigns_serialized_payload(db_session_wit
total_price=Decimal(0),
currency="USD",
status="normal",
from_source="console",
from_source=ConversationFromSource.CONSOLE,
from_account_id=fixture.account.id,
)
db_session_with_containers.add(message_without_extra_content)

View File

@ -11,7 +11,14 @@ from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
from models.enums import (
ConversationFromSource,
DataSourceType,
FeedbackFromSource,
FeedbackRating,
MessageChainType,
MessageFileBelongsTo,
)
from models.model import (
App,
AppAnnotationHitHistory,
@ -166,7 +173,7 @@ class TestMessagesCleanServiceIntegration:
name="Test conversation",
inputs={},
status="normal",
from_source=FeedbackFromSource.USER,
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid.uuid4()),
)
db_session_with_containers.add(conversation)
@ -196,7 +203,7 @@ class TestMessagesCleanServiceIntegration:
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
from_source=FeedbackFromSource.USER,
from_source=ConversationFromSource.API,
from_account_id=conversation.from_end_user_id,
created_at=created_at,
)

View File

@ -4,6 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import ConversationFromSource
from models.model import EndUser, Message
from models.web import SavedMessage
from services.app_service import AppService
@ -132,11 +133,14 @@ class TestSavedMessageService:
# Create a simple conversation first
from models.model import Conversation
is_account = hasattr(user, "current_tenant")
from_source = ConversationFromSource.CONSOLE if is_account else ConversationFromSource.API
conversation = Conversation(
app_id=app.id,
from_source="account" if hasattr(user, "current_tenant") else "end_user",
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
from_account_id=user.id if hasattr(user, "current_tenant") else None,
from_source=from_source,
from_end_user_id=user.id if not is_account else None,
from_account_id=user.id if is_account else None,
name=fake.sentence(nb_words=3),
inputs={},
status="normal",
@ -150,9 +154,9 @@ class TestSavedMessageService:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
from_source="account" if hasattr(user, "current_tenant") else "end_user",
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
from_account_id=user.id if hasattr(user, "current_tenant") else None,
from_source=from_source,
from_end_user_id=user.id if not is_account else None,
from_account_id=user.id if is_account else None,
inputs={},
query=fake.sentence(nb_words=5),
message=fake.text(max_nb_chars=100),

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset
from models.enums import DataSourceType
from models.enums import DataSourceType, TagType
from models.model import App, Tag, TagBinding
from services.tag_service import TagService
@ -547,7 +547,7 @@ class TestTagService:
assert result is not None
assert len(result) == 1
assert result[0].name == "python_tag"
assert result[0].type == "app"
assert result[0].type == TagType.APP
assert result[0].tenant_id == tenant.id
def test_get_tag_by_tag_name_no_matches(
@ -638,7 +638,7 @@ class TestTagService:
# Verify all tags are returned
for tag in result:
assert tag.type == "app"
assert tag.type == TagType.APP
assert tag.tenant_id == tenant.id
assert tag.id in [t.id for t in tags]

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models.enums import ConversationFromSource
from models.model import Conversation, EndUser
from models.web import PinnedConversation
from services.account_service import AccountService, TenantService
@ -145,7 +146,7 @@ class TestWebConversationService:
system_instruction_tokens=50,
status="normal",
invoke_from=InvokeFrom.WEB_APP,
from_source="console" if isinstance(user, Account) else "api",
from_source=ConversationFromSource.CONSOLE if isinstance(user, Account) else ConversationFromSource.API,
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,
dialogue_count=0,

View File

@ -7,7 +7,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import CreatorUserRole
from models.enums import ConversationFromSource, CreatorUserRole
from models.model import (
Message,
)
@ -165,7 +165,7 @@ class TestWorkflowRunService:
inputs={},
status="normal",
mode="chat",
from_source=CreatorUserRole.ACCOUNT,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account.id,
)
db_session_with_containers.add(conversation)
@ -186,7 +186,7 @@ class TestWorkflowRunService:
message.answer_price_unit = 0.001
message.currency = "USD"
message.status = "normal"
message.from_source = CreatorUserRole.ACCOUNT
message.from_source = ConversationFromSource.CONSOLE
message.from_account_id = account.id
message.workflow_run_id = workflow_run.id
message.inputs = {"input": "test input"}

View File

@ -802,6 +802,81 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="No valid workflow found"):
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account)
def test_restore_published_workflow_to_draft_does_not_persist_normalized_source_features(
self, db_session_with_containers: Session
):
"""Restore copies legacy feature JSON into draft without rewriting the source row."""
fake = Faker()
account = self._create_test_account(db_session_with_containers, fake)
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.ADVANCED_CHAT
legacy_features = {
"file_upload": {
"image": {
"enabled": True,
"number_limits": 6,
"transfer_methods": ["remote_url", "local_file"],
}
},
"opening_statement": "",
"retriever_resource": {"enabled": True},
"sensitive_word_avoidance": {"enabled": False},
"speech_to_text": {"enabled": False},
"suggested_questions": [],
"suggested_questions_after_answer": {"enabled": False},
"text_to_speech": {"enabled": False, "language": "", "voice": ""},
}
published_workflow = Workflow(
id=fake.uuid4(),
tenant_id=app.tenant_id,
app_id=app.id,
type=WorkflowType.WORKFLOW,
version="2026.03.19.001",
graph=json.dumps({"nodes": [], "edges": []}),
features=json.dumps(legacy_features),
created_by=account.id,
updated_by=account.id,
environment_variables=[],
conversation_variables=[],
)
draft_workflow = Workflow(
id=fake.uuid4(),
tenant_id=app.tenant_id,
app_id=app.id,
type=WorkflowType.WORKFLOW,
version=Workflow.VERSION_DRAFT,
graph=json.dumps({"nodes": [], "edges": []}),
features=json.dumps({}),
created_by=account.id,
updated_by=account.id,
environment_variables=[],
conversation_variables=[],
)
db_session_with_containers.add(published_workflow)
db_session_with_containers.add(draft_workflow)
db_session_with_containers.commit()
workflow_service = WorkflowService()
restored_workflow = workflow_service.restore_published_workflow_to_draft(
app_model=app,
workflow_id=published_workflow.id,
account=account,
)
db_session_with_containers.expire_all()
refreshed_published_workflow = (
db_session_with_containers.query(Workflow).filter_by(id=published_workflow.id).first()
)
refreshed_draft_workflow = db_session_with_containers.query(Workflow).filter_by(id=draft_workflow.id).first()
assert restored_workflow.id == draft_workflow.id
assert refreshed_published_workflow is not None
assert refreshed_draft_workflow is not None
assert refreshed_published_workflow.serialized_features == json.dumps(legacy_features)
assert refreshed_draft_workflow.serialized_features == json.dumps(legacy_features)
def test_get_default_block_configs(self, db_session_with_containers: Session):
"""
Test retrieval of default block configurations for all node types.

View File

@ -0,0 +1,158 @@
"""Testcontainers integration tests for WorkflowService.delete_workflow."""
import json
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session, sessionmaker
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin
from models.model import App
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
class TestWorkflowDeletion:
def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]:
tenant = Tenant(name=f"Tenant {uuid4()}")
session.add(tenant)
session.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"wf_del_{uuid4()}@example.com",
password="hashed",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
session.add(account)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
session.add(join)
session.flush()
return tenant, account
def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App:
app = App(
tenant_id=tenant.id,
name=f"App {uuid4()}",
description="",
mode="workflow",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
workflow_id=workflow_id,
)
session.add(app)
session.flush()
return app
def _create_workflow(
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0"
) -> Workflow:
workflow = Workflow(
id=str(uuid4()),
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
version=version,
graph=json.dumps({"nodes": [], "edges": []}),
_features=json.dumps({}),
created_by=account.id,
updated_by=account.id,
)
session.add(workflow)
session.flush()
return workflow
def _create_tool_provider(
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str
) -> WorkflowToolProvider:
provider = WorkflowToolProvider(
name=f"tool-{uuid4()}",
label=f"Tool {uuid4()}",
icon="wrench",
app_id=app.id,
version=version,
user_id=account.id,
tenant_id=tenant.id,
description="test tool provider",
)
session.add(provider)
session.flush()
return provider
def test_delete_workflow_success(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
db_session_with_containers.commit()
workflow_id = workflow.id
service = WorkflowService(sessionmaker(bind=db.engine))
result = service.delete_workflow(
session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id
)
assert result is True
db_session_with_containers.expire_all()
assert db_session_with_containers.get(Workflow, workflow_id) is None
def test_delete_draft_workflow_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="draft"
)
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(DraftWorkflowDeletionError):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
# Point app to this workflow
app.workflow_id = workflow.id
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(WorkflowInUseError, match="currently in use by app"):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0")
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(WorkflowInUseError, match="published as a tool"):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)

View File

@ -0,0 +1,320 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.local import LocalProxy
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.app.message import (
ChatMessageListApi,
ChatMessagesQuery,
FeedbackExportQuery,
MessageAnnotationCountApi,
MessageApi,
MessageFeedbackApi,
MessageFeedbackExportApi,
MessageFeedbackPayload,
MessageSuggestedQuestionApi,
)
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from models import App, AppMode
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.CHAT
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
import contextlib
@contextlib.contextmanager
def setup_test_context(
test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None
):
with (
patch("extensions.ext_database.db") as mock_db,
patch("controllers.console.app.wraps.db", mock_db),
patch("controllers.console.wraps.db", mock_db),
patch("controllers.console.app.message.db", mock_db),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
# Set up a generic query mock that usually returns mock_app_model when getting app
app_query_mock = MagicMock()
app_query_mock.filter.return_value.first.return_value = mock_app_model
app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model
app_query_mock.where.return_value.first.return_value = mock_app_model
app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model
data_query_mock = MagicMock()
def query_side_effect(*args, **kwargs):
if args and hasattr(args[0], "__name__") and args[0].__name__ == "App":
return app_query_mock
return data_query_mock
mock_db.session.query.side_effect = query_side_effect
mock_db.data_query = data_query_mock
# Let the caller override the stat db query logic
proxy_mock = LocalProxy(lambda: mock_account)
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
full_path = f"{route_path}?{query_string}" if qs else route_path
with (
patch("libs.login.current_user", proxy_mock),
patch("flask_login.current_user", proxy_mock),
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
):
with test_app.test_request_context(full_path, method=method, json=payload):
request.view_args = {"app_id": "app_123"}
if "suggested-questions" in route_path:
# simplistic extraction for message_id
parts = route_path.split("chat-messages/")
if len(parts) > 1:
request.view_args["message_id"] = parts[1].split("/")[0]
elif "messages/" in route_path and "chat-messages" not in route_path:
parts = route_path.split("messages/")
if len(parts) > 1:
request.view_args["message_id"] = parts[1].split("/")[0]
api_instance = endpoint_class()
# Check if it has a dispatch_request or method
if hasattr(api_instance, method.lower()):
yield api_instance, mock_db, request.view_args
class TestMessageValidators:
def test_chat_messages_query_validators(self):
# Test empty_to_none
assert ChatMessagesQuery.empty_to_none("") is None
assert ChatMessagesQuery.empty_to_none("val") == "val"
# Test validate_uuid
assert ChatMessagesQuery.validate_uuid(None) is None
assert (
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_message_feedback_validators(self):
assert (
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_feedback_export_validators(self):
assert FeedbackExportQuery.parse_bool(None) is None
assert FeedbackExportQuery.parse_bool(True) is True
assert FeedbackExportQuery.parse_bool("1") is True
assert FeedbackExportQuery.parse_bool("0") is False
assert FeedbackExportQuery.parse_bool("off") is False
with pytest.raises(ValueError):
FeedbackExportQuery.parse_bool("invalid")
class TestMessageEndpoints:
def test_chat_message_list_not_found(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
ChatMessageListApi,
"/apps/app_123/chat-messages",
"GET",
mock_account,
mock_app_model,
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args):
mock_db.session.scalar.return_value = None
with pytest.raises(NotFound):
api.get(**v_args)
def test_chat_message_list_success(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
ChatMessageListApi,
"/apps/app_123/chat-messages",
"GET",
mock_account,
mock_app_model,
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1},
) as (api, mock_db, v_args):
mock_conv = MagicMock()
mock_conv.id = "123e4567-e89b-12d3-a456-426614174000"
mock_msg = MagicMock()
mock_msg.id = "msg_123"
mock_msg.feedbacks = []
mock_msg.annotation = None
mock_msg.annotation_hit_history = None
mock_msg.agent_thoughts = []
mock_msg.message_files = []
mock_msg.extra_contents = []
mock_msg.message = {}
mock_msg.message_metadata_dict = {}
# scalar() is called twice: first for conversation lookup, second for has_more check
mock_db.session.scalar.side_effect = [mock_conv, False]
scalars_result = MagicMock()
scalars_result.all.return_value = [mock_msg]
mock_db.session.scalars.return_value = scalars_result
resp = api.get(**v_args)
assert resp["limit"] == 1
assert resp["has_more"] is False
assert len(resp["data"]) == 1
def test_message_feedback_not_found(self, app, mock_account, mock_app_model):
with setup_test_context(
app,
MessageFeedbackApi,
"/apps/app_123/feedbacks",
"POST",
mock_account,
mock_app_model,
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args):
mock_db.session.scalar.return_value = None
with pytest.raises(NotFound):
api.post(**v_args)
def test_message_feedback_success(self, app, mock_account, mock_app_model):
payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"}
with setup_test_context(
app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload
) as (api, mock_db, v_args):
mock_msg = MagicMock()
mock_msg.admin_feedback = None
mock_db.session.scalar.return_value = mock_msg
resp = api.post(**v_args)
assert resp == {"result": "success"}
def test_message_annotation_count(self, app, mock_account, mock_app_model):
with setup_test_context(
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
mock_db.session.scalar.return_value = 5
resp = api.get(**v_args)
assert resp == {"count": 5}
@patch("controllers.console.app.message.MessageService")
def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model):
mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"]
with setup_test_context(
app,
MessageSuggestedQuestionApi,
"/apps/app_123/chat-messages/msg_123/suggested-questions",
"GET",
mock_account,
mock_app_model,
) as (api, mock_db, v_args):
resp = api.get(**v_args)
assert resp == {"data": ["q1", "q2"]}
@pytest.mark.parametrize(
("exc", "expected_exc"),
[
(MessageNotExistsError, NotFound),
(ConversationNotExistsError, NotFound),
(ProviderTokenNotInitError, ProviderNotInitializeError),
(QuotaExceededError, ProviderQuotaExceededError),
(ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError),
(SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError),
(Exception, InternalServerError),
],
)
@patch("controllers.console.app.message.MessageService")
def test_message_suggested_questions_errors(
self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model
):
mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc()
with setup_test_context(
app,
MessageSuggestedQuestionApi,
"/apps/app_123/chat-messages/msg_123/suggested-questions",
"GET",
mock_account,
mock_app_model,
) as (api, mock_db, v_args):
with pytest.raises(expected_exc):
api.get(**v_args)
@patch("services.feedback_service.FeedbackService.export_feedbacks")
def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model):
mock_export.return_value = {"exported": True}
with setup_test_context(
app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
resp = api.get(**v_args)
assert resp == {"exported": True}
def test_message_api_get_success(self, app, mock_account, mock_app_model):
with setup_test_context(
app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
mock_msg = MagicMock()
mock_msg.id = "msg_123"
mock_msg.feedbacks = []
mock_msg.annotation = None
mock_msg.annotation_hit_history = None
mock_msg.agent_thoughts = []
mock_msg.message_files = []
mock_msg.extra_contents = []
mock_msg.message = {}
mock_msg.message_metadata_dict = {}
mock_db.session.scalar.return_value = mock_msg
resp = api.get(**v_args)
assert resp["id"] == "msg_123"

View File

@ -0,0 +1,275 @@
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.statistic import (
AverageResponseTimeStatistic,
AverageSessionInteractionStatistic,
DailyConversationStatistic,
DailyMessageStatistic,
DailyTerminalsStatistic,
DailyTokenCostStatistic,
TokensPerSecondStatistic,
UserSatisfactionRateStatistic,
)
from models import App, AppMode
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.CHAT
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
def setup_test_context(
test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None)
):
with (
patch("controllers.console.app.statistic.db") as mock_db_stat,
patch("controllers.console.app.wraps.db") as mock_db_wraps,
patch("controllers.console.wraps.db", mock_db_wraps),
patch(
"controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123")
),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
mock_conn = MagicMock()
mock_conn.execute.return_value = mock_rs
mock_begin = MagicMock()
mock_begin.__enter__.return_value = mock_conn
mock_db_stat.engine.begin.return_value = mock_begin
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = mock_app_model
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
mock_query.where.return_value.first.return_value = mock_app_model
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with test_app.test_request_context(route_path, method="GET"):
request.view_args = {"app_id": "app_123"}
api_instance = endpoint_class()
response = api_instance.get(app_id="app_123")
return response
class TestStatisticEndpoints:
def test_daily_message_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.message_count = 10
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["message_count"] == 10
def test_daily_conversation_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.conversation_count = 5
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyConversationStatistic,
"/apps/app_123/statistics/daily-conversations",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["conversation_count"] == 5
def test_daily_terminals_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.terminal_count = 2
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyTerminalsStatistic,
"/apps/app_123/statistics/daily-end-users",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["terminal_count"] == 2
def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.token_count = 100
mock_row.total_price = Decimal("0.02")
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
DailyTokenCostStatistic,
"/apps/app_123/statistics/token-costs",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["token_count"] == 100
assert response.json["data"][0]["total_price"] == "0.02"
def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.interactions = Decimal("3.523")
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
AverageSessionInteractionStatistic,
"/apps/app_123/statistics/average-session-interactions",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["interactions"] == 3.52
def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.message_count = 100
mock_row.feedback_count = 10
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
UserSatisfactionRateStatistic,
"/apps/app_123/statistics/user-satisfaction-rate",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["rate"] == 100.0
def test_average_response_time_statistic(self, app, mock_account, mock_app_model):
mock_app_model.mode = AppMode.COMPLETION
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.latency = 1.234
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
AverageResponseTimeStatistic,
"/apps/app_123/statistics/average-response-time",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["latency"] == 1234.0
def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model):
mock_row = MagicMock()
mock_row.date = "2023-01-01"
mock_row.tokens_per_second = 15.5
mock_row.interactions = Decimal(0)
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
response = setup_test_context(
app,
TokensPerSecondStatistic,
"/apps/app_123/statistics/tokens-per-second",
mock_account,
mock_app_model,
[mock_row],
)
assert response.status_code == 200
assert response.json["data"][0]["tps"] == 15.5
@patch("controllers.console.app.statistic.parse_time_range")
def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model):
mock_parse.side_effect = ValueError("Invalid time")
from werkzeug.exceptions import BadRequest
with pytest.raises(BadRequest):
setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=invalid&end=invalid",
mock_account,
mock_app_model,
[],
)
@patch("controllers.console.app.statistic.parse_time_range")
def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model):
import datetime
start = datetime.datetime.now()
end = datetime.datetime.now()
mock_parse.return_value = (start, end)
response = setup_test_context(
app,
DailyMessageStatistic,
"/apps/app_123/statistics/daily-messages?start=something&end=something",
mock_account,
mock_app_model,
[],
)
assert response.status_code == 200
mock_parse.assert_called_once()

View File

@ -129,6 +129,136 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch)
handler(api, app_model=SimpleNamespace(id="app"))
def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(
unique_hash="restored-hash",
updated_at=None,
created_at=datetime(2024, 1, 1),
)
user = SimpleNamespace(id="account-1")
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
lambda: SimpleNamespace(restore_published_workflow_to_draft=lambda **_kwargs: workflow),
)
api = workflow_module.DraftWorkflowRestoreApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/published-workflow/restore",
method="POST",
):
response = handler(
api,
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
workflow_id="published-workflow",
)
assert response["result"] == "success"
assert response["hash"] == "restored-hash"
def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(id="account-1")
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
lambda: SimpleNamespace(
restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw(
workflow_module.WorkflowNotFoundError("Workflow not found")
)
),
)
api = workflow_module.DraftWorkflowRestoreApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/published-workflow/restore",
method="POST",
):
with pytest.raises(NotFound):
handler(
api,
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
workflow_id="published-workflow",
)
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(id="account-1")
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
lambda: SimpleNamespace(
restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw(
workflow_module.IsDraftWorkflowError(
"Cannot use draft workflow version. Workflow ID: draft-workflow. "
"Please use a published workflow version or leave workflow_id empty."
)
)
),
)
api = workflow_module.DraftWorkflowRestoreApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft-workflow/restore",
method="POST",
):
with pytest.raises(HTTPException) as exc:
handler(
api,
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
workflow_id="draft-workflow",
)
assert exc.value.code == 400
assert exc.value.description == workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE
def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
user = SimpleNamespace(id="account-1")
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
lambda: SimpleNamespace(
restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw(
ValueError("invalid workflow graph")
)
),
)
api = workflow_module.DraftWorkflowRestoreApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/published-workflow/restore",
method="POST",
):
with pytest.raises(HTTPException) as exc:
handler(
api,
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
workflow_id="published-workflow",
)
assert exc.value.code == 400
assert exc.value.description == "invalid workflow graph"
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)

View File

@ -0,0 +1,313 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
from werkzeug.local import LocalProxy
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.workflow_draft_variable import (
ConversationVariableCollectionApi,
EnvironmentVariableCollectionApi,
NodeVariableCollectionApi,
SystemVariableCollectionApi,
VariableApi,
VariableResetApi,
WorkflowVariableCollectionApi,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from models import App, AppMode
from models.enums import DraftVariableType
@pytest.fixture
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
return flask_app
@pytest.fixture
def mock_account():
from models.account import Account, AccountStatus
account = MagicMock(spec=Account)
account.id = "user_123"
account.timezone = "UTC"
account.status = AccountStatus.ACTIVE
account.is_admin_or_owner = True
account.current_tenant.current_role = "owner"
account.has_edit_permission = True
return account
@pytest.fixture
def mock_app_model():
app_model = MagicMock(spec=App)
app_model.id = "app_123"
app_model.mode = AppMode.WORKFLOW
app_model.tenant_id = "tenant_123"
return app_model
@pytest.fixture(autouse=True)
def mock_csrf():
with patch("libs.login.check_csrf_token") as mock:
yield mock
def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None):
with (
patch("controllers.console.app.wraps.db") as mock_db_wraps,
patch("controllers.console.wraps.db", mock_db_wraps),
patch("controllers.console.app.workflow_draft_variable.db"),
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
):
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = mock_app_model
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
mock_query.where.return_value.first.return_value = mock_app_model
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
mock_db_wraps.session.query.return_value = mock_query
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
with test_app.test_request_context(route_path, method=method, json=payload):
request.view_args = {"app_id": "app_123"}
# extract node_id or variable_id from path manually since view_args overrides
if "nodes/" in route_path:
request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0]
if "variables/" in route_path:
# simplistic extraction
parts = route_path.split("variables/")
if len(parts) > 1 and parts[1] and parts[1] != "reset":
request.view_args["variable_id"] = parts[1].split("/")[0]
api_instance = endpoint_class()
# we just call dispatch_request to avoid manual argument passing
if hasattr(api_instance, method.lower()):
func = getattr(api_instance, method.lower())
return func(**request.view_args)
class TestWorkflowDraftVariableEndpoints:
@staticmethod
def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock:
class DummyValueType:
def exposed_type(self):
return DraftVariableType.NODE
mock_var = MagicMock()
mock_var.app_id = "app_123"
mock_var.id = "var_123"
mock_var.name = "test_var"
mock_var.description = ""
mock_var.get_variable_type.return_value = variable_type
mock_var.get_selector.return_value = []
mock_var.value_type = DummyValueType()
mock_var.edited = False
mock_var.visible = True
mock_var.file_id = None
mock_var.variable_file = None
mock_var.is_truncated.return_value = False
mock_var.get_value.return_value.model_copy.return_value.value = "test_value"
return mock_var
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_workflow_variable_collection_get_success(
self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model
):
mock_wf_srv.return_value.is_workflow_exist.return_value = True
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList(
variables=[], total=0
)
resp = setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables?page=1&limit=20",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": [], "total": 0}
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.is_workflow_exist.return_value = False
with pytest.raises(DraftWorkflowNotExist):
setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables",
"GET",
mock_account,
mock_app_model,
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
resp = setup_test_context(
app,
WorkflowVariableCollectionApi,
"/apps/app_123/workflows/draft/variables",
"DELETE",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/node_123/variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model):
with pytest.raises(InvalidArgumentError):
setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/sys/variables",
"GET",
mock_account,
mock_app_model,
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
resp = setup_test_context(
app,
NodeVariableCollectionApi,
"/apps/app_123/workflows/draft/nodes/node_123/variables",
"DELETE",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
)
assert resp["id"] == "var_123"
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = None
with pytest.raises(NotFoundError):
setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
)
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app,
VariableApi,
"/apps/app_123/workflows/draft/variables/var_123",
"PATCH",
mock_account,
mock_app_model,
payload={"name": "new_name"},
)
assert resp["id"] == "var_123"
mock_draft_srv.return_value.update_variable.assert_called_once()
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model):
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
resp = setup_test_context(
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model
)
assert resp.status_code == 204
mock_draft_srv.return_value.delete_variable.assert_called_once()
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
mock_draft_srv.return_value.reset_variable.return_value = None # means no content
resp = setup_test_context(
app,
VariableResetApi,
"/apps/app_123/workflows/draft/variables/var_123/reset",
"PUT",
mock_account,
mock_app_model,
)
assert resp.status_code == 204
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
ConversationVariableCollectionApi,
"/apps/app_123/workflows/draft/conversation-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model):
from services.workflow_draft_variable_service import WorkflowDraftVariableList
mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[])
resp = setup_test_context(
app,
SystemVariableCollectionApi,
"/apps/app_123/workflows/draft/system-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model):
mock_wf = MagicMock()
mock_wf.environment_variables = []
mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf
resp = setup_test_context(
app,
EnvironmentVariableCollectionApi,
"/apps/app_123/workflows/draft/environment-variables",
"GET",
mock_account,
mock_app_model,
)
assert resp == {"items": []}

View File

@ -0,0 +1,209 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.data_source_bearer_auth import (
ApiKeyAuthDataSource,
ApiKeyAuthDataSourceBinding,
ApiKeyAuthDataSourceBindingDelete,
)
from controllers.console.auth.error import ApiKeyAuthFailedError
class TestApiKeyAuthDataSource:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_binding = MagicMock()
mock_binding.id = "bind_123"
mock_binding.category = "api_key"
mock_binding.provider = "custom_provider"
mock_binding.disabled = False
mock_binding.created_at.timestamp.return_value = 1620000000
mock_binding.updated_at.timestamp.return_value = 1620000001
mock_get_list.return_value = [mock_binding]
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSource()
response = api_instance.get()
assert "sources" in response
assert len(response["sources"]) == 1
assert response["sources"][0]["provider"] == "custom_provider"
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_get_list.return_value = None
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSource()
response = api_instance.get()
assert "sources" in response
assert len(response["sources"]) == 0
class TestApiKeyAuthDataSourceBinding:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context(
"/console/api/api-key-auth/data-source/binding",
method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBinding()
response = api_instance.post()
assert response[0]["result"] == "success"
assert response[1] == 200
mock_validate.assert_called_once()
mock_create.assert_called_once()
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_create.side_effect = ValueError("Invalid structure")
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context(
"/console/api/api-key-auth/data-source/binding",
method="POST",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBinding()
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
api_instance.post()
class TestApiKeyAuthDataSourceBindingDelete:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False
return app
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth")
def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app):
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with (
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
patch(
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
return_value=(mock_account, "tenant_123"),
),
):
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
proxy_mock = MagicMock()
proxy_mock._get_current_object.return_value = mock_account
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
api_instance = ApiKeyAuthDataSourceBindingDelete()
response = api_instance.delete("binding_123")
assert response[0]["result"] == "success"
assert response[1] == 204
mock_delete.assert_called_once_with("tenant_123", "binding_123")

View File

@ -0,0 +1,192 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.local import LocalProxy
from controllers.console.auth.data_source_oauth import (
OAuthDataSource,
OAuthDataSourceBinding,
OAuthDataSourceCallback,
OAuthDataSourceSync,
)
class TestOAuthDataSource:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("flask_login.current_user")
@patch("libs.login.current_user")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
def test_get_oauth_url_successful(
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
):
mock_oauth_provider = MagicMock()
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
mock_get_providers.return_value = {"notion": mock_oauth_provider}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
mock_libs_user.return_value = mock_account
mock_flask_user.return_value = mock_account
# also patch current_account_with_tenant
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSource()
response = api_instance.get("notion")
assert response[0]["data"] == "http://oauth.provider/auth"
assert response[1] == 200
mock_oauth_provider.get_authorization_url.assert_called_once()
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("flask_login.current_user")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSource()
response = api_instance.get("unknown_provider")
assert response[0]["error"] == "Invalid provider"
assert response[1] == 400
class TestOAuthDataSourceCallback:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_successful(self, mock_get_providers, app):
provider_mock = MagicMock()
mock_get_providers.return_value = {"notion": provider_mock}
with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("notion")
assert response.status_code == 302
assert "code=mock_code" in response.location
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_missing_code(self, mock_get_providers, app):
provider_mock = MagicMock()
mock_get_providers.return_value = {"notion": provider_mock}
with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("notion")
assert response.status_code == 302
assert "error=Access denied" in response.location
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_oauth_callback_invalid_provider(self, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"):
api_instance = OAuthDataSourceCallback()
response = api_instance.get("invalid")
assert response[0]["error"] == "Invalid provider"
assert response[1] == 400
class TestOAuthDataSourceBinding:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_get_binding_successful(self, mock_get_providers, app):
mock_provider = MagicMock()
mock_provider.get_access_token.return_value = None
mock_get_providers.return_value = {"notion": mock_provider}
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"):
api_instance = OAuthDataSourceBinding()
response = api_instance.get("notion")
assert response[0]["result"] == "success"
assert response[1] == 200
mock_provider.get_access_token.assert_called_once_with("auth_code_123")
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
def test_get_binding_missing_code(self, mock_get_providers, app):
mock_get_providers.return_value = {"notion": MagicMock()}
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"):
api_instance = OAuthDataSourceBinding()
response = api_instance.get("notion")
assert response[0]["error"] == "Invalid code"
assert response[1] == 400
class TestOAuthDataSourceSync:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
@patch("libs.login.check_csrf_token")
@patch("controllers.console.wraps.db")
def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app):
mock_provider = MagicMock()
mock_provider.sync_data_source.return_value = None
mock_get_providers.return_value = {"notion": mock_provider}
from models.account import Account, AccountStatus
mock_account = MagicMock(spec=Account)
mock_account.id = "user_123"
mock_account.status = AccountStatus.ACTIVE
mock_account.is_admin_or_owner = True
mock_account.current_tenant.current_role = "owner"
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
proxy_mock = LocalProxy(lambda: mock_account)
with patch("libs.login.current_user", proxy_mock):
api_instance = OAuthDataSourceSync()
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
response = api_instance.get("notion", "binding_123")
assert response[0]["result"] == "success"
assert response[1] == 200
mock_provider.sync_data_source.assert_called_once_with("binding_123")

View File

@ -0,0 +1,417 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.auth.oauth_server import (
OAuthServerAppApi,
OAuthServerUserAccountApi,
OAuthServerUserAuthorizeApi,
OAuthServerUserTokenApi,
)
class TestOAuthServerAppApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
oauth_app.redirect_uris = ["http://localhost/callback"]
oauth_app.app_icon = "icon_url"
oauth_app.app_label = "Test App"
oauth_app.scope = "read,write"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
):
api_instance = OAuthServerAppApi()
response = api_instance.post()
assert response["app_icon"] == "icon_url"
assert response["app_label"] == "Test App"
assert response["scope"] == "read,write"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
):
api_instance = OAuthServerAppApi()
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_client_id(self, mock_get_app, mock_db, app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = None
with app.test_request_context(
"/oauth/provider",
method="POST",
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
):
api_instance = OAuthServerAppApi()
with pytest.raises(NotFound, match="client_id is invalid"):
api_instance.post()
class TestOAuthServerUserAuthorizeApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
oauth_app = MagicMock()
oauth_app.client_id = "test_client_id"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.current_account_with_tenant")
@patch("controllers.console.wraps.current_account_with_tenant")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code")
@patch("libs.login.check_csrf_token")
def test_successful_authorize(
self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_account = MagicMock()
mock_account.id = "user_123"
from models.account import AccountStatus
mock_account.status = AccountStatus.ACTIVE
mock_current.return_value = (mock_account, MagicMock())
mock_wrap_current.return_value = (mock_account, MagicMock())
mock_sign.return_value = "auth_code_123"
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
with patch("libs.login.current_user", mock_account):
api_instance = OAuthServerUserAuthorizeApi()
response = api_instance.post()
assert response["code"] == "auth_code_123"
mock_sign.assert_called_once_with("test_client_id", "user_123")
class TestOAuthServerUserTokenApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
oauth_app.client_secret = "test_secret"
oauth_app.redirect_uris = ["http://localhost/callback"]
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_sign.return_value = ("access_123", "refresh_123")
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
response = api_instance.post()
assert response["access_token"] == "access_123"
assert response["refresh_token"] == "refresh_123"
assert response["token_type"] == "Bearer"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="code is required"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "invalid_secret",
"redirect_uri": "http://localhost/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="client_secret is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://invalid/callback",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_sign.return_value = ("new_access", "new_refresh")
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
):
api_instance = OAuthServerUserTokenApi()
response = api_instance.post()
assert response["access_token"] == "new_access"
assert response["refresh_token"] == "new_refresh"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "refresh_token",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="refresh_token is required"):
api_instance.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/token",
method="POST",
json={
"client_id": "test_client_id",
"grant_type": "invalid_grant",
},
):
api_instance = OAuthServerUserTokenApi()
with pytest.raises(BadRequest, match="invalid grant_type"):
api_instance.post()
class TestOAuthServerUserAccountApi:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider_app(self):
from models.model import OAuthProviderApp
oauth_app = MagicMock(spec=OAuthProviderApp)
oauth_app.client_id = "test_client_id"
return oauth_app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_account = MagicMock()
mock_account.name = "Test User"
mock_account.email = "test@example.com"
mock_account.avatar = "avatar_url"
mock_account.interface_language = "en-US"
mock_account.timezone = "UTC"
mock_validate.return_value = mock_account
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer valid_access_token"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response["name"] == "Test User"
assert response["email"] == "test@example.com"
assert response["avatar"] == "avatar_url"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Authorization header is required"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "InvalidFormat"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Invalid Authorization header format"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Basic something"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "token_type is invalid"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer "},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "Invalid Authorization header format"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_app.return_value = mock_oauth_provider_app
mock_validate.return_value = None
with app.test_request_context(
"/oauth/provider/account",
method="POST",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer invalid_token"},
):
api_instance = OAuthServerUserAccountApi()
response = api_instance.post()
assert response.status_code == 401
assert response.json["error"] == "access_token or client_id is invalid"

View File

@ -2,7 +2,7 @@ from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import Forbidden, HTTPException, NotFound
import services
from controllers.console import console_ns
@ -19,13 +19,14 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
RagPipelineDraftNodeRunApi,
RagPipelineDraftRunIterationNodeApi,
RagPipelineDraftRunLoopNodeApi,
RagPipelineDraftWorkflowRestoreApi,
RagPipelineRecommendedPluginApi,
RagPipelineTaskStopApi,
RagPipelineTransformApi,
RagPipelineWorkflowLastRunApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from services.errors.app import WorkflowHashNotEqualError
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
@ -116,6 +117,86 @@ class TestDraftWorkflowApi:
response, status = method(api, pipeline)
assert status == 400
def test_restore_published_workflow_to_draft_success(self, app):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="account-1")
workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1))
service = MagicMock()
service.restore_published_workflow_to_draft.return_value = workflow
with (
app.test_request_context("/", method="POST"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "published-workflow")
assert result["result"] == "success"
assert result["hash"] == "restored-hash"
def test_restore_published_workflow_to_draft_not_found(self, app):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="account-1")
service = MagicMock()
service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found")
with (
app.test_request_context("/", method="POST"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(NotFound):
method(api, pipeline, "published-workflow")
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="account-1")
service = MagicMock()
service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError(
"source workflow must be published"
)
with (
app.test_request_context("/", method="POST"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(HTTPException) as exc:
method(api, pipeline, "draft-workflow")
assert exc.value.code == 400
assert exc.value.description == "source workflow must be published"
class TestDraftRunNodes:
def test_iteration_node_success(self, app):

View File

@ -24,13 +24,8 @@ class TestBannerApi:
banner.status = BannerStatus.ENABLED
banner.created_at = datetime(2024, 1, 1)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = [banner]
session = MagicMock()
session.query.return_value = query
session.scalars.return_value.all.return_value = [banner]
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
result = method(api)
@ -58,16 +53,14 @@ class TestBannerApi:
banner.status = BannerStatus.ENABLED
banner.created_at = None
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.side_effect = [
scalars_result = MagicMock()
scalars_result.all.side_effect = [
[],
[banner],
]
session = MagicMock()
session.query.return_value = query
session.scalars.return_value = scalars_result
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
result = method(api)
@ -87,13 +80,8 @@ class TestBannerApi:
api = banner_module.BannerApi()
method = unwrap(api.get)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = []
session = MagicMock()
session.query.return_value = query
session.scalars.return_value.all.return_value = []
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
result = method(api)

View File

@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi:
app_entity.tenant_id = "t2"
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
None,
]
# scalar() is called for recommended_app and installed_app lookups
session.scalar.side_effect = [recommended, None]
# get() is called for app PK lookup
session.get.return_value = app_entity
with (
app.test_request_context("/", json={"app_id": "a1"}),
@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi:
method = unwrap(api.post)
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with (
app.test_request_context("/", json={"app_id": "a1"}),
@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi:
app_entity = MagicMock(is_public=False)
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
]
# scalar() returns recommended_app
session.scalar.return_value = recommended
# get() returns the app entity
session.get.return_value = app_entity
with (
app.test_request_context("/", json={"app_id": "a1"}),

View File

@ -958,8 +958,8 @@ class TestTrialSitApi:
app_model = MagicMock()
app_model.id = "a1"
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
mock_scalar.return_value = None
with pytest.raises(Forbidden):
method(api, app_model)
@ -973,8 +973,8 @@ class TestTrialSitApi:
app_model.tenant = MagicMock()
app_model.tenant.status = TenantStatus.ARCHIVE
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = site
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
mock_scalar.return_value = site
with pytest.raises(Forbidden):
method(api, app_model)
@ -990,10 +990,10 @@ class TestTrialSitApi:
with (
app.test_request_context("/"),
patch.object(module.db.session, "query") as mock_query,
patch.object(module.db.session, "scalar") as mock_scalar,
patch.object(module.SiteResponse, "model_validate") as mock_validate,
):
mock_query.return_value.where.return_value.first.return_value = site
mock_scalar.return_value = site
mock_validate_result = MagicMock()
mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"}
mock_validate.return_value = mock_validate_result

View File

@ -34,9 +34,9 @@ def test_installed_app_required_not_found():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = None
scalar_mock.return_value = None
with pytest.raises(NotFound):
view("app-id")
@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
patch("controllers.console.explore.wraps.db.session.delete"),
patch("controllers.console.explore.wraps.db.session.commit"),
):
q.return_value.where.return_value.first.return_value = installed_app
scalar_mock.return_value = installed_app
with pytest.raises(NotFound):
view("app-id")
@ -76,9 +76,9 @@ def test_installed_app_required_success():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = installed_app
scalar_mock.return_value = installed_app
result = view("app-id")
assert result == installed_app
@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.return_value = None
scalar_mock.return_value = None
with pytest.raises(TrialAppNotAllowed):
view("app-id")
@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.side_effect = [
scalar_mock.side_effect = [
trial_app,
record,
]
@ -194,9 +194,9 @@ def test_trial_app_required_success():
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
):
q.return_value.where.return_value.first.side_effect = [
scalar_mock.side_effect = [
trial_app,
record,
]

View File

@ -11,6 +11,7 @@ from controllers.console.tag.tags import (
TagListApi,
TagUpdateDeleteApi,
)
from models.enums import TagType
def unwrap(func):
@ -52,7 +53,7 @@ def tag():
tag = MagicMock()
tag.id = "tag-1"
tag.name = "test-tag"
tag.type = "knowledge"
tag.type = TagType.KNOWLEDGE
return tag

View File

@ -114,7 +114,7 @@ class TestBaseApiKeyResource:
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = None
db_mock.session.scalar.return_value = None
with patch("controllers.console.apikey._get_resource"):
with pytest.raises(Exception) as exc_info:
@ -125,7 +125,7 @@ class TestBaseApiKeyResource:
def test_delete_success(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
db_mock.session.scalar.return_value = MagicMock()
with (
patch("controllers.console.apikey._get_resource"),

View File

@ -328,7 +328,7 @@ class TestSystemSetup:
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = None # No setup
mock_db.session.scalar.return_value = None # No setup
mock_environ_get.return_value = "some_password"
@setup_required
@ -345,7 +345,7 @@ class TestSystemSetup:
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = None # No setup
mock_db.session.scalar.return_value = None # No setup
mock_environ_get.return_value = None # No INIT_PASSWORD
@setup_required

View File

@ -55,9 +55,9 @@ class TestAccountInitApi:
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
patch("controllers.console.workspace.account.db.session.query") as query_mock,
patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock,
):
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
scalar_mock.return_value = MagicMock(status="unused")
resp = method(api)
assert resp["result"] == "success"

View File

@ -207,10 +207,10 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 200
@ -226,9 +226,9 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
):
q.return_value.where.return_value.first.return_value = None
get_mock.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
@ -244,13 +244,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 400
@ -266,13 +266,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 403
@ -288,13 +288,13 @@ class TestMemberCancelInviteApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
q.return_value.where.return_value.first.return_value = member
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 404

View File

@ -36,7 +36,115 @@ def unwrap(func):
class TestTenantListApi:
def test_get_success(self, app):
def test_get_success_saas_path(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
return_value={
"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0},
"t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0},
},
) as get_plan_bulk_mock,
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_not_called()
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app):
"""Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used.
billing.enabled is mocked False to prove the endpoint does not gate on it for this path
(SaaS contract treats enabled as on; display follows subscription.plan).
"""
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
features_t2 = MagicMock()
features_t2.billing.enabled = False
features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}},
) as get_plan_bulk_mock,
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features_t2,
) as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_called_once_with("t2")
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
"""Test fallback to FeatureService when bulk billing returns empty result.
BillingService.get_plan_bulk catches exceptions internally and returns empty dict,
so we simulate the real failure mode by returning empty dict for non-empty input.
"""
api = TenantListApi()
method = unwrap(api.get)
@ -54,27 +162,41 @@ class TestTenantListApi:
)
features = MagicMock()
features.billing.enabled = True
features.billing.subscription.plan = CloudPlan.SANDBOX
features.billing.enabled = False
features.billing.subscription.plan = CloudPlan.TEAM
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
return_value={}, # Simulates real failure: empty result for non-empty input
) as get_plan_bulk_mock,
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
) as get_features_mock,
patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock,
):
result, status = method(api)
assert status == 200
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.TEAM
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
assert get_features_mock.call_count == 2
logger_warning_mock.assert_called_once()
def test_get_billing_disabled(self, app):
def test_get_billing_disabled_community_path(self, app):
api = TenantListApi()
method = unwrap(api.get)
@ -87,6 +209,7 @@ class TestTenantListApi:
features = MagicMock()
features.billing.enabled = False
features.billing.subscription.plan = CloudPlan.SANDBOX
with (
app.test_request_context("/workspaces"),
@ -98,15 +221,83 @@ class TestTenantListApi:
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
),
) as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
get_features_mock.assert_called_once_with("t1")
def test_get_enterprise_only_skips_feature_service(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX
assert result["workspaces"][0]["current"] is False
assert result["workspaces"][1]["current"] is True
get_features_mock.assert_not_called()
def test_get_enterprise_only_with_empty_tenants(self, app):
api = TenantListApi()
method = unwrap(api.get)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None)
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"] == []
get_features_mock.assert_not_called()
class TestWorkspaceListApi:
@ -258,12 +449,12 @@ class TestSwitchWorkspaceApi:
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
),
):
query_mock.return_value.get.return_value = tenant
get_mock.return_value = tenant
result = method(api)
assert result["result"] == "success"
@ -297,9 +488,9 @@ class TestSwitchWorkspaceApi:
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
):
query_mock.return_value.get.return_value = None
get_mock.return_value = None
with pytest.raises(ValueError):
method(api)

View File

@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import (
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.enums import TagType
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.tag_service import TagService
@ -277,7 +278,7 @@ class TestDatasetTagsApi:
mock_tag = Mock()
mock_tag.id = "tag_1"
mock_tag.name = "Test Tag"
mock_tag.type = "knowledge"
mock_tag.type = TagType.KNOWLEDGE
mock_tag.binding_count = "0" # Required for Pydantic validation - must be string
mock_tag_service.get_tags.return_value = [mock_tag]
@ -316,7 +317,7 @@ class TestDatasetTagsApi:
mock_tag = Mock()
mock_tag.id = "new_tag_1"
mock_tag.name = "New Tag"
mock_tag.type = "knowledge"
mock_tag.type = TagType.KNOWLEDGE
mock_tag_service.save_tags.return_value = mock_tag
mock_service_api_ns.payload = {"name": "New Tag"}
@ -378,7 +379,7 @@ class TestDatasetTagsApi:
mock_tag = Mock()
mock_tag.id = "tag_1"
mock_tag.name = "Updated Tag"
mock_tag.type = "knowledge"
mock_tag.type = TagType.KNOWLEDGE
mock_tag.binding_count = "5"
mock_tag_service.update_tags.return_value = mock_tag
mock_tag_service.get_tag_binding_count.return_value = 5
@ -866,7 +867,7 @@ class TestTagService:
mock_tag = Mock()
mock_tag.id = str(uuid.uuid4())
mock_tag.name = "New Tag"
mock_tag.type = "knowledge"
mock_tag.type = TagType.KNOWLEDGE
mock_save.return_value = mock_tag
result = TagService.save_tags({"name": "New Tag", "type": "knowledge"})

View File

@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.task_pipeline import message_cycle_manager
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from models.enums import ConversationFromSource
from models.model import AppMode, Conversation, Message
@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation():
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id="user-id",
from_account_id=None,
)

View File

@ -0,0 +1,181 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
from models.api_based_extension import APIBasedExtension
class TestApiModeration:
@pytest.fixture
def api_config(self):
return {
"inputs_config": {
"enabled": True,
},
"outputs_config": {
"enabled": True,
},
"api_based_extension_id": "test-extension-id",
}
@pytest.fixture
def api_moderation(self, api_config):
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
def test_moderation_input_params(self):
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
assert params.app_id == "app-1"
assert params.inputs == {"key": "val"}
assert params.query == "test query"
# Test defaults
params_default = ModerationInputParams()
assert params_default.app_id == ""
assert params_default.inputs == {}
assert params_default.query == ""
def test_moderation_output_params(self):
params = ModerationOutputParams(app_id="app-1", text="test text")
assert params.app_id == "app-1"
assert params.text == "test text"
with pytest.raises(ValidationError):
ModerationOutputParams()
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_success(self, mock_get_extension, api_config):
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
ApiModeration.validate_config("test-tenant-id", api_config)
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
def test_validate_config_missing_extension_id(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": True},
}
with pytest.raises(ValueError, match="api_based_extension_id is required"):
ApiModeration.validate_config("test-tenant-id", config)
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
mock_get_extension.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
ApiModeration.validate_config("test-tenant-id", api_config)
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
assert isinstance(result, ModerationInputsResult)
assert result.flagged is True
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == "Blocked by API"
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_INPUT,
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
)
def test_moderation_for_inputs_disabled(self):
config = {
"inputs_config": {"enabled": False},
"outputs_config": {"enabled": True},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_inputs(inputs={}, query="")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == ""
def test_moderation_for_inputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_inputs({}, "")
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
result = api_moderation.moderation_for_outputs(text="hello world")
assert isinstance(result, ModerationOutputsResult)
assert result.flagged is False
mock_get_config.assert_called_once_with(
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
)
def test_moderation_for_outputs_disabled(self):
config = {
"inputs_config": {"enabled": True},
"outputs_config": {"enabled": False},
"api_based_extension_id": "ext-id",
}
moderation = ApiModeration("app-id", "tenant-id", config)
result = moderation.moderation_for_outputs(text="test")
assert result.flagged is False
assert result.action == ModerationAction.DIRECT_OUTPUT
def test_moderation_for_outputs_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation.moderation_for_outputs("test")
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
@patch("core.moderation.api.api.decrypt_token")
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_ext.api_endpoint = "http://api.test"
mock_ext.api_key = "encrypted-key"
mock_get_ext.return_value = mock_ext
mock_decrypt.return_value = "decrypted-key"
mock_requestor = MagicMock()
mock_requestor.request.return_value = {"flagged": True}
mock_requestor_cls.return_value = mock_requestor
params = {"some": "params"}
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
assert result == {"flagged": True}
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
def test_get_config_by_requestor_no_config(self):
moderation = ApiModeration("app-id", "tenant-id", None)
with pytest.raises(ValueError, match="The config is not set"):
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
mock_get_ext.return_value = None
with pytest.raises(ValueError, match="API-based Extension not found"):
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
@patch("core.moderation.api.api.db.session.scalar")
def test_get_api_based_extension(self, mock_scalar):
mock_ext = MagicMock(spec=APIBasedExtension)
mock_scalar.return_value = mock_ext
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
assert result == mock_ext
mock_scalar.assert_called_once()
# Verify the call has the correct filters
args, kwargs = mock_scalar.call_args
stmt = args[0]
# We can't easily inspect the statement without complex sqlalchemy tricks,
# but calling it is usually enough for unit tests if we mock the result.

View File

@ -0,0 +1,207 @@
from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
from core.moderation.input_moderation import InputModeration
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager
class TestInputModeration:
@pytest.fixture
def app_config(self):
config = MagicMock(spec=AppConfig)
config.sensitive_word_avoidance = None
return config
@pytest.fixture
def input_moderation(self):
return InputModeration()
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {"keywords": ["bad"]}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is False
assert final_inputs == inputs
assert final_query == query
mock_factory_cls.assert_called_once_with(
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
)
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
@patch("core.moderation.input_moderation.ModerationFactory")
@patch("core.moderation.input_moderation.TraceTask")
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
trace_manager = MagicMock(spec=TraceQueueManager)
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_inputs.return_value = mock_result
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
trace_manager=trace_manager,
)
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
mock_trace_task.assert_called_once()
call_kwargs = mock_trace_task.call_args.kwargs
call_args = mock_trace_task.call_args.args
assert call_args[0] == TraceTaskName.MODERATION_TRACE
assert call_kwargs["message_id"] == message_id
assert call_kwargs["moderation_result"] == mock_result
assert call_kwargs["inputs"] == inputs
assert "timer" in call_kwargs
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
)
mock_factory.moderation_for_inputs.return_value = mock_result
with pytest.raises(ModerationError) as excinfo:
input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert str(excinfo.value) == "Blocked content"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = ModerationInputsResult(
flagged=True,
action=ModerationAction.OVERRIDDEN,
inputs={"input_key": "overridden_value"},
query="overridden query",
)
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
)
assert flagged is True
assert final_inputs == {"input_key": "overridden_value"}
assert final_query == "overridden query"
@patch("core.moderation.input_moderation.ModerationFactory")
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
app_id = "test_app_id"
tenant_id = "test_tenant_id"
inputs = {"input_key": "input_value"}
query = "test query"
message_id = "test_message_id"
# Setup config
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
sensitive_word_config.type = "keywords"
sensitive_word_config.config = {}
app_config.sensitive_word_avoidance = sensitive_word_config
# Setup factory mock
mock_factory = mock_factory_cls.return_value
mock_result = MagicMock()
mock_result.flagged = True
mock_result.action = "NONE" # Some other action
mock_factory.moderation_for_inputs.return_value = mock_result
flagged, final_inputs, final_query = input_moderation.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_config,
inputs=inputs,
query=query,
message_id=message_id,
)
assert flagged is True
assert final_inputs == inputs
assert final_query == query

View File

@ -0,0 +1,234 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import QueueMessageReplaceEvent
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.output_moderation import ModerationRule, OutputModeration
class TestOutputModeration:
@pytest.fixture
def mock_queue_manager(self):
return MagicMock(spec=AppQueueManager)
@pytest.fixture
def moderation_rule(self):
return ModerationRule(type="keywords", config={"keywords": "badword"})
@pytest.fixture
def output_moderation(self, mock_queue_manager, moderation_rule):
return OutputModeration(
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
)
def test_should_direct_output(self, output_moderation):
assert output_moderation.should_direct_output() is False
output_moderation.final_output = "blocked"
assert output_moderation.should_direct_output() is True
def test_get_final_output(self, output_moderation):
assert output_moderation.get_final_output() == ""
output_moderation.final_output = "blocked"
assert output_moderation.get_final_output() == "blocked"
def test_append_new_token(self, output_moderation):
with patch.object(OutputModeration, "start_thread") as mock_start:
output_moderation.append_new_token("hello")
assert output_moderation.buffer == "hello"
mock_start.assert_called_once()
output_moderation.thread = MagicMock()
output_moderation.append_new_token(" world")
assert output_moderation.buffer == "hello world"
assert mock_start.call_count == 1
def test_moderation_completion_no_flag(self, output_moderation):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output, flagged = output_moderation.moderation_completion("safe content")
assert output == "safe content"
assert flagged is False
assert output_moderation.is_final_chunk is True
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "preset"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert isinstance(args[0], QueueMessageReplaceEvent)
assert args[0].text == "preset"
assert args[1] == PublishFrom.TASK_PIPELINE
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
)
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
assert output == "masked content"
assert flagged is True
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked content"
def test_start_thread(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
mock_current_app._get_current_object.return_value = mock_app
with patch("threading.Thread") as mock_thread_class:
mock_thread_instance = MagicMock()
mock_thread_class.return_value = mock_thread_instance
thread = output_moderation.start_thread()
assert thread == mock_thread_instance
mock_thread_class.assert_called_once()
mock_thread_instance.start.assert_called_once()
def test_stop_thread(self, output_moderation):
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
output_moderation.thread = mock_thread
output_moderation.stop_thread()
assert output_moderation.thread_running is False
output_moderation.thread_running = True
mock_thread.is_alive.return_value = False
output_moderation.stop_thread()
assert output_moderation.thread_running is True
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_success(self, mock_factory_class, output_moderation):
mock_factory = mock_factory_class.return_value
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_outputs.return_value = mock_result
result = output_moderation.moderation("tenant", "app", "buffer")
assert result == mock_result
mock_factory_class.assert_called_once_with(
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
)
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_exception(self, mock_factory_class, output_moderation):
mock_factory_class.side_effect = Exception("error")
result = output_moderation.moderation("tenant", "app", "buffer")
assert result is None
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
# Test exit on thread_running=False
output_moderation.thread_running = False
output_moderation.worker(mock_app, 10)
# Should exit immediately
def test_worker_no_flag(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
output_moderation.buffer = "safe"
output_moderation.is_final_chunk = True
# To avoid infinite loop, we'll set thread_running to False after one iteration
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
return mock_moderation.return_value
mock_moderation.side_effect = side_effect
output_moderation.worker(mock_app, 10)
assert mock_moderation.called
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
)
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
assert output_moderation.final_output == "preset"
mock_queue_manager.publish.assert_called_once()
# It breaks on DIRECT_OUTPUT
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Use side_effect to change thread_running on second call
def side_effect(*args, **kwargs):
if mock_moderation.call_count > 1:
output_moderation.thread_running = False
return None
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
mock_moderation.side_effect = side_effect
output_moderation.buffer = "badword"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_called_once()
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked"
def test_worker_chunk_too_small(self, output_moderation):
mock_app = MagicMock(spec=Flask)
with patch("time.sleep") as mock_sleep:
# chunk_length < buffer_size and not is_final_chunk
output_moderation.buffer = "123" # length 3
output_moderation.is_final_chunk = False
def sleep_side_effect(seconds):
output_moderation.thread_running = False
mock_sleep.side_effect = sleep_side_effect
output_moderation.worker(mock_app, 10) # buffer_size 10
mock_sleep.assert_called_once_with(1)
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Return None (exception or no rule)
mock_moderation.return_value = None
def side_effect(*args, **kwargs):
output_moderation.thread_running = False
mock_moderation.side_effect = side_effect
output_moderation.buffer = "something"
output_moderation.is_final_chunk = True
output_moderation.worker(mock_app, 10)
mock_queue_manager.publish.assert_not_called()

View File

@ -0,0 +1,160 @@
from unittest.mock import patch
import httpx
import pytest
from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import (
TidbOnQdrantConfig,
TidbOnQdrantVector,
)
class TestTidbOnQdrantVectorDeleteByIds:
"""Unit tests for TidbOnQdrantVector.delete_by_ids method."""
@pytest.fixture
def vector_instance(self):
"""Create a TidbOnQdrantVector instance for testing."""
config = TidbOnQdrantConfig(
endpoint="http://localhost:6333",
api_key="test_api_key",
)
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"):
vector = TidbOnQdrantVector(
collection_name="test_collection",
group_id="test_group",
config=config,
)
return vector
def test_delete_by_ids_with_multiple_ids(self, vector_instance):
"""Test batch deletion with multiple document IDs."""
ids = ["doc1", "doc2", "doc3"]
vector_instance.delete_by_ids(ids)
# Verify that delete was called once with MatchAny filter
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
# Check collection name
assert call_args[1]["collection_name"] == "test_collection"
# Verify filter uses MatchAny with all IDs
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
assert len(filter_obj.must) == 1
field_condition = filter_obj.must[0]
assert field_condition.key == "metadata.doc_id"
assert isinstance(field_condition.match, rest.MatchAny)
assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"}
def test_delete_by_ids_with_single_id(self, vector_instance):
"""Test deletion with a single document ID."""
ids = ["doc1"]
vector_instance.delete_by_ids(ids)
# Verify that delete was called once
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
# Verify filter uses MatchAny with single ID
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ["doc1"]
def test_delete_by_ids_with_empty_list(self, vector_instance):
"""Test deletion with empty ID list returns early without API call."""
vector_instance.delete_by_ids([])
# Verify that delete was NOT called
vector_instance._client.delete.assert_not_called()
def test_delete_by_ids_with_404_error(self, vector_instance):
"""Test that 404 errors (collection not found) are handled gracefully."""
ids = ["doc1", "doc2"]
# Mock a 404 error
error = UnexpectedResponse(
status_code=404,
reason_phrase="Not Found",
content=b"Collection not found",
headers=httpx.Headers(),
)
vector_instance._client.delete.side_effect = error
# Should not raise an exception
vector_instance.delete_by_ids(ids)
# Verify delete was called
vector_instance._client.delete.assert_called_once()
def test_delete_by_ids_with_unexpected_error(self, vector_instance):
"""Test that non-404 errors are re-raised."""
ids = ["doc1", "doc2"]
# Mock a 500 error
error = UnexpectedResponse(
status_code=500,
reason_phrase="Internal Server Error",
content=b"Server error",
headers=httpx.Headers(),
)
vector_instance._client.delete.side_effect = error
# Should re-raise the exception
with pytest.raises(UnexpectedResponse) as exc_info:
vector_instance.delete_by_ids(ids)
assert exc_info.value.status_code == 500
def test_delete_by_ids_with_large_batch(self, vector_instance):
"""Test deletion with a large batch of IDs."""
# Create 1000 IDs
ids = [f"doc_{i}" for i in range(1000)]
vector_instance.delete_by_ids(ids)
# Verify single delete call with all IDs
vector_instance._client.delete.assert_called_once()
call_args = vector_instance._client.delete.call_args
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
field_condition = filter_obj.must[0]
# Verify all 1000 IDs are in the batch
assert len(field_condition.match.any) == 1000
assert "doc_0" in field_condition.match.any
assert "doc_999" in field_condition.match.any
def test_delete_by_ids_filter_structure(self, vector_instance):
"""Test that the filter structure is correctly constructed."""
ids = ["doc1", "doc2"]
vector_instance.delete_by_ids(ids)
call_args = vector_instance._client.delete.call_args
filter_selector = call_args[1]["points_selector"]
filter_obj = filter_selector.filter
# Verify Filter structure
assert isinstance(filter_obj, rest.Filter)
assert filter_obj.must is not None
assert len(filter_obj.must) == 1
# Verify FieldCondition structure
field_condition = filter_obj.must[0]
assert isinstance(field_condition, rest.FieldCondition)
assert field_condition.key == "metadata.doc_id"
# Verify MatchAny structure
assert isinstance(field_condition.match, rest.MatchAny)
assert field_condition.match.any == ids

View File

@ -0,0 +1,677 @@
from __future__ import annotations
import dataclasses
import json
from collections.abc import Sequence
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository,
_HumanInputFormEntityImpl,
_HumanInputFormRecipientEntityImpl,
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
from dify_graph.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
)
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
from libs.datetime_utils import naive_utc_now
from models.human_input import HumanInputFormRecipient, RecipientType
@pytest.fixture(autouse=True)
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
class _FakeSelect:
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
return self
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
payload: dict[str, Any] = {
"form_content": "hi",
"inputs": [],
"user_actions": [{"id": "submit", "title": "Submit"}],
"rendered_content": "<p>hi</p>",
}
if include_expiration_time:
payload["expiration_time"] = naive_utc_now()
return json.dumps(payload, default=str)
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str | None
node_id: str
tenant_id: str
app_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str | None
class _FakeScalarResult:
def __init__(self, obj: Any):
self._obj = obj
def first(self) -> Any:
if isinstance(self._obj, list):
return self._obj[0] if self._obj else None
return self._obj
def all(self) -> list[Any]:
if self._obj is None:
return []
if isinstance(self._obj, list):
return list(self._obj)
return [self._obj]
class _FakeExecuteResult:
def __init__(self, rows: Sequence[tuple[Any, ...]]):
self._rows = list(rows)
def all(self) -> list[tuple[Any, ...]]:
return list(self._rows)
class _FakeSession:
def __init__(
self,
*,
scalars_result: Any = None,
scalars_results: list[Any] | None = None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
execute_rows: Sequence[tuple[Any, ...]] = (),
):
if scalars_results is not None:
self._scalars_queue = list(scalars_results)
else:
self._scalars_queue = [scalars_result]
self._forms = forms or {}
self._recipients = recipients or {}
self._execute_rows = list(execute_rows)
self.added: list[Any] = []
def scalars(self, _query: Any) -> _FakeScalarResult:
if self._scalars_queue:
value = self._scalars_queue.pop(0)
else:
value = None
return _FakeScalarResult(value)
def execute(self, _stmt: Any) -> _FakeExecuteResult:
return _FakeExecuteResult(self._execute_rows)
def get(self, model_cls: Any, obj_id: str) -> Any:
name = getattr(model_cls, "__name__", "")
if name == "HumanInputForm":
return self._forms.get(obj_id)
if name == "HumanInputFormRecipient":
return self._recipients.get(obj_id)
return None
def add(self, obj: Any) -> None:
self.added.append(obj)
def add_all(self, objs: Sequence[Any]) -> None:
self.added.extend(list(objs))
def flush(self) -> None:
# Simulate DB default population for attributes referenced in entity wrappers.
for obj in self.added:
if hasattr(obj, "id") and obj.id in (None, ""):
obj.id = f"gen-{len(str(self.added))}"
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
if obj.recipient_type == RecipientType.CONSOLE:
obj.access_token = "token-console"
elif obj.recipient_type == RecipientType.BACKSTAGE:
obj.access_token = "token-backstage"
else:
obj.access_token = "token-webapp"
def refresh(self, _obj: Any) -> None:
return None
def begin(self) -> _FakeSession:
return self
def __enter__(self) -> _FakeSession:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
class _SessionFactoryStub:
def __init__(self, session: _FakeSession):
self._session = session
def create_session(self) -> _FakeSession:
return self._session
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
def test_recipient_entity_token_raises_when_missing() -> None:
recipient = SimpleNamespace(id="r1", access_token=None)
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
with pytest.raises(AssertionError, match="access_token should not be None"):
_ = entity.token
def test_recipient_entity_id_and_token_success() -> None:
recipient = SimpleNamespace(id="r1", access_token="tok")
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
assert entity.id == "r1"
assert entity.token == "tok"
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
webapp = _DummyRecipient(
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
assert entity.web_app_token == "ctok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
assert entity.web_app_token == "wtok"
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.web_app_token is None
def test_form_entity_submitted_data_parsed() -> None:
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
submitted_data='{"a": 1}',
submitted_at=naive_utc_now(),
)
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
assert entity.submitted is True
assert entity.submitted_data == {"a": 1}
assert entity.rendered_content == "<p>x</p>"
assert entity.selected_action_id is None
assert entity.status == HumanInputFormStatus.WAITING
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
expiration = naive_utc_now()
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=False),
rendered_content="<p>x</p>",
expiration_time=expiration,
submitted_data='{"k": "v"}',
)
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
assert record.definition.expiration_time == expiration
assert record.submitted_data == {"k": "v"}
assert record.submitted is False
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
created: list[SimpleNamespace] = []
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
recipient = SimpleNamespace(
id=f"{payload.TYPE}-{len(created)}",
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
access_token="tok",
)
created.append(recipient)
return recipient
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
form_id="f",
delivery_id="d",
members=[
_WorkspaceMemberInfo(user_id="u1", email=""),
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
],
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
)
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
def test_query_workspace_members_by_ids_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
assert rows == [
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
]
def test_query_all_workspace_members_maps_rows() -> None:
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
rows = repo._query_all_workspace_members(session=session)
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
def test_repository_init_sets_tenant_id() -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo._tenant_id == "tenant"
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
result = repo._delivery_method_to_model(
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
)
assert result.delivery.id == "del-1"
assert result.delivery.form_id == "form-1"
assert len(result.recipients) == 1
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
called: dict[str, Any] = {}
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
called.update(
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
)
return ["r"]
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
subject="s",
body="b",
)
)
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
assert result.recipients == ["r"]
assert called["delivery_id"] == "del-1"
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
monkeypatch.setattr(
repo,
"_query_all_workspace_members",
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
)
assert recipients == ["ok"]
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
assert restrict_to_user_ids == ["u1"]
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
recipients = repo._build_email_recipients(
session=MagicMock(),
form_id="f",
delivery_id="d",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
),
)
assert recipients == ["ok"]
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
assert repo.get_form("run", "node") is None
form = _DummyForm(
id="f1",
workflow_run_id="run",
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
)
session = _FakeSession(scalars_results=[form, [recipient]])
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
entity = repo.get_form("run", "node")
assert entity is not None
assert entity.id == "f1"
assert entity.recipients[0].id == "r1"
assert entity.recipients[0].token == "tok"
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
session = _FakeSession()
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
form_config = HumanInputNodeData(
title="Title",
delivery_methods=[],
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
)
params = FormCreateParams(
app_id="app",
workflow_execution_id="run",
node_id="node",
form_config=form_config,
rendered_content="<p>hello</p>",
delivery_methods=[WebAppDeliveryMethod()],
display_in_ui=True,
resolved_default_values={},
form_kind=HumanInputFormKind.RUNTIME,
console_recipient_required=True,
console_creator_account_id="acc-1",
backstage_recipient_required=True,
)
entity = repo.create_form(params)
assert entity.id == "form-id"
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
# Console token should take precedence when console recipient is present.
assert entity.web_app_token == "token-console"
assert len(entity.recipients) == 3
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
recipient = SimpleNamespace(form=None)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_token("tok") is None
def test_submission_repository_init_no_args() -> None:
repo = HumanInputFormSubmissionRepository()
assert isinstance(repo, HumanInputFormSubmissionRepository)
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f1",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
recipient = SimpleNamespace(
id="r1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="tok",
form=form,
)
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_token("tok")
assert record is not None
assert record.access_token == "tok"
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
repo = HumanInputFormSubmissionRepository()
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
assert record is not None
assert record.recipient_id == "r1"
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
repo = HumanInputFormSubmissionRepository()
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
missing_session = _FakeSession(forms={})
_patch_session_factory(monkeypatch, missing_session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_submitted(
form_id="missing",
recipient_id=None,
selected_action_id="a",
form_data={},
submission_user_id=None,
submission_end_user_id=None,
)
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"k": "v"},
submission_user_id="u",
submission_end_user_id="eu",
)
assert form.status == HumanInputFormStatus.SUBMITTED
assert form.submitted_at == fixed_now
assert record.submitted_data == {"k": "v"}
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.TIMEOUT,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
assert record.status == HumanInputFormStatus.TIMEOUT
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
status=HumanInputFormStatus.SUBMITTED,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form already submitted"):
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
form = _DummyForm(
id="f",
workflow_run_id=None,
node_id="node",
tenant_id="tenant",
app_id="app",
form_definition=_make_form_definition_json(include_expiration_time=True),
rendered_content="<p>x</p>",
expiration_time=naive_utc_now(),
selected_action_id="a",
submitted_data="{}",
submission_user_id="u",
submission_end_user_id="eu",
completed_by_recipient_id="r",
status=HumanInputFormStatus.WAITING,
)
session = _FakeSession(forms={form.id: form})
_patch_session_factory(monkeypatch, session)
repo = HumanInputFormSubmissionRepository()
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
assert form.status == HumanInputFormStatus.EXPIRED
assert form.selected_action_id is None
assert form.submitted_data is None
assert form.submission_user_id is None
assert form.submission_end_user_id is None
assert form.completed_by_recipient_id is None
assert record.status == HumanInputFormStatus.EXPIRED
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
repo = HumanInputFormSubmissionRepository()
with pytest.raises(FormNotFoundError, match="form not found"):
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)

View File

@ -1,84 +1,291 @@
from datetime import datetime
from datetime import UTC, datetime
from unittest.mock import MagicMock
from uuid import uuid4
from sqlalchemy import create_engine
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, WorkflowRun
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from models import Account, CreatorUserRole, EndUser, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
engine = create_engine("sqlite:///:memory:")
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
user = MagicMock(spec=Account)
user.id = str(uuid4())
user.current_tenant_id = str(uuid4())
repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=real_session_factory,
user=user,
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
session_context = MagicMock()
session_context.__enter__.return_value = session
session_context.__exit__.return_value = False
repository._session_factory = MagicMock(return_value=session_context)
return repository
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
return WorkflowExecution.new(
id_=execution_id,
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0.0",
graph={"nodes": [], "edges": []},
inputs={"query": "hello"},
started_at=started_at,
)
def test_save_uses_execution_started_at_when_record_does_not_exist():
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
session_factory = MagicMock(spec=sessionmaker)
session = MagicMock()
session.get.return_value = None
repository = _build_repository_with_mocked_session(session)
started_at = datetime(2026, 1, 1, 12, 0, 0)
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
session_factory.return_value.__enter__.return_value = session
return session_factory
def test_save_preserves_existing_created_at_when_record_already_exists():
session = MagicMock()
repository = _build_repository_with_mocked_session(session)
@pytest.fixture
def mock_engine():
"""Mock SQLAlchemy Engine."""
return MagicMock(spec=Engine)
execution_id = str(uuid4())
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repository._tenant_id
existing_run.created_at = existing_created_at
session.get.return_value = existing_run
execution = _build_execution(
execution_id=execution_id,
started_at=datetime(2026, 1, 1, 12, 30, 0),
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = MagicMock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = MagicMock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
outputs={"output1": "result1"},
status=WorkflowExecutionStatus.SUCCEEDED,
error_message="",
total_tokens=100,
total_steps=5,
exceptions_count=0,
started_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()
class TestSQLAlchemyWorkflowExecutionRepository:
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
app_id = "test_app_id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
)
assert repo._session_factory == mock_session_factory
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
def test_init_with_engine(self, mock_engine, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_engine,
user=mock_account,
app_id="test_app_id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
assert repo._session_factory.kw["bind"] == mock_engine
def test_init_invalid_session_factory(self, mock_account):
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowExecutionRepository(
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
)
def test_init_no_tenant_id(self, mock_session_factory):
user = MagicMock(spec=Account)
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
)
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
)
assert repo._tenant_id == mock_end_user.tenant_id
assert repo._creator_user_role == CreatorUserRole.END_USER
def test_to_domain_model(self, mock_session_factory, mock_account):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
db_model = MagicMock(spec=WorkflowRun)
db_model.id = str(uuid4())
db_model.workflow_id = str(uuid4())
db_model.type = "workflow"
db_model.version = "1.0"
db_model.inputs_dict = {"in": "val"}
db_model.outputs_dict = {"out": "val"}
db_model.graph_dict = {"nodes": []}
db_model.status = "succeeded"
db_model.error = "some error"
db_model.total_tokens = 50
db_model.total_steps = 3
db_model.exceptions_count = 1
db_model.created_at = datetime.now(UTC)
db_model.finished_at = datetime.now(UTC)
domain_model = repo._to_domain_model(db_model)
assert domain_model.id_ == db_model.id
assert domain_model.workflow_id == db_model.workflow_id
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
assert domain_model.inputs == db_model.inputs_dict
assert domain_model.error_message == "some error"
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Make elapsed time deterministic to avoid flaky tests
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.id == sample_workflow_execution.id_
assert db_model.tenant_id == repo._tenant_id
assert db_model.app_id == "test_app"
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
assert db_model.status == sample_workflow_execution.status.value
assert db_model.total_tokens == sample_workflow_execution.total_tokens
assert db_model.elapsed_time == 10.0
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Test with empty/None fields
sample_workflow_execution.graph = None
sample_workflow_execution.inputs = None
sample_workflow_execution.outputs = None
sample_workflow_execution.error_message = None
sample_workflow_execution.finished_at = None
db_model = repo._to_db_model(sample_workflow_execution)
assert db_model.graph is None
assert db_model.inputs is None
assert db_model.outputs is None
assert db_model.error is None
assert db_model.elapsed_time == 0
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=None,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
db_model = repo._to_db_model(sample_workflow_execution)
assert not hasattr(db_model, "app_id") or db_model.app_id is None
assert db_model.tenant_id == repo._tenant_id
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
)
# Test triggered_from missing
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(sample_workflow_execution)
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(sample_workflow_execution)
repo._creator_user_id = "some_id"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(sample_workflow_execution)
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
session = mock_session_factory.return_value.__enter__.return_value
session.merge.assert_called_once()
session.commit.assert_called_once()
# Check cache
assert sample_workflow_execution.id_ in repo._execution_cache
cached_model = repo._execution_cache[sample_workflow_execution.id_]
assert cached_model.id == sample_workflow_execution.id_
def test_save_uses_execution_started_at_when_record_does_not_exist(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
sample_workflow_execution.started_at = started_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = None
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
def test_save_preserves_existing_created_at_when_record_already_exists(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution_id = sample_workflow_execution.id_
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repo._tenant_id
existing_run.created_at = existing_created_at
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = existing_run
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
repo.save(sample_workflow_execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()

View File

@ -0,0 +1,772 @@
from __future__ import annotations
import json
import logging
from collections.abc import Mapping
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock
import psycopg2.errors
import pytest
from sqlalchemy import Engine, create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
_deterministic_json_dump,
_filter_by_offload_type,
_find_first,
_replace_or_append_offload,
)
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.enums import ExecutionOffLoadType
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
return user
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
user = Mock(spec=EndUser)
user.id = user_id
user.tenant_id = tenant_id
return user
def _execution(
*,
execution_id: str = "exec-id",
node_execution_id: str = "node-exec-id",
workflow_run_id: str = "run-id",
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
inputs: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
) -> WorkflowNodeExecution:
return WorkflowNodeExecution(
id=execution_id,
node_execution_id=node_execution_id,
workflow_id="workflow-id",
workflow_execution_id=workflow_run_id,
index=1,
predecessor_node_id=None,
node_id="node-id",
node_type=BuiltinNodeTypes.LLM,
title="Title",
inputs=inputs,
outputs=outputs,
process_data=process_data,
status=status,
error=None,
elapsed_time=1.0,
metadata=metadata,
created_at=datetime.now(UTC),
finished_at=None,
)
class _SessionCtx:
def __init__(self, session: Any):
self._session = session
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type, exc, tb) -> None:
return None
def _session_factory(session: Any) -> sessionmaker:
factory = Mock(spec=sessionmaker)
factory.return_value = _SessionCtx(session)
return factory
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
engine: Engine = create_engine("sqlite:///:memory:")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=engine,
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert isinstance(repo._session_factory, sessionmaker)
sm = Mock(spec=sessionmaker)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=sm,
user=_mock_end_user(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._creator_user_role.value == "end_user"
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
with pytest.raises(ValueError, match="Invalid session_factory type"):
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
session_factory=object(),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
user = _mock_account()
user.current_tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=user,
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
created: dict[str, Any] = {}
class FakeTruncator:
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
created.update(
{
"max_size_bytes": max_size_bytes,
"array_element_limit": array_element_limit,
"string_length_limit": string_length_limit,
}
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
FakeTruncator,
)
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
_ = repo._create_truncator()
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
assert _find_first([], lambda _: True) is None
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
assert len(replaced) == 2
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
# Happy path: deterministic json dump should be sorted
db_model = repo._to_db_model(execution)
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
repo._triggered_from = None
with pytest.raises(ValueError, match="triggered_from is required"):
repo._to_db_model(execution)
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution()
db_model = repo._to_db_model(execution)
assert db_model.app_id == "app"
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._to_db_model(execution)
repo._creator_user_id = "user"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._to_db_model(execution)
def test_is_duplicate_key_error_and_regenerate_id(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
assert repo._is_duplicate_key_error(duplicate_error) is True
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
execution = _execution(execution_id="old-id")
db_model = WorkflowNodeExecutionModel()
db_model.id = "old-id"
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
caplog.set_level(logging.WARNING)
repo._regenerate_id_on_duplicate(execution, db_model)
assert execution.id == "new-id"
assert db_model.id == "new-id"
assert any("Duplicate key conflict" in r.message for r in caplog.records)
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id1"
db_model.node_execution_id = "node1"
db_model.foo = "bar" # type: ignore[attr-defined]
db_model.__dict__["_private"] = "x"
existing = SimpleNamespace()
session.get.return_value = existing
repo._persist_to_database(db_model)
assert existing.foo == "bar"
session.add.assert_not_called()
assert repo._node_execution_cache["node1"] is db_model
session.reset_mock()
session.get.return_value = None
repo._node_execution_cache.clear()
repo._persist_to_database(db_model)
session.add.assert_called_once_with(db_model)
assert repo._node_execution_cache["node1"] is db_model
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return value, False
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
uploaded: dict[str, Any] = {}
class FakeFileService:
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
return SimpleNamespace(id="file-id", key="file-key")
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
class FakeTruncator:
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
return {"truncated": True}, True
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
assert result is not None
assert result.truncated_value == {"truncated": True}
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
assert result.offload.file_id == "file-id"
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = BuiltinNodeTypes.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"trunc": "i"})
db_model.process_data = json.dumps({"trunc": "p"})
db_model.outputs = json.dumps({"trunc": "o"})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = json.dumps({"total_tokens": 3})
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
off_in.file = SimpleNamespace(key="k-in")
off_out.file = SimpleNamespace(key="k-out")
off_proc.file = SimpleNamespace(key="k-proc")
db_model.offload_data = [off_out, off_in, off_proc]
def fake_load(key: str) -> bytes:
return json.dumps({"full": key}).encode()
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"full": "k-in"}
assert domain.outputs == {"full": "k-out"}
assert domain.process_data == {"full": "k-proc"}
assert domain.get_truncated_inputs() == {"trunc": "i"}
assert domain.get_truncated_outputs() == {"trunc": "o"}
assert domain.get_truncated_process_data() == {"trunc": "p"}
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_model = WorkflowNodeExecutionModel()
db_model.id = "id"
db_model.node_execution_id = "node-exec"
db_model.workflow_id = "wf"
db_model.workflow_run_id = "run"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node"
db_model.node_type = BuiltinNodeTypes.LLM
db_model.title = "t"
db_model.inputs = json.dumps({"i": 1})
db_model.process_data = json.dumps({"p": 2})
db_model.outputs = json.dumps({"o": 3})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 0.1
db_model.execution_metadata = "{}"
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
db_model.offload_data = []
domain = repo._to_domain_model(db_model)
assert domain.inputs == {"i": 1}
assert domain.outputs == {"o": 3}
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeConverter:
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
return {"wrapped": values["a"]}
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
FakeConverter,
)
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
id="id",
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
inputs=None,
outputs=None,
process_data=None,
)
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
trunc_result = SimpleNamespace(
truncated_value={"trunc": True},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
)
monkeypatch.setattr(
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
# Inputs should be truncated, outputs/process_data encoded directly
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.inputs) == {"trunc": True}
assert json.loads(db_model.outputs) == {"b": 2}
assert json.loads(db_model.process_data) == {"c": 3}
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
assert execution.get_truncated_inputs() == {"trunc": True}
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
existing = SimpleNamespace(
id="id",
offload_data=[],
inputs=None,
outputs=None,
process_data=None,
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = existing
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
if values == {"b": 2}:
return SimpleNamespace(
truncated_value={"b": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
)
if values == {"c": 3}:
return SimpleNamespace(
truncated_value={"c": "trunc"},
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
)
return None
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
repo.save_execution_data(execution)
db_model = session.merge.call_args.args[0]
assert json.loads(db_model.outputs) == {"b": "trunc"}
assert json.loads(db_model.process_data) == {"c": "trunc"}
assert execution.get_truncated_outputs() == {"b": "trunc"}
assert execution.get_truncated_process_data() == {"c": "trunc"}
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
session = MagicMock()
session.execute.return_value.scalars.return_value.first.return_value = None
session.merge = Mock()
session.flush = Mock()
session.begin.return_value.__enter__ = Mock(return_value=session)
session.begin.return_value.__exit__ = Mock(return_value=None)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(inputs={"a": 1})
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
repo.save_execution_data(execution)
merged = session.merge.call_args.args[0]
assert merged.inputs == '{"a": 1}'
def test_save_retries_duplicate_and_logs_non_duplicate(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution(execution_id="id")
unique = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError("dup", params=None, orig=unique)
other_error = IntegrityError("other", params=None, orig=None)
calls = {"n": 0}
def persist(_db_model: Any) -> None:
calls["n"] += 1
if calls["n"] == 1:
raise duplicate_error
monkeypatch.setattr(repo, "_persist_to_database", persist)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
repo.save(execution)
assert execution.id == "new-id"
assert repo._node_execution_cache[execution.node_execution_id] is not None
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
with pytest.raises(IntegrityError):
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
def test_save_logs_and_reraises_on_unexpected_error(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
caplog.set_level(logging.ERROR)
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
with pytest.raises(RuntimeError, match="boom"):
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def __init__(self) -> None:
self.where_calls = 0
self.order_by_args: tuple[Any, ...] | None = None
def where(self, *_args: Any) -> FakeStmt:
self.where_calls += 1
return self
def order_by(self, *args: Any) -> FakeStmt:
self.order_by_args = args
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
model1 = SimpleNamespace(node_execution_id="n1")
model2 = SimpleNamespace(node_execution_id=None)
session = MagicMock()
session.scalars.return_value.all.return_value = [model1, model2]
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
db_models = repo.get_db_models_by_workflow_run("run", order)
assert db_models == [model1, model2]
assert repo._node_execution_cache["n1"] is model1
assert stmt.order_by_args is not None
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
class FakeStmt:
def where(self, *_args: Any) -> FakeStmt:
return self
def order_by(self, *args: Any) -> FakeStmt:
self.args = args # type: ignore[attr-defined]
return self
stmt = FakeStmt()
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
lambda _q: stmt,
)
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
session = MagicMock()
session.scalars.return_value.all.return_value = []
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=_session_factory(session),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
class FakeExecutor:
def __enter__(self) -> FakeExecutor:
return self
def __exit__(self, exc_type, exc, tb) -> None:
return None
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
assert timeout == 30
return list(map(func, items))
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
lambda max_workers: FakeExecutor(),
)
result = repo.get_by_workflow_run("run", order_config=None)
assert result == ["domain:db1", "domain:db2"]

View File

@ -0,0 +1,137 @@
import json
from unittest.mock import patch
from core.schemas.registry import SchemaRegistry
class TestSchemaRegistry:
def test_initialization(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
registry = SchemaRegistry(str(base_dir))
assert registry.base_dir == base_dir
assert registry.versions == {}
assert registry.metadata == {}
def test_default_registry_singleton(self):
registry1 = SchemaRegistry.default_registry()
registry2 = SchemaRegistry.default_registry()
assert registry1 is registry2
assert isinstance(registry1, SchemaRegistry)
def test_load_all_versions_non_existent_dir(self, tmp_path):
base_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(base_dir))
registry.load_all_versions()
assert registry.versions == {}
def test_load_all_versions_filtering(self, tmp_path):
base_dir = tmp_path / "schemas"
base_dir.mkdir()
(base_dir / "not_a_version_dir").mkdir()
(base_dir / "v1").mkdir()
(base_dir / "some_file.txt").write_text("content")
registry = SchemaRegistry(str(base_dir))
with patch.object(registry, "_load_version_dir") as mock_load:
registry.load_all_versions()
mock_load.assert_called_once()
assert mock_load.call_args[0][0] == "v1"
def test_load_version_dir_filtering(self, tmp_path):
version_dir = tmp_path / "v1"
version_dir.mkdir()
(version_dir / "schema1.json").write_text("{}")
(version_dir / "not_a_schema.txt").write_text("content")
registry = SchemaRegistry(str(tmp_path))
with patch.object(registry, "_load_schema") as mock_load:
registry._load_version_dir("v1", version_dir)
mock_load.assert_called_once()
assert mock_load.call_args[0][1] == "schema1"
def test_load_version_dir_non_existent(self, tmp_path):
version_dir = tmp_path / "non_existent"
registry = SchemaRegistry(str(tmp_path))
registry._load_version_dir("v1", version_dir)
assert "v1" not in registry.versions
def test_load_schema_success(self, tmp_path):
schema_path = tmp_path / "test.json"
schema_content = {"title": "Test Schema", "description": "A test schema"}
schema_path.write_text(json.dumps(schema_content))
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "test", schema_path)
assert registry.versions["v1"]["test"] == schema_content
uri = "https://dify.ai/schemas/v1/test.json"
assert registry.metadata[uri]["title"] == "Test Schema"
assert registry.metadata[uri]["version"] == "v1"
def test_load_schema_invalid_json(self, tmp_path, caplog):
schema_path = tmp_path / "invalid.json"
schema_path.write_text("invalid json")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
registry._load_schema("v1", "invalid", schema_path)
assert "Failed to load schema v1/invalid" in caplog.text
def test_load_schema_os_error(self, tmp_path, caplog):
schema_path = tmp_path / "error.json"
schema_path.write_text("{}")
registry = SchemaRegistry(str(tmp_path))
registry.versions["v1"] = {}
with patch("builtins.open", side_effect=OSError("Read error")):
registry._load_schema("v1", "error", schema_path)
assert "Failed to load schema v1/error" in caplog.text
def test_get_schema(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"type": "object"}}}
# Valid URI
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
# Invalid URI
assert registry.get_schema("invalid-uri") is None
# Missing version
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
def test_list_versions(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v2": {}, "v1": {}}
assert registry.list_versions() == ["v1", "v2"]
def test_list_schemas(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"b": {}, "a": {}}}
assert registry.list_schemas("v1") == ["a", "b"]
assert registry.list_schemas("v2") == []
def test_get_all_schemas_for_version(self):
registry = SchemaRegistry("/tmp")
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
results = registry.get_all_schemas_for_version("v1")
assert len(results) == 1
assert results[0]["name"] == "test"
assert results[0]["label"] == "Test Label"
assert results[0]["schema"] == {"title": "Test Label"}
# Default label if title missing
registry.versions["v1"]["no_title"] = {}
results = registry.get_all_schemas_for_version("v1")
item = next(r for r in results if r["name"] == "no_title")
assert item["label"] == "no_title"
# Empty if version missing
assert registry.get_all_schemas_for_version("v2") == []

View File

@ -0,0 +1,80 @@
from unittest.mock import MagicMock, patch
from core.schemas.registry import SchemaRegistry
from core.schemas.schema_manager import SchemaManager
def test_init_with_provided_registry():
mock_registry = MagicMock(spec=SchemaRegistry)
manager = SchemaManager(registry=mock_registry)
assert manager.registry == mock_registry
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
def test_init_with_default_registry(mock_default_registry):
mock_registry = MagicMock(spec=SchemaRegistry)
mock_default_registry.return_value = mock_registry
manager = SchemaManager()
mock_default_registry.assert_called_once()
assert manager.registry == mock_registry
def test_get_all_schema_definitions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
manager = SchemaManager(registry=mock_registry)
result = manager.get_all_schema_definitions(version="v2")
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
assert result == expected_definitions
def test_get_schema_by_name_success():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_schema = {"type": "object"}
mock_registry.get_schema.return_value = mock_schema
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("my_schema", version="v1")
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
mock_registry.get_schema.assert_called_once_with(expected_uri)
assert result == {"name": "my_schema", "schema": mock_schema}
def test_get_schema_by_name_not_found():
mock_registry = MagicMock(spec=SchemaRegistry)
mock_registry.get_schema.return_value = None
manager = SchemaManager(registry=mock_registry)
result = manager.get_schema_by_name("non_existent", version="v1")
assert result is None
def test_list_available_schemas():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_schemas = ["schema1", "schema2"]
mock_registry.list_schemas.return_value = expected_schemas
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_schemas(version="v1")
mock_registry.list_schemas.assert_called_once_with("v1")
assert result == expected_schemas
def test_list_available_versions():
mock_registry = MagicMock(spec=SchemaRegistry)
expected_versions = ["v1", "v2"]
mock_registry.list_versions.return_value = expected_versions
manager = SchemaManager(registry=mock_registry)
result = manager.list_available_versions()
mock_registry.list_versions.assert_called_once()
assert result == expected_versions

View File

@ -95,13 +95,11 @@ class TestGitHubOAuth(BaseOAuthTest):
],
"primary@example.com",
),
# User with no emails - fallback to noreply
({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
# User with only secondary email - fallback to noreply
# User with private email (null email and name from API)
(
{"id": 12345, "login": "testuser", "name": "Test User"},
[{"email": "secondary@example.com", "primary": False}],
"12345+testuser@users.noreply.github.com",
{"id": 12345, "login": "testuser", "name": None, "email": None},
[{"email": "primary@example.com", "primary": True}],
"primary@example.com",
),
],
)
@ -118,9 +116,54 @@ class TestGitHubOAuth(BaseOAuthTest):
user_info = oauth.get_user_info("test_token")
assert user_info.id == str(user_data["id"])
assert user_info.name == user_data["name"]
assert user_info.name == (user_data["name"] or "")
assert user_info.email == expected_email
@pytest.mark.parametrize(
("user_data", "email_data"),
[
# User with no emails
({"id": 12345, "login": "testuser", "name": "Test User"}, []),
# User with only secondary email
(
{"id": 12345, "login": "testuser", "name": "Test User"},
[{"email": "secondary@example.com", "primary": False}],
),
# User with private email and no primary in emails endpoint
(
{"id": 12345, "login": "testuser", "name": None, "email": None},
[],
),
],
)
@patch("httpx.get", autospec=True)
def test_should_raise_error_when_no_primary_email(self, mock_get, oauth, user_data, email_data):
user_response = MagicMock()
user_response.json.return_value = user_data
email_response = MagicMock()
email_response.json.return_value = email_data
mock_get.side_effect = [user_response, email_response]
with pytest.raises(ValueError, match="Keep my email addresses private"):
oauth.get_user_info("test_token")
@patch("httpx.get", autospec=True)
def test_should_raise_error_when_email_endpoint_fails(self, mock_get, oauth):
user_response = MagicMock()
user_response.json.return_value = {"id": 12345, "login": "testuser", "name": "Test User"}
email_response = MagicMock()
email_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Forbidden", request=MagicMock(), response=MagicMock()
)
mock_get.side_effect = [user_response, email_response]
with pytest.raises(ValueError, match="Keep my email addresses private"):
oauth.get_user_info("test_token")
@patch("httpx.get", autospec=True)
def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = httpx.RequestError("Network error")

View File

@ -16,6 +16,7 @@ from uuid import uuid4
import pytest
from models.enums import ConversationFromSource
from models.model import (
App,
AppAnnotationHitHistory,
@ -324,7 +325,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=from_end_user_id,
)
@ -345,7 +346,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
)
conversation._inputs = inputs
@ -364,7 +365,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
)
inputs = {"query": "Hello", "context": "test"}
@ -383,7 +384,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
summary="Test summary",
)
@ -402,7 +403,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
summary=None,
)
@ -425,7 +426,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
override_model_configs='{"model": "gpt-4"}',
)
@ -446,7 +447,7 @@ class TestConversationModel:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=from_end_user_id,
dialogue_count=5,
)
@ -487,7 +488,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
# Assert
@ -511,7 +512,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
message._inputs = inputs
@ -533,7 +534,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
inputs = {"query": "Hello", "context": "test"}
@ -555,7 +556,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
override_model_configs='{"model": "gpt-4"}',
)
@ -578,7 +579,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
message_metadata=json.dumps(metadata),
)
@ -600,7 +601,7 @@ class TestMessageModel:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
message_metadata=None,
)
@ -627,7 +628,7 @@ class TestMessageModel:
answer_unit_price=Decimal("0.0002"),
total_price=Decimal("0.0003"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
status="normal",
)
message.id = str(uuid4())
@ -988,7 +989,7 @@ class TestModelIntegration:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
from_end_user_id=str(uuid4()),
)
conversation.id = conversation_id
@ -1003,7 +1004,7 @@ class TestModelIntegration:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
message.id = message_id
@ -1064,7 +1065,7 @@ class TestModelIntegration:
message_unit_price=Decimal("0.0001"),
answer_unit_price=Decimal("0.0002"),
currency="USD",
from_source="api",
from_source=ConversationFromSource.API,
)
message.id = message_id
@ -1158,7 +1159,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = str(uuid4())
@ -1183,7 +1184,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
@ -1215,7 +1216,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
@ -1307,7 +1308,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
@ -1361,7 +1362,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id
@ -1418,7 +1419,7 @@ class TestConversationStatusCount:
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
from_source=ConversationFromSource.API,
)
conversation.id = conversation_id

View File

@ -12,7 +12,7 @@ This test suite covers:
import json
from uuid import uuid4
from core.tools.entities.tool_entities import ApiProviderSchemaType
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
@ -631,7 +631,7 @@ class TestToolLabelBinding:
"""Test creating a tool label binding."""
# Arrange
tool_id = "google.search"
tool_type = "builtin"
tool_type = ToolProviderType.BUILT_IN
label_name = "search"
# Act
@ -655,7 +655,7 @@ class TestToolLabelBinding:
# Act
label_binding = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
label_name=label_name,
)
@ -667,7 +667,7 @@ class TestToolLabelBinding:
"""Test multiple labels can be bound to the same tool."""
# Arrange
tool_id = "google.search"
tool_type = "builtin"
tool_type = ToolProviderType.BUILT_IN
# Act
binding1 = ToolLabelBinding(
@ -688,7 +688,7 @@ class TestToolLabelBinding:
def test_tool_label_binding_different_tool_types(self):
"""Test label bindings for different tool types."""
# Arrange
tool_types = ["builtin", "api", "workflow"]
tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW]
# Act & Assert
for tool_type in tool_types:
@ -951,12 +951,12 @@ class TestToolProviderRelationships:
# Act
binding1 = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
label_name="search",
)
binding2 = ToolLabelBinding(
tool_id=tool_id,
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
label_name="web",
)

View File

@ -4,12 +4,18 @@ from unittest import mock
from uuid import uuid4
from constants import HIDDEN_VALUE
from core.helper import encrypter
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.file.models import File
from dify_graph.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from dify_graph.variables.segments import IntegerSegment, Segment
from factories.variable_factory import build_segment
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from models.workflow import (
Workflow,
WorkflowDraftVariable,
WorkflowNodeExecutionModel,
is_system_variable_editable,
)
def test_environment_variables():
@ -144,6 +150,36 @@ def test_to_dict():
assert workflow_dict["environment_variables"][1]["value"] == "text"
def test_normalize_environment_variable_mappings_converts_full_mask_to_hidden_value():
normalized = Workflow.normalize_environment_variable_mappings(
[
{
"id": str(uuid4()),
"name": "secret",
"value": encrypter.full_mask_token(),
"value_type": "secret",
}
]
)
assert normalized[0]["value"] == HIDDEN_VALUE
def test_normalize_environment_variable_mappings_keeps_hidden_value():
normalized = Workflow.normalize_environment_variable_mappings(
[
{
"id": str(uuid4()),
"name": "secret",
"value": HIDDEN_VALUE,
"value_type": "secret",
}
]
)
assert normalized[0]["value"] == HIDDEN_VALUE
class TestWorkflowNodeExecution:
def test_execution_metadata_dict(self):
node_exec = WorkflowNodeExecutionModel()

View File

@ -1,135 +0,0 @@
"""Unit tests for non-SQL helper logic in workflow run repository."""
import secrets
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowPauseReason
from repositories.sqlalchemy_api_workflow_run_repository import (
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
)
@pytest.fixture
def sample_workflow_pause() -> Mock:
"""Create a sample WorkflowPause model."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestPrivateWorkflowPauseEntity:
"""Test _PrivateWorkflowPauseEntity class."""
def test_properties(self, sample_workflow_pause: Mock) -> None:
"""Test entity properties."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
# Assert
assert entity.id == sample_workflow_pause.id
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
assert entity.resumed_at == sample_workflow_pause.resumed_at
def test_get_state(self, sample_workflow_pause: Mock) -> None:
"""Test getting state from storage."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result = entity.get_state()
# Assert
assert result == expected_state
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
"""Test state caching in get_state method."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result1 = entity.get_state()
result2 = entity.get_state()
# Assert
assert result1 == expected_state
assert result2 == expected_state
mock_storage.load.assert_called_once()
class TestBuildHumanInputRequiredReason:
"""Test helper that builds HumanInputRequired pause reasons."""
def test_prefers_backstage_token_when_available(self) -> None:
"""Use backstage token when multiple recipient types may exist."""
# Arrange
expiration_time = datetime.now(UTC)
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
id="form-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_run_id="run-1",
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
reason_model = WorkflowPauseReason(
pause_id="pause-1",
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id="form-1",
node_id="node-1",
message="",
)
access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id="form-1",
delivery_id="delivery-1",
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=access_token,
)
# Act
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
# Assert
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"
assert reason.form_content == "content"
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"

View File

@ -1,251 +0,0 @@
"""Unit tests for workflow run repository with status filter."""
import uuid
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import sessionmaker
from models import WorkflowRun, WorkflowRunTriggeredFrom
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
class TestDifyAPISQLAlchemyWorkflowRunRepository:
"""Test workflow run repository with status filtering."""
@pytest.fixture
def mock_session_maker(self):
"""Create a mock session maker."""
return MagicMock(spec=sessionmaker)
@pytest.fixture
def repository(self, mock_session_maker):
"""Create repository instance with mock session."""
return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker):
"""Test getting paginated workflow runs without status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)]
mock_session.scalars.return_value.all.return_value = mock_runs
# Act
result = repository.get_paginated_workflow_runs(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status=None,
)
# Assert
assert len(result.data) == 3
assert result.limit == 20
assert result.has_more is False
def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker):
"""Test getting paginated workflow runs with status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)]
mock_session.scalars.return_value.all.return_value = mock_runs
# Act
result = repository.get_paginated_workflow_runs(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=20,
last_id=None,
status="succeeded",
)
# Assert
assert len(result.data) == 2
assert all(run.status == "succeeded" for run in result.data)
def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker):
"""Test getting workflow runs count without status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the GROUP BY query results
mock_results = [
("succeeded", 5),
("failed", 2),
("running", 1),
]
mock_session.execute.return_value.all.return_value = mock_results
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
)
# Assert
assert result["total"] == 8
assert result["succeeded"] == 5
assert result["failed"] == 2
assert result["running"] == 1
assert result["stopped"] == 0
assert result["partial-succeeded"] == 0
def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker):
"""Test getting workflow runs count with status filter."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the count query for succeeded status
mock_session.scalar.return_value = 5
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="succeeded",
)
# Assert
assert result["total"] == 5
assert result["succeeded"] == 5
assert result["running"] == 0
assert result["failed"] == 0
assert result["stopped"] == 0
assert result["partial-succeeded"] == 0
def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker):
"""Test that invalid status is still counted in total but not in any specific status."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock count query returning 0 for invalid status
mock_session.scalar.return_value = 0
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="invalid_status",
)
# Assert
assert result["total"] == 0
assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"])
def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker):
"""Test getting workflow runs count with time range filter verifies SQL query construction."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the GROUP BY query results
mock_results = [
("succeeded", 3),
("running", 2),
]
mock_session.execute.return_value.all.return_value = mock_results
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status=None,
time_range="1d",
)
# Assert results
assert result["total"] == 5
assert result["succeeded"] == 3
assert result["running"] == 2
assert result["failed"] == 0
# Verify that execute was called (which means GROUP BY query was used)
assert mock_session.execute.called, "execute should have been called for GROUP BY query"
# Verify SQL query includes time filter by checking the statement
call_args = mock_session.execute.call_args
assert call_args is not None, "execute should have been called with a statement"
# The first argument should be the SQL statement
stmt = call_args[0][0]
# Convert to string to inspect the query
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
# Verify the query includes created_at filter
# The query should have a WHERE clause with created_at comparison
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
"Query should include created_at filter for time range"
)
def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker):
"""Test getting workflow runs count with both status and time range filters verifies SQL query."""
# Arrange
tenant_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
# Mock the count query for running status within time range
mock_session.scalar.return_value = 2
# Act
result = repository.get_workflow_runs_count(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
status="running",
time_range="1d",
)
# Assert results
assert result["total"] == 2
assert result["running"] == 2
assert result["succeeded"] == 0
assert result["failed"] == 0
# Verify that scalar was called (which means COUNT query was used)
assert mock_session.scalar.called, "scalar should have been called for count query"
# Verify SQL query includes both status and time filter
call_args = mock_session.scalar.call_args
assert call_args is not None, "scalar should have been called with a statement"
# The first argument should be the SQL statement
stmt = call_args[0][0]
# Convert to string to inspect the query
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
# Verify the query includes both filters
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
"Query should include created_at filter for time range"
)
assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), (
"Query should include status filter"
)

Some files were not shown because too many files have changed in this diff Show More