mirror of https://github.com/langgenius/dify.git
Merge branch 'fix/memeber-settings' into deploy/dev
# Conflicts: # web/app/components/app/type-selector/index.tsx # web/app/components/header/account-setting/data-source-page/data-source-notion/operate/index.tsx # web/app/components/header/account-setting/data-source-page/data-source-website/config-firecrawl-modal.tsx # web/app/components/header/account-setting/data-source-page/data-source-website/config-jina-reader-modal.tsx # web/app/components/header/account-setting/data-source-page/data-source-website/config-watercrawl-modal.tsx # web/app/components/header/account-setting/data-source-page/panel/config-item.tsx # web/app/components/header/account-setting/data-source-page/panel/index.tsx # web/app/components/tools/labels/filter.tsx # web/app/components/workflow/nodes/knowledge-retrieval/node.tsx
This commit is contained in:
commit
b3bcaed9b9
|
|
@ -94,11 +94,6 @@ jobs:
|
|||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ jobs:
|
|||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
|
||||
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
|
|||
|
|
@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process.
|
|||
## Getting Help
|
||||
|
||||
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
## Automated Agent Contributions
|
||||
|
||||
> [!NOTE]
|
||||
> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -33,16 +33,10 @@ api_key_list_model = console_ns.model(
|
|||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
if resource_model == App:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if resource is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
|
||||
|
|
@ -80,10 +74,13 @@ class BaseApiKeyListResource(Resource):
|
|||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.count()
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
|
|
@ -119,14 +116,14 @@ class BaseApiKeyResource(Resource):
|
|||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
getattr(ApiToken, self.resource_id_field) == resource_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if key is None:
|
||||
|
|
@ -137,7 +134,7 @@ class BaseApiKeyResource(Resource):
|
|||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask import abort, request
|
|||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -376,8 +376,12 @@ class CompletionConversationApi(Resource):
|
|||
|
||||
# FIXME, the type ignore in this file
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
query = (
|
||||
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
|
|
@ -511,8 +515,12 @@ class ChatConversationApi(Resource):
|
|||
|
||||
match args.annotation_status:
|
||||
case "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
query = (
|
||||
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Literal
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -244,27 +244,25 @@ class ChatMessageListApi(Resource):
|
|||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args.first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
first_message = db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise NotFound("First message not found")
|
||||
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
|
|
@ -272,16 +270,14 @@ class ChatMessageListApi(Resource):
|
|||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args.limit:
|
||||
|
|
@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
||||
count = db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||
)
|
||||
|
||||
return {"count": count}
|
||||
|
||||
|
|
@ -479,7 +479,9 @@ class MessageApi(Resource):
|
|||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from flask import abort, request
|
|||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -46,13 +46,14 @@ from models import App
|
|||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
|
@ -284,7 +285,9 @@ class DraftWorkflowApi(Resource):
|
|||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables_list = Workflow.normalize_environment_variable_mappings(
|
||||
args.get("environment_variables") or [],
|
||||
)
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
|
@ -994,6 +997,43 @@ class PublishedAllWorkflowApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>/restore")
|
||||
class DraftWorkflowRestoreApi(Resource):
|
||||
@console_ns.doc("restore_workflow_to_draft")
|
||||
@console_ns.doc(description="Restore a published workflow version into the draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"})
|
||||
@console_ns.response(200, "Workflow restored successfully")
|
||||
@console_ns.response(400, "Source workflow must be published")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
workflow = workflow_service.restore_published_workflow_to_draft(
|
||||
app_model=app_model,
|
||||
workflow_id=workflow_id,
|
||||
account=current_user,
|
||||
)
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
|
|
@ -112,6 +113,9 @@ class OAuthCallback(Resource):
|
|||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
except ValueError as e:
|
||||
logger.warning("OAuth error with %s", provider, exc_info=True)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from flask import abort, request
|
|||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -16,7 +16,11 @@ from controllers.console.app.error import (
|
|||
DraftWorkflowNotExist,
|
||||
DraftWorkflowNotSync,
|
||||
)
|
||||
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
|
||||
from controllers.console.app.workflow import (
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
|
||||
workflow_model,
|
||||
workflow_pagination_model,
|
||||
)
|
||||
from controllers.console.app.workflow_run import (
|
||||
workflow_run_detail_model,
|
||||
workflow_run_node_execution_list_model,
|
||||
|
|
@ -42,7 +46,8 @@ from libs.login import current_account_with_tenant, current_user, login_required
|
|||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import EndUser
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
|
@ -203,9 +208,12 @@ class DraftRagPipelineApi(Resource):
|
|||
abort(415)
|
||||
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
environment_variables_list = payload.environment_variables or []
|
||||
environment_variables_list = Workflow.normalize_environment_variable_mappings(
|
||||
payload.environment_variables or [],
|
||||
)
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
|
@ -213,7 +221,6 @@ class DraftRagPipelineApi(Resource):
|
|||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=payload.graph,
|
||||
|
|
@ -705,6 +712,36 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>/restore")
|
||||
class RagPipelineDraftWorkflowRestoreApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
workflow = rag_pipeline_service.restore_published_workflow_to_draft(
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
account=current_user,
|
||||
)
|
||||
except IsDraftWorkflowError as exc:
|
||||
# Use a stable, predefined message to keep the 400 response consistent
|
||||
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
|
|
@ -17,14 +18,18 @@ class BannerApi(Resource):
|
|||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
banners = db.session.scalars(
|
||||
base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort)
|
||||
).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
banners = db.session.scalars(
|
||||
base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort)
|
||||
).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
|
|
|
|||
|
|
@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource):
|
|||
def post(self):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||
recommended_app = db.session.scalar(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1)
|
||||
)
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
raise NotFound("App entity not found")
|
||||
|
|
@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource):
|
|||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
installed_app = db.session.scalar(
|
||||
select(InstalledApp)
|
||||
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Literal, cast
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -476,7 +477,7 @@ class TrialSitApi(Resource):
|
|||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
|
@ -541,13 +542,7 @@ class AppWorkflowApi(Resource):
|
|||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
workflow = db.session.get(Workflow, app_model.workflow_id)
|
||||
return workflow
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar
|
|||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
|
|
@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
installed_app = db.session.scalar(
|
||||
select(InstalledApp)
|
||||
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
|
|
@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
|||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1))
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
|
@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
|||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
account_trial_app_record = db.session.scalar(
|
||||
select(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Literal
|
|||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
|
|
@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
|||
|
||||
def get_setup_status() -> DifySetup | bool | None:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return db.session.query(DifySetup).first()
|
||||
return db.session.scalar(select(DifySetup).limit(1))
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -212,13 +212,13 @@ class AccountInitApi(Resource):
|
|||
raise ValueError("invitation_code is required")
|
||||
|
||||
# check invitation code
|
||||
invitation_code = (
|
||||
db.session.query(InvitationCode)
|
||||
invitation_code = db.session.scalar(
|
||||
select(InvitationCode)
|
||||
.where(
|
||||
InvitationCode.code == args.invitation_code,
|
||||
InvitationCode.status == InvitationCodeStatus.UNUSED,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not invitation_code:
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource):
|
|||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if member is None:
|
||||
abort(404)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import select
|
|||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
|
|
@ -29,6 +30,7 @@ from libs.helper import TimestampField
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
|
|
@ -108,9 +110,29 @@ class TenantListApi(Resource):
|
|||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
|
||||
is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED
|
||||
tenant_plans: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
if is_saas:
|
||||
tenant_ids = [tenant.id for tenant in tenants]
|
||||
if tenant_ids:
|
||||
tenant_plans = BillingService.get_plan_bulk(tenant_ids)
|
||||
if not tenant_plans:
|
||||
logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path")
|
||||
|
||||
for tenant in tenants:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan: str = CloudPlan.SANDBOX
|
||||
if is_saas:
|
||||
tenant_plan = tenant_plans.get(tenant.id)
|
||||
if tenant_plan:
|
||||
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
|
||||
else:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
elif not is_enterprise_only:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
tenant_dict = {
|
||||
|
|
@ -118,7 +140,7 @@ class TenantListApi(Resource):
|
|||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
|
||||
"plan": plan,
|
||||
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||
}
|
||||
|
||||
|
|
@ -198,7 +220,7 @@ class SwitchWorkspaceApi(Resource):
|
|||
except Exception:
|
||||
raise AccountNotLinkTenantError("Account not link tenant")
|
||||
|
||||
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
|
||||
new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant
|
||||
if new_tenant is None:
|
||||
raise ValueError("Tenant not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
|
||||
|
|
@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
and os.environ.get("INIT_PASSWORD")
|
||||
and not db.session.query(DifySetup).first()
|
||||
):
|
||||
raise NotInitValidateError()
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
|
||||
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
raise NotInitValidateError()
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
|
|||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.enums import ConversationFromSource, CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
|
@ -130,10 +130,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
end_user_id = None
|
||||
account_id = None
|
||||
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
from_source = "api"
|
||||
from_source = ConversationFromSource.API
|
||||
end_user_id = application_generate_entity.user_id
|
||||
else:
|
||||
from_source = "console"
|
||||
from_source = ConversationFromSource.CONSOLE
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType
|
||||
from models.enums import CollectionBindingType, ConversationFromSource
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
|
@ -68,9 +68,9 @@ class AnnotationReplyFeature:
|
|||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
|
||||
from_source = "api"
|
||||
from_source = ConversationFromSource.API
|
||||
else:
|
||||
from_source = "console"
|
||||
from_source = ConversationFromSource.CONSOLE
|
||||
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(
|
||||
|
|
|
|||
|
|
@ -181,10 +181,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
|
|
|
|||
|
|
@ -284,27 +284,29 @@ class TidbOnQdrantVector(BaseVector):
|
|||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class BuiltinTool(Tool):
|
|||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class ToolLabelManager:
|
|||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
tool_type=controller.provider_type,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
|
|
@ -58,7 +58,7 @@ class ToolLabelManager:
|
|||
raise ValueError("Unsupported tool type")
|
||||
stmt = select(ToolLabelBinding.label_name).where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
ToolLabelBinding.tool_type == controller.provider_type,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from decimal import Decimal
|
|||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
|
|
@ -78,7 +79,7 @@ class ModelInvocationUtils:
|
|||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
|
|
|||
|
|
@ -101,6 +101,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
|||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
# Must be 0 to disable executor-level retries, as the graph engine handles them.
|
||||
# This is critical to prevent nested retries.
|
||||
max_retries=0,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,19 @@
|
|||
import logging
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
|
|
@ -30,8 +33,8 @@ class GitHubEmailRecord(TypedDict, total=False):
|
|||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
name: NotRequired[str | None]
|
||||
email: NotRequired[str | None]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
|
|
@ -127,9 +130,14 @@ class GitHubOAuth(OAuth):
|
|||
response.raise_for_status()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
try:
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_response.raise_for_status()
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
except (httpx.HTTPStatusError, ValidationError):
|
||||
logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True)
|
||||
primary_email = None
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
|
||||
|
|
@ -137,8 +145,11 @@ class GitHubOAuth(OAuth):
|
|||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
raise ValueError(
|
||||
'Dify currently not supports the "Keep my email addresses private" feature,'
|
||||
" please disable it and login again"
|
||||
)
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
|
|
|
|||
|
|
@ -34,13 +34,16 @@ from .enums import (
|
|||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
ConversationFromSource,
|
||||
ConversationStatus,
|
||||
CreatorUserRole,
|
||||
FeedbackFromSource,
|
||||
FeedbackRating,
|
||||
InvokeFrom,
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
TagType,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
|
@ -1022,10 +1025,12 @@ class Conversation(Base):
|
|||
#
|
||||
# Its value corresponds to the members of `InvokeFrom`.
|
||||
# (api/core/app/entities/app_invoke_entities.py)
|
||||
invoke_from = mapped_column(String(255), nullable=True)
|
||||
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
|
||||
|
||||
# ref: ConversationSource.
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_source: Mapped[ConversationFromSource] = mapped_column(
|
||||
EnumText(ConversationFromSource, length=255), nullable=False
|
||||
)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
read_at = mapped_column(sa.DateTime)
|
||||
|
|
@ -1374,8 +1379,10 @@ class Message(Base):
|
|||
)
|
||||
error: Mapped[str | None] = mapped_column(LongText)
|
||||
message_metadata: Mapped[str | None] = mapped_column(LongText)
|
||||
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
|
||||
from_source: Mapped[ConversationFromSource] = mapped_column(
|
||||
EnumText(ConversationFromSource, length=255), nullable=False
|
||||
)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
|
|
@ -2398,7 +2405,7 @@ class Tag(TypeBase):
|
|||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
|
|||
|
|
@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderSchemaType,
|
||||
ToolProviderType,
|
||||
WorkflowToolParameterConfiguration,
|
||||
)
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
|
|
@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase):
|
|||
# tool id
|
||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# tool type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# label name
|
||||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
||||
|
|
@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase):
|
|||
# provider
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# tool name
|
||||
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
|
|
@ -302,26 +303,40 @@ class Workflow(Base): # bug
|
|||
def features(self) -> str:
|
||||
"""
|
||||
Convert old features structure to new features structure.
|
||||
|
||||
This property avoids rewriting the underlying JSON when normalization
|
||||
produces no effective change, to prevent marking the row dirty on read.
|
||||
"""
|
||||
if not self._features:
|
||||
return self._features
|
||||
|
||||
features = json.loads(self._features)
|
||||
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
|
||||
image_enabled = True
|
||||
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
|
||||
image_transfer_methods = features["file_upload"]["image"].get(
|
||||
"transfer_methods", ["remote_url", "local_file"]
|
||||
)
|
||||
features["file_upload"]["enabled"] = image_enabled
|
||||
features["file_upload"]["number_limits"] = image_number_limits
|
||||
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
|
||||
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
|
||||
"allowed_file_extensions", []
|
||||
)
|
||||
del features["file_upload"]["image"]
|
||||
self._features = json.dumps(features)
|
||||
# Parse once and deep-copy before normalization to detect in-place changes.
|
||||
original_dict = self._decode_features_payload(self._features)
|
||||
if original_dict is None:
|
||||
return self._features
|
||||
|
||||
# Fast-path: if the legacy file_upload.image.enabled shape is absent, skip
|
||||
# deep-copy and normalization entirely and return the stored JSON.
|
||||
file_upload_payload = original_dict.get("file_upload")
|
||||
if not isinstance(file_upload_payload, dict):
|
||||
return self._features
|
||||
file_upload = cast(dict[str, Any], file_upload_payload)
|
||||
|
||||
image_payload = file_upload.get("image")
|
||||
if not isinstance(image_payload, dict):
|
||||
return self._features
|
||||
image = cast(dict[str, Any], image_payload)
|
||||
if "enabled" not in image:
|
||||
return self._features
|
||||
|
||||
normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict))
|
||||
|
||||
if normalized_dict == original_dict:
|
||||
# No effective change; return stored JSON unchanged.
|
||||
return self._features
|
||||
|
||||
# Normalization changed the payload: persist the normalized JSON.
|
||||
self._features = json.dumps(normalized_dict)
|
||||
return self._features
|
||||
|
||||
@features.setter
|
||||
|
|
@ -332,6 +347,44 @@ class Workflow(Base): # bug
|
|||
def features_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.features) if self.features else {}
|
||||
|
||||
@property
|
||||
def serialized_features(self) -> str:
|
||||
"""Return the stored features JSON without triggering compatibility rewrites."""
|
||||
return self._features
|
||||
|
||||
@property
|
||||
def normalized_features_dict(self) -> dict[str, Any]:
|
||||
"""Decode features with legacy normalization without mutating the model state."""
|
||||
if not self._features:
|
||||
return {}
|
||||
|
||||
features = self._decode_features_payload(self._features)
|
||||
return self._normalize_features_payload(features) if features is not None else {}
|
||||
|
||||
@staticmethod
|
||||
def _decode_features_payload(features: str) -> dict[str, Any] | None:
|
||||
"""Decode workflow features JSON when it contains an object payload."""
|
||||
payload = json.loads(features)
|
||||
return cast(dict[str, Any], payload) if isinstance(payload, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]:
|
||||
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
|
||||
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
|
||||
image_transfer_methods = features["file_upload"]["image"].get(
|
||||
"transfer_methods", ["remote_url", "local_file"]
|
||||
)
|
||||
features["file_upload"]["enabled"] = True
|
||||
features["file_upload"]["number_limits"] = image_number_limits
|
||||
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
|
||||
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
|
||||
"allowed_file_extensions", []
|
||||
)
|
||||
del features["file_upload"]["image"]
|
||||
|
||||
return features
|
||||
|
||||
def walk_nodes(
|
||||
self, specific_node_type: NodeType | None = None
|
||||
) -> Generator[tuple[str, Mapping[str, Any]], None, None]:
|
||||
|
|
@ -517,6 +570,31 @@ class Workflow(Base): # bug
|
|||
)
|
||||
self._environment_variables = environment_variables_json
|
||||
|
||||
@staticmethod
|
||||
def normalize_environment_variable_mappings(
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert masked secret placeholders into the draft hidden sentinel.
|
||||
|
||||
Regular draft sync requests should preserve existing secrets without shipping
|
||||
plaintext values back from the client. The dedicated restore endpoint now
|
||||
copies published secrets server-side, so draft sync only needs to normalize
|
||||
the UI mask into `HIDDEN_VALUE`.
|
||||
"""
|
||||
masked_secret_value = encrypter.full_mask_token()
|
||||
normalized_mappings: list[dict[str, Any]] = []
|
||||
|
||||
for mapping in mappings:
|
||||
normalized_mapping = dict(mapping)
|
||||
if (
|
||||
normalized_mapping.get("value_type") == SegmentType.SECRET.value
|
||||
and normalized_mapping.get("value") == masked_secret_value
|
||||
):
|
||||
normalized_mapping["value"] = HIDDEN_VALUE
|
||||
normalized_mappings.append(normalized_mapping)
|
||||
|
||||
return normalized_mappings
|
||||
|
||||
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
|
||||
environment_variables = list(self.environment_variables)
|
||||
environment_variables = [
|
||||
|
|
@ -564,6 +642,12 @@ class Workflow(Base): # bug
|
|||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None:
|
||||
"""Copy stored variable JSON directly for same-tenant restore flows."""
|
||||
self._environment_variables = source_workflow._environment_variables
|
||||
self._conversation_variables = source_workflow._conversation_variables
|
||||
self._rag_pipeline_variables = source_workflow._rag_pipeline_variables
|
||||
|
||||
@staticmethod
|
||||
def version_from_datetime(d: datetime) -> str:
|
||||
return str(d)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ dependencies = [
|
|||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.68",
|
||||
"boto3==1.42.73",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
|
|
@ -23,7 +23,7 @@ dependencies = [
|
|||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.192.0",
|
||||
"google-api-python-client==2.193.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-auth-httplib2==0.3.0",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
|
|
@ -40,7 +40,7 @@ dependencies = [
|
|||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
|
|
@ -72,13 +72,14 @@ dependencies = [
|
|||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.23.0",
|
||||
"sentry-sdk[flask]~=2.54.0",
|
||||
"resend~=2.26.0",
|
||||
"sentry-sdk[flask]~=2.55.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.52.1",
|
||||
"starlette==1.0.0",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"pypandoc~=1.13",
|
||||
"yarl~=1.23.0",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.9.0",
|
||||
|
|
@ -91,7 +92,7 @@ dependencies = [
|
|||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
"bleach~=6.2.0",
|
||||
"bleach~=6.3.0",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
|
@ -118,7 +119,7 @@ dev = [
|
|||
"ruff~=0.15.5",
|
||||
"pytest~=9.0.2",
|
||||
"pytest-benchmark~=5.2.3",
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-cov~=7.1.0",
|
||||
"pytest-env~=1.6.0",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.14.1",
|
||||
|
|
|
|||
|
|
@ -385,7 +385,11 @@ class BillingService:
|
|||
# Redis returns bytes, decode to string and parse JSON
|
||||
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
|
||||
plan_dict = json.loads(json_str)
|
||||
# NOTE (hj24): New billing versions may return timestamp as str, and validate_python
|
||||
# in non-strict mode will coerce it to the expected int type.
|
||||
# To preserve compatibility, always keep non-strict mode here and avoid strict mode.
|
||||
subscription_plan = subscription_adapter.validate_python(plan_dict)
|
||||
# NOTE END
|
||||
tenant_plans[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
|
|
|||
|
|
@ -79,10 +79,11 @@ from services.entities.knowledge_entities.rag_pipeline_entities import (
|
|||
KnowledgeConfiguration,
|
||||
PipelineTemplateInfoEntity,
|
||||
)
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader
|
||||
from services.workflow_restore import apply_published_workflow_snapshot_to_draft
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -234,6 +235,21 @@ class RagPipelineService:
|
|||
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""Fetch a published workflow snapshot by ID for restore operations."""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if workflow and workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError("source workflow must be published")
|
||||
return workflow
|
||||
|
||||
def get_all_published_workflow(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -327,6 +343,42 @@ class RagPipelineService:
|
|||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def restore_published_workflow_to_draft(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow_id: str,
|
||||
account: Account,
|
||||
) -> Workflow:
|
||||
"""Restore a published pipeline workflow snapshot into the draft workflow.
|
||||
|
||||
Pipelines reuse the shared draft-restore field copy helper, but still own
|
||||
the pipeline-specific flush/link step that wires a newly created draft
|
||||
back onto ``pipeline.workflow_id``.
|
||||
"""
|
||||
source_workflow = self.get_published_workflow_by_id(pipeline=pipeline, workflow_id=workflow_id)
|
||||
if not source_workflow:
|
||||
raise WorkflowNotFoundError("Workflow not found.")
|
||||
|
||||
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
source_workflow=source_workflow,
|
||||
draft_workflow=draft_workflow,
|
||||
account=account,
|
||||
updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
if is_new_draft:
|
||||
db.session.add(draft_workflow)
|
||||
db.session.flush()
|
||||
pipeline.workflow_id = draft_workflow.id
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return draft_workflow
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
|
||||
|
||||
|
|
@ -83,7 +84,7 @@ class TagService:
|
|||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
name=args["name"],
|
||||
type=args["type"],
|
||||
type=TagType(args["type"]),
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
|
|
@ -28,9 +28,16 @@ from services.tools.tools_transform_service import ToolTransformService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiSchemaParseResult(TypedDict):
|
||||
schema_type: str
|
||||
parameters_schema: list[dict[str, Any]]
|
||||
credentials_schema: list[dict[str, Any]]
|
||||
warning: dict[str, str]
|
||||
|
||||
|
||||
class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> Mapping[str, Any]:
|
||||
def parser_api_schema(schema: str) -> ApiSchemaParseResult:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
|
|
@ -71,7 +78,7 @@ class ApiToolManageService:
|
|||
]
|
||||
|
||||
return cast(
|
||||
Mapping,
|
||||
ApiSchemaParseResult,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
|||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||
from models.tools import MCPToolProvider
|
||||
|
|
@ -681,7 +682,7 @@ class MCPToolManageService:
|
|||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list[MCPTool]
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Build API response for tool provider."""
|
||||
user = db_provider.load_user()
|
||||
|
|
@ -703,7 +704,7 @@ class MCPToolManageService:
|
|||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
raise error
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format."""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
|
|
@ -34,6 +36,17 @@ class _NodeType(TypedDict):
|
|||
data: dict[str, Any]
|
||||
|
||||
|
||||
class _EdgeType(TypedDict):
|
||||
id: str
|
||||
source: str
|
||||
target: str
|
||||
|
||||
|
||||
class WorkflowGraph(TypedDict):
|
||||
nodes: list[_NodeType]
|
||||
edges: list[_EdgeType]
|
||||
|
||||
|
||||
class WorkflowConverter:
|
||||
"""
|
||||
App Convert to Workflow Mode
|
||||
|
|
@ -107,7 +120,7 @@ class WorkflowConverter:
|
|||
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
# init workflow graph
|
||||
graph: dict[str, Any] = {"nodes": [], "edges": []}
|
||||
graph: WorkflowGraph = {"nodes": [], "edges": []}
|
||||
|
||||
# Convert list:
|
||||
# - variables -> start
|
||||
|
|
@ -385,7 +398,7 @@ class WorkflowConverter:
|
|||
self,
|
||||
original_app_mode: AppMode,
|
||||
new_app_mode: AppMode,
|
||||
graph: dict,
|
||||
graph: WorkflowGraph,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: FileUploadConfig | None = None,
|
||||
|
|
@ -595,7 +608,7 @@ class WorkflowConverter:
|
|||
"data": {"title": "ANSWER", "type": BuiltinNodeTypes.ANSWER, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str):
|
||||
def _create_edge(self, source: str, target: str) -> _EdgeType:
|
||||
"""
|
||||
Create Edge
|
||||
:param source: source node id
|
||||
|
|
@ -604,7 +617,7 @@ class WorkflowConverter:
|
|||
"""
|
||||
return {"id": f"{source}-{target}", "source": source, "target": target}
|
||||
|
||||
def _append_node(self, graph: dict[str, Any], node: _NodeType):
|
||||
def _append_node(self, graph: WorkflowGraph, node: _NodeType):
|
||||
"""
|
||||
Append Node to Graph
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Any
|
|||
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
|
|
@ -14,6 +15,10 @@ from services.plugin.plugin_service import PluginService
|
|||
from services.workflow.entities import TriggerMetadata
|
||||
|
||||
|
||||
class LogViewDetails(TypedDict):
|
||||
trigger_metadata: dict[str, Any] | None
|
||||
|
||||
|
||||
# Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it
|
||||
class LogView:
|
||||
"""Lightweight wrapper for WorkflowAppLog with computed details.
|
||||
|
|
@ -22,12 +27,12 @@ class LogView:
|
|||
- Proxies all other attributes to the underlying `WorkflowAppLog`
|
||||
"""
|
||||
|
||||
def __init__(self, log: WorkflowAppLog, details: dict | None):
|
||||
def __init__(self, log: WorkflowAppLog, details: LogViewDetails | None):
|
||||
self.log = log
|
||||
self.details_ = details
|
||||
|
||||
@property
|
||||
def details(self) -> dict | None:
|
||||
def details(self) -> LogViewDetails | None:
|
||||
return self.details_
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from factories.variable_factory import build_segment, segment_to_variable
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models import Account, App, Conversation
|
||||
from models.enums import DraftVariableType
|
||||
from models.enums import ConversationFromSource, DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.file_service import FileService
|
||||
|
|
@ -601,7 +601,7 @@ class WorkflowDraftVariableService:
|
|||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
"""Shared helpers for restoring published workflow snapshots into drafts.
|
||||
|
||||
Both app workflows and RAG pipeline workflows restore the same workflow fields
|
||||
from a published snapshot into a draft. Keeping that field-copy logic in one
|
||||
place prevents the two restore paths from drifting when we add or adjust draft
|
||||
state in the future. Restore stays within a tenant, so we can safely reuse the
|
||||
serialized workflow storage blobs without decrypting and re-encrypting secrets.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
|
||||
from models import Account
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
UpdatedAtFactory = Callable[[], datetime]
|
||||
|
||||
|
||||
def apply_published_workflow_snapshot_to_draft(
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
source_workflow: Workflow,
|
||||
draft_workflow: Workflow | None,
|
||||
account: Account,
|
||||
updated_at_factory: UpdatedAtFactory,
|
||||
) -> tuple[Workflow, bool]:
|
||||
"""Copy a published workflow snapshot into a draft workflow record.
|
||||
|
||||
The caller remains responsible for source lookup, validation, flushing, and
|
||||
post-commit side effects. This helper only centralizes the shared draft
|
||||
creation/update semantics used by both restore entry points. Features are
|
||||
copied from the stored JSON payload so restore does not normalize and dirty
|
||||
the published source row before the caller commits.
|
||||
"""
|
||||
if not draft_workflow:
|
||||
workflow_type = (
|
||||
source_workflow.type.value if isinstance(source_workflow.type, WorkflowType) else source_workflow.type
|
||||
)
|
||||
draft_workflow = Workflow(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type=workflow_type,
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=source_workflow.graph,
|
||||
features=source_workflow.serialized_features,
|
||||
created_by=account.id,
|
||||
)
|
||||
draft_workflow.copy_serialized_variable_storage_from(source_workflow)
|
||||
return draft_workflow, True
|
||||
|
||||
draft_workflow.graph = source_workflow.graph
|
||||
draft_workflow.features = source_workflow.serialized_features
|
||||
draft_workflow.updated_by = account.id
|
||||
draft_workflow.updated_at = updated_at_factory()
|
||||
draft_workflow.copy_serialized_variable_storage_from(source_workflow)
|
||||
|
||||
return draft_workflow, False
|
||||
|
|
@ -63,7 +63,12 @@ from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeEx
|
|||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
TriggerNodeLimitExceededError,
|
||||
WorkflowHashNotEqualError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
|
|
@ -75,6 +80,7 @@ from .human_input_delivery_test_service import (
|
|||
HumanInputDeliveryTestService,
|
||||
)
|
||||
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
|
||||
from .workflow_restore import apply_published_workflow_snapshot_to_draft
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
|
|
@ -279,6 +285,43 @@ class WorkflowService:
|
|||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def restore_published_workflow_to_draft(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow_id: str,
|
||||
account: Account,
|
||||
) -> Workflow:
|
||||
"""Restore a published workflow snapshot into the draft workflow.
|
||||
|
||||
Secret environment variables are copied server-side from the selected
|
||||
published workflow so the normal draft sync flow stays stateless.
|
||||
"""
|
||||
source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
|
||||
if not source_workflow:
|
||||
raise WorkflowNotFoundError("Workflow not found.")
|
||||
|
||||
self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict)
|
||||
self.validate_graph_structure(graph=source_workflow.graph_dict)
|
||||
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
source_workflow=source_workflow,
|
||||
draft_workflow=draft_workflow,
|
||||
account=account,
|
||||
updated_at_factory=naive_utc_now,
|
||||
)
|
||||
|
||||
if is_new_draft:
|
||||
db.session.add(draft_workflow)
|
||||
|
||||
db.session.commit()
|
||||
app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow)
|
||||
|
||||
return draft_workflow
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from controllers.console.app import wraps
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
|
@ -154,7 +155,7 @@ class TestChatMessageApiPermissions:
|
|||
re_sign_file_url_answer="",
|
||||
answer_tokens=0,
|
||||
provider_response_latency=0.0,
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=mock_account.id,
|
||||
feedbacks=[],
|
||||
|
|
|
|||
|
|
@ -165,8 +165,9 @@ class DifyTestContainers:
|
|||
|
||||
# Start Dify Sandbox container for code execution environment
|
||||
# Dify Sandbox provides a secure environment for executing user code
|
||||
# Use pinned version 0.2.12 to match production docker-compose configuration
|
||||
logger.info("Initializing Dify Sandbox container...")
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network)
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network)
|
||||
self.dify_sandbox.with_exposed_ports(8194)
|
||||
self.dify_sandbox.env = {
|
||||
"API_KEY": "test_api_key",
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.token import _real_cookie_name, generate_csrf_token
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import ConversationFromSource, CreatorUserRole
|
||||
from models.model import App, AppMode, Conversation, Message
|
||||
from models.workflow import WorkflowRun
|
||||
from services.account_service import AccountService
|
||||
|
|
@ -75,7 +75,7 @@ def _create_conversation(db_session: Session, app_id: str, account_id: str) -> C
|
|||
inputs={},
|
||||
status="normal",
|
||||
mode=AppMode.CHAT,
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
|
|
@ -124,7 +124,7 @@ def _create_message(
|
|||
answer_price_unit=0.001,
|
||||
currency="USD",
|
||||
status="normal",
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
inputs={"query": "Hello"},
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from uuid import uuid4
|
|||
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, UserAction
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.human_input import HumanInputForm, HumanInputFormStatus
|
||||
from models.model import App, Conversation, Message
|
||||
|
|
@ -78,8 +79,8 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
|
|||
introduction="",
|
||||
system_instruction="",
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account.id,
|
||||
from_end_user_id=None,
|
||||
)
|
||||
|
|
@ -101,7 +102,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
|
|||
answer_unit_price=Decimal("0.001"),
|
||||
provider_response_latency=0.5,
|
||||
currency="USD",
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
|
|
@ -12,15 +13,26 @@ from sqlalchemy import Engine, delete, select
|
|||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from dify_graph.entities import WorkflowExecution
|
||||
from dify_graph.entities.pause_reason import PauseReasonType
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_build_human_input_required_reason,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
|
@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
|
|||
WorkflowRun.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
|
||||
form_ids_subquery = select(HumanInputForm.id).where(
|
||||
HumanInputForm.tenant_id == scope.tenant_id,
|
||||
HumanInputForm.app_id == scope.app_id,
|
||||
)
|
||||
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
|
||||
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
|
||||
session.execute(
|
||||
delete(HumanInputForm).where(
|
||||
HumanInputForm.tenant_id == scope.tenant_id,
|
||||
HumanInputForm.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for state_key in scope.state_keys:
|
||||
|
|
@ -504,3 +529,200 @@ class TestDeleteWorkflowPause:
|
|||
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Integration tests for _PrivateWorkflowPauseEntity using real DB models."""
|
||||
|
||||
def test_properties(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Entity properties delegate to the persisted WorkflowPause model."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(pause.state_object_key)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.id == pause.id
|
||||
assert entity.workflow_execution_id == workflow_run.id
|
||||
assert entity.resumed_at is None
|
||||
|
||||
def test_get_state(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""get_state loads state data from storage using the persisted state_object_key."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
state_key = f"workflow-state-{uuid4()}.json"
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_key,
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(state_key)
|
||||
|
||||
expected_state = b'{"test": "state"}'
|
||||
storage.save(state_key, expected_state)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == expected_state
|
||||
|
||||
def test_get_state_caching(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""get_state caches the result so storage is only accessed once."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
state_key = f"workflow-state-{uuid4()}.json"
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_key,
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(state_key)
|
||||
|
||||
expected_state = b'{"test": "state"}'
|
||||
storage.save(state_key, expected_state)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
result1 = entity.get_state()
|
||||
# Delete from storage to prove second call uses cache
|
||||
storage.delete(state_key)
|
||||
test_scope.state_keys.discard(state_key)
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
|
||||
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
"""Integration tests for _build_human_input_required_reason using real DB models."""
|
||||
|
||||
def test_prefers_backstage_token_when_available(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Use backstage token when multiple recipient types may exist."""
|
||||
|
||||
expiration_time = naive_utc_now()
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
|
||||
form_model = HumanInputForm(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
workflow_run_id=str(uuid4()),
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
db_session_with_containers.add(form_model)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_model.id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload="{}",
|
||||
)
|
||||
db_session_with_containers.add(delivery)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
access_token = secrets.token_urlsafe(8)
|
||||
recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
db_session_with_containers.add(recipient)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create a pause so the reason has a valid pause_id
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.flush()
|
||||
test_scope.state_keys.add(pause.state_object_key)
|
||||
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id=pause.id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=form_model.id,
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
db_session_with_containers.add(reason_model)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh to ensure we have DB-round-tripped objects
|
||||
db_session_with_containers.refresh(form_model)
|
||||
db_session_with_containers.refresh(reason_model)
|
||||
db_session_with_containers.refresh(recipient)
|
||||
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [recipient])
|
||||
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.form_token == access_token
|
||||
assert reason.node_title == "Ask Name"
|
||||
assert reason.form_content == "content"
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,391 @@
|
|||
"""Integration tests for get_paginated_workflow_runs and get_workflow_runs_count using testcontainers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from dify_graph.entities import WorkflowExecution
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun, WorkflowType
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
|
||||
class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Concrete repository for tests where save() is not under test."""
|
||||
|
||||
def save(self, execution: WorkflowExecution) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestScope:
|
||||
"""Per-test data scope used to isolate DB rows."""
|
||||
|
||||
tenant_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
app_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
workflow_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
user_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
def _create_workflow_run(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
*,
|
||||
status: WorkflowExecutionStatus,
|
||||
triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
created_at_offset: timedelta | None = None,
|
||||
) -> WorkflowRun:
|
||||
"""Create and persist a workflow run bound to the current test scope."""
|
||||
now = naive_utc_now()
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(uuid4()),
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_id=scope.workflow_id,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=triggered_from,
|
||||
version="draft",
|
||||
graph="{}",
|
||||
inputs="{}",
|
||||
status=status,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=scope.user_id,
|
||||
created_at=now + created_at_offset if created_at_offset is not None else now,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
|
||||
"""Remove test-created DB rows for a test scope."""
|
||||
session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == scope.tenant_id,
|
||||
WorkflowRun.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Build a repository backed by the testcontainers database engine."""
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_scope(db_session_with_containers: Session) -> _TestScope:
|
||||
"""Provide an isolated scope and clean related data after each test."""
|
||||
scope = _TestScope()
|
||||
yield scope
|
||||
_cleanup_scope_data(db_session_with_containers, scope)
|
||||
|
||||
|
||||
class TestGetPaginatedWorkflowRuns:
|
||||
"""Integration tests for get_paginated_workflow_runs."""
|
||||
|
||||
def test_returns_runs_without_status_filter(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Return all runs for the given tenant/app when no status filter is applied."""
|
||||
for status in (
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.RUNNING,
|
||||
):
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=status)
|
||||
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
assert len(result.data) == 3
|
||||
assert result.limit == 20
|
||||
assert result.has_more is False
|
||||
|
||||
def test_filters_by_status(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Return only runs matching the requested status."""
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
|
||||
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
assert len(result.data) == 2
|
||||
assert all(run.status == WorkflowExecutionStatus.SUCCEEDED for run in result.data)
|
||||
|
||||
def test_pagination_has_more(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Return has_more=True when more records exist beyond the limit."""
|
||||
for i in range(5):
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_offset=timedelta(seconds=i),
|
||||
)
|
||||
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=3,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
assert len(result.data) == 3
|
||||
assert result.has_more is True
|
||||
|
||||
def test_cursor_based_pagination(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Cursor-based pagination returns the next page of results."""
|
||||
for i in range(5):
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_offset=timedelta(seconds=i),
|
||||
)
|
||||
|
||||
# First page
|
||||
page1 = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=3,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
assert len(page1.data) == 3
|
||||
assert page1.has_more is True
|
||||
|
||||
# Second page using cursor
|
||||
page2 = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=3,
|
||||
last_id=page1.data[-1].id,
|
||||
status=None,
|
||||
)
|
||||
assert len(page2.data) == 2
|
||||
assert page2.has_more is False
|
||||
|
||||
# No overlap between pages
|
||||
page1_ids = {r.id for r in page1.data}
|
||||
page2_ids = {r.id for r in page2.data}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
|
||||
def test_invalid_last_id_raises(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Raise ValueError when last_id refers to a non-existent run."""
|
||||
with pytest.raises(ValueError, match="Last workflow run not exists"):
|
||||
repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=str(uuid4()),
|
||||
status=None,
|
||||
)
|
||||
|
||||
def test_tenant_isolation(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Runs from other tenants are not returned."""
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
|
||||
other_scope = _TestScope(app_id=test_scope.app_id)
|
||||
try:
|
||||
_create_workflow_run(db_session_with_containers, other_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].tenant_id == test_scope.tenant_id
|
||||
finally:
|
||||
_cleanup_scope_data(db_session_with_containers, other_scope)
|
||||
|
||||
|
||||
class TestGetWorkflowRunsCount:
|
||||
"""Integration tests for get_workflow_runs_count."""
|
||||
|
||||
def test_count_without_status_filter(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Count all runs grouped by status when no status filter is applied."""
|
||||
for _ in range(3):
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
for _ in range(2):
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.RUNNING)
|
||||
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
)
|
||||
|
||||
assert result["total"] == 6
|
||||
assert result["succeeded"] == 3
|
||||
assert result["failed"] == 2
|
||||
assert result["running"] == 1
|
||||
assert result["stopped"] == 0
|
||||
assert result["partial-succeeded"] == 0
|
||||
|
||||
def test_count_with_status_filter(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Count only runs matching the requested status."""
|
||||
for _ in range(3):
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
|
||||
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
assert result["total"] == 3
|
||||
assert result["succeeded"] == 3
|
||||
assert result["failed"] == 0
|
||||
|
||||
def test_count_with_invalid_status_raises(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Invalid status raises StatementError because the column uses an enum type."""
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
|
||||
with pytest.raises(sa_exc.StatementError) as exc_info:
|
||||
repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="invalid_status",
|
||||
)
|
||||
assert isinstance(exc_info.value.orig, ValueError)
|
||||
|
||||
def test_count_with_time_range(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Time range filter excludes runs created outside the window."""
|
||||
# Recent run (within 1 day)
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
# Old run (8 days ago)
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_offset=timedelta(days=-8),
|
||||
)
|
||||
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
assert result["total"] == 1
|
||||
assert result["succeeded"] == 1
|
||||
|
||||
def test_count_with_status_and_time_range(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Both status and time_range filters apply together."""
|
||||
# Recent succeeded
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED)
|
||||
# Recent failed
|
||||
_create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED)
|
||||
# Old succeeded (outside time range)
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
created_at_offset=timedelta(days=-8),
|
||||
)
|
||||
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="succeeded",
|
||||
time_range="7d",
|
||||
)
|
||||
|
||||
assert result["total"] == 1
|
||||
assert result["succeeded"] == 1
|
||||
assert result["failed"] == 0
|
||||
|
|
@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models import Account
|
||||
from models.enums import MessageFileBelongsTo
|
||||
from models.enums import ConversationFromSource, MessageFileBelongsTo
|
||||
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.agent_service import AgentService
|
||||
|
|
@ -165,7 +165,7 @@ class TestAgentService:
|
|||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -204,7 +204,7 @@ class TestAgentService:
|
|||
answer_unit_price=0.001,
|
||||
provider_response_latency=1.5,
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -406,7 +406,7 @@ class TestAgentService:
|
|||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -445,7 +445,7 @@ class TestAgentService:
|
|||
answer_unit_price=0.001,
|
||||
provider_response_latency=1.5,
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -478,7 +478,7 @@ class TestAgentService:
|
|||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -517,7 +517,7 @@ class TestAgentService:
|
|||
answer_unit_price=0.001,
|
||||
provider_response_latency=1.5,
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -624,7 +624,7 @@ class TestAgentService:
|
|||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
app_model_config_id=None, # Explicitly set to None
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
|
|
@ -647,7 +647,7 @@ class TestAgentService:
|
|||
answer_unit_price=0.001,
|
||||
provider_response_latency=1.5,
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models import Account
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.model import MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.app_service import AppService
|
||||
|
|
@ -136,8 +137,8 @@ class TestAnnotationService:
|
|||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
|
@ -174,8 +175,8 @@ class TestAnnotationService:
|
|||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
|
@ -721,7 +722,7 @@ class TestAnnotationService:
|
|||
query=f"Query {i}: {fake.sentence()}",
|
||||
user_id=account.id,
|
||||
message_id=fake.uuid4(),
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
score=0.8 + (i * 0.1),
|
||||
)
|
||||
|
||||
|
|
@ -772,7 +773,7 @@ class TestAnnotationService:
|
|||
query=query,
|
||||
user_id=account.id,
|
||||
message_id=message_id,
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
score=score,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
"""Testcontainers integration tests for AttachmentService."""
|
||||
|
||||
import base64
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.attachment_service as attachment_service_module
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.attachment_service import AttachmentService
|
||||
|
||||
|
||||
class TestAttachmentService:
|
||||
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
storage_type=StorageType.OPENDAL,
|
||||
key=f"upload/{uuid4()}.txt",
|
||||
name="test-file.txt",
|
||||
size=100,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
return upload_file
|
||||
|
||||
def test_should_initialize_with_sessionmaker(self):
|
||||
session_factory = sessionmaker()
|
||||
|
||||
service = AttachmentService(session_factory=session_factory)
|
||||
|
||||
assert service._session_maker is session_factory
|
||||
|
||||
def test_should_initialize_with_engine(self):
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
service = AttachmentService(session_factory=engine)
|
||||
session = service._session_maker()
|
||||
try:
|
||||
assert session.bind == engine
|
||||
finally:
|
||||
session.close()
|
||||
engine.dispose()
|
||||
|
||||
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
|
||||
def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory):
|
||||
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
|
||||
AttachmentService(session_factory=invalid_session_factory)
|
||||
|
||||
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
|
||||
upload_file = self._create_upload_file(db_session_with_containers)
|
||||
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
|
||||
result = service.get_file_base64(upload_file.id)
|
||||
|
||||
assert result == base64.b64encode(b"binary-content").decode()
|
||||
mock_load.assert_called_once_with(upload_file.key)
|
||||
|
||||
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
|
||||
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
service.get_file_base64(str(uuid4()))
|
||||
|
||||
mock_load.assert_not_called()
|
||||
|
|
@ -10,6 +10,7 @@ from sqlalchemy import select
|
|||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.conversation_service import ConversationService
|
||||
|
|
@ -107,7 +108,7 @@ class ConversationServiceIntegrationTestDataFactory:
|
|||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=invoke_from.value,
|
||||
from_source="api" if isinstance(user, EndUser) else "console",
|
||||
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
dialogue_count=0,
|
||||
|
|
@ -154,7 +155,7 @@ class ConversationServiceIntegrationTestDataFactory:
|
|||
currency="USD",
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api" if isinstance(user, EndUser) else "console",
|
||||
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
"""Testcontainers integration tests for ConversationVariableUpdater."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from dify_graph.variables import StringVariable
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
|
||||
|
||||
|
||||
class TestConversationVariableUpdater:
|
||||
def _create_conversation_variable(
|
||||
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
|
||||
) -> ConversationVariable:
|
||||
row = ConversationVariable(
|
||||
id=variable.id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id or str(uuid4()),
|
||||
data=variable.model_dump_json(),
|
||||
)
|
||||
db_session_with_containers.add(row)
|
||||
db_session_with_containers.commit()
|
||||
return row
|
||||
|
||||
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
|
||||
conversation_id = str(uuid4())
|
||||
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
|
||||
self._create_conversation_variable(
|
||||
db_session_with_containers, conversation_id=conversation_id, variable=variable
|
||||
)
|
||||
|
||||
updated_variable = StringVariable(id=variable.id, name="topic", value="new value")
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
updater.update(conversation_id=conversation_id, variable=updated_variable)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id))
|
||||
assert row is not None
|
||||
assert row.data == updated_variable.model_dump_json()
|
||||
|
||||
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
|
||||
conversation_id = str(uuid4())
|
||||
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
result = updater.flush()
|
||||
|
||||
assert result is None
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
"""Testcontainers integration tests for CreditPoolService."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
class TestCreditPoolService:
|
||||
def _create_tenant_id(self) -> str:
|
||||
return str(uuid4())
|
||||
|
||||
def test_create_default_pool(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
assert isinstance(pool, TenantCreditPool)
|
||||
assert pool.tenant_id == tenant_id
|
||||
assert pool.pool_type == "trial"
|
||||
assert pool.quota_used == 0
|
||||
assert pool.quota_limit > 0
|
||||
|
||||
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial")
|
||||
|
||||
assert result is not None
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.pool_type == "trial"
|
||||
|
||||
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers):
|
||||
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
# Exhaust credits
|
||||
pool.quota_used = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers):
|
||||
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
pool.quota_used = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with pytest.raises(QuotaExceededError, match="No credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
|
||||
|
||||
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
credits_required = 10
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required)
|
||||
|
||||
assert result == credits_required
|
||||
db_session_with_containers.expire_all()
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert pool.quota_used == credits_required
|
||||
|
||||
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
|
||||
assert result == remaining
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == pool.quota_limit
|
||||
|
|
@ -397,6 +397,68 @@ class TestDatasetPermissionServiceClearPartialMemberList:
|
|||
class TestDatasetServiceCheckDatasetPermission:
|
||||
"""Verify dataset access checks against persisted partial-member permissions."""
|
||||
|
||||
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers):
|
||||
"""Test that users from different tenants cannot access dataset."""
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, other_user)
|
||||
|
||||
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers):
|
||||
"""Test that tenant owners can access any dataset regardless of permission level."""
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL, tenant=tenant
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, owner)
|
||||
|
||||
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers):
|
||||
"""Test ONLY_ME permission allows only the dataset creator to access."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, creator)
|
||||
|
||||
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers):
|
||||
"""Test ONLY_ME permission denies access to non-creators."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL, tenant=tenant
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, other)
|
||||
|
||||
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers):
|
||||
"""Test ALL_TEAM permission allows any team member to access the dataset."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL, tenant=tenant
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, member)
|
||||
|
||||
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
|
||||
"""
|
||||
Test that user with explicit permission can access partial_members dataset.
|
||||
|
|
@ -443,6 +505,16 @@ class TestDatasetServiceCheckDatasetPermission:
|
|||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access without explicit permission."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, creator)
|
||||
|
||||
|
||||
class TestDatasetServiceCheckDatasetOperatorPermission:
|
||||
"""Verify operator permission checks against persisted partial-member permissions."""
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
|
|
@ -94,7 +94,7 @@ class TestAppMessageExportServiceIntegration:
|
|||
name="conv",
|
||||
inputs={"seed": 1},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(conversation)
|
||||
|
|
@ -129,7 +129,7 @@ class TestAppMessageExportServiceIntegration:
|
|||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
message_metadata=message_metadata,
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import FeedbackRating
|
||||
from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom
|
||||
from models.model import MessageFeedback
|
||||
from services.app_service import AppService
|
||||
from services.errors.message import (
|
||||
|
|
@ -149,8 +149,8 @@ class TestMessageService:
|
|||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
|
@ -187,8 +187,8 @@ class TestMessageService:
|
|||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from decimal import Decimal
|
|||
|
||||
import pytest
|
||||
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import Message
|
||||
from services import message_service
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
|
|
@ -36,7 +37,7 @@ def test_attach_message_extra_contents_assigns_serialized_payload(db_session_wit
|
|||
total_price=Decimal(0),
|
||||
currency="USD",
|
||||
status="normal",
|
||||
from_source="console",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=fixture.account.id,
|
||||
)
|
||||
db_session_with_containers.add(message_without_extra_content)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,14 @@ from sqlalchemy.orm import Session
|
|||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
|
||||
from models.enums import (
|
||||
ConversationFromSource,
|
||||
DataSourceType,
|
||||
FeedbackFromSource,
|
||||
FeedbackRating,
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
)
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
|
|
@ -166,7 +173,7 @@ class TestMessagesCleanServiceIntegration:
|
|||
name="Test conversation",
|
||||
inputs={},
|
||||
status="normal",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
|
|
@ -196,7 +203,7 @@ class TestMessagesCleanServiceIntegration:
|
|||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_source=ConversationFromSource.API,
|
||||
from_account_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.app_service import AppService
|
||||
|
|
@ -132,11 +133,14 @@ class TestSavedMessageService:
|
|||
# Create a simple conversation first
|
||||
from models.model import Conversation
|
||||
|
||||
is_account = hasattr(user, "current_tenant")
|
||||
from_source = ConversationFromSource.CONSOLE if is_account else ConversationFromSource.API
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
from_source="account" if hasattr(user, "current_tenant") else "end_user",
|
||||
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
|
||||
from_account_id=user.id if hasattr(user, "current_tenant") else None,
|
||||
from_source=from_source,
|
||||
from_end_user_id=user.id if not is_account else None,
|
||||
from_account_id=user.id if is_account else None,
|
||||
name=fake.sentence(nb_words=3),
|
||||
inputs={},
|
||||
status="normal",
|
||||
|
|
@ -150,9 +154,9 @@ class TestSavedMessageService:
|
|||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
from_source="account" if hasattr(user, "current_tenant") else "end_user",
|
||||
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
|
||||
from_account_id=user.id if hasattr(user, "current_tenant") else None,
|
||||
from_source=from_source,
|
||||
from_end_user_id=user.id if not is_account else None,
|
||||
from_account_id=user.id if is_account else None,
|
||||
inputs={},
|
||||
query=fake.sentence(nb_words=5),
|
||||
message=fake.text(max_nb_chars=100),
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from models.enums import DataSourceType
|
||||
from models.enums import DataSourceType, TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
|
@ -547,7 +547,7 @@ class TestTagService:
|
|||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "python_tag"
|
||||
assert result[0].type == "app"
|
||||
assert result[0].type == TagType.APP
|
||||
assert result[0].tenant_id == tenant.id
|
||||
|
||||
def test_get_tag_by_tag_name_no_matches(
|
||||
|
|
@ -638,7 +638,7 @@ class TestTagService:
|
|||
|
||||
# Verify all tags are returned
|
||||
for tag in result:
|
||||
assert tag.type == "app"
|
||||
assert tag.type == TagType.APP
|
||||
assert tag.tenant_id == tenant.id
|
||||
assert tag.id in [t.id for t in tags]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import Account
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import Conversation, EndUser
|
||||
from models.web import PinnedConversation
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
|
@ -145,7 +146,7 @@ class TestWebConversationService:
|
|||
system_instruction_tokens=50,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_source="console" if isinstance(user, Account) else "api",
|
||||
from_source=ConversationFromSource.CONSOLE if isinstance(user, Account) else ConversationFromSource.API,
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
dialogue_count=0,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import ConversationFromSource, CreatorUserRole
|
||||
from models.model import (
|
||||
Message,
|
||||
)
|
||||
|
|
@ -165,7 +165,7 @@ class TestWorkflowRunService:
|
|||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
|
|
@ -186,7 +186,7 @@ class TestWorkflowRunService:
|
|||
message.answer_price_unit = 0.001
|
||||
message.currency = "USD"
|
||||
message.status = "normal"
|
||||
message.from_source = CreatorUserRole.ACCOUNT
|
||||
message.from_source = ConversationFromSource.CONSOLE
|
||||
message.from_account_id = account.id
|
||||
message.workflow_run_id = workflow_run.id
|
||||
message.inputs = {"input": "test input"}
|
||||
|
|
|
|||
|
|
@ -802,6 +802,81 @@ class TestWorkflowService:
|
|||
with pytest.raises(ValueError, match="No valid workflow found"):
|
||||
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account)
|
||||
|
||||
def test_restore_published_workflow_to_draft_does_not_persist_normalized_source_features(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
"""Restore copies legacy feature JSON into draft without rewriting the source row."""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
legacy_features = {
|
||||
"file_upload": {
|
||||
"image": {
|
||||
"enabled": True,
|
||||
"number_limits": 6,
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
"opening_statement": "",
|
||||
"retriever_resource": {"enabled": True},
|
||||
"sensitive_word_avoidance": {"enabled": False},
|
||||
"speech_to_text": {"enabled": False},
|
||||
"suggested_questions": [],
|
||||
"suggested_questions_after_answer": {"enabled": False},
|
||||
"text_to_speech": {"enabled": False, "language": "", "voice": ""},
|
||||
}
|
||||
published_workflow = Workflow(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version="2026.03.19.001",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
features=json.dumps(legacy_features),
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
draft_workflow = Workflow(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(published_workflow)
|
||||
db_session_with_containers.add(draft_workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
restored_workflow = workflow_service.restore_published_workflow_to_draft(
|
||||
app_model=app,
|
||||
workflow_id=published_workflow.id,
|
||||
account=account,
|
||||
)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
refreshed_published_workflow = (
|
||||
db_session_with_containers.query(Workflow).filter_by(id=published_workflow.id).first()
|
||||
)
|
||||
refreshed_draft_workflow = db_session_with_containers.query(Workflow).filter_by(id=draft_workflow.id).first()
|
||||
|
||||
assert restored_workflow.id == draft_workflow.id
|
||||
assert refreshed_published_workflow is not None
|
||||
assert refreshed_draft_workflow is not None
|
||||
assert refreshed_published_workflow.serialized_features == json.dumps(legacy_features)
|
||||
assert refreshed_draft_workflow.serialized_features == json.dumps(legacy_features)
|
||||
|
||||
def test_get_default_block_configs(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieval of default block configurations for all node types.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,158 @@
|
|||
"""Testcontainers integration tests for WorkflowService.delete_workflow."""
|
||||
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
|
||||
class TestWorkflowDeletion:
|
||||
def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]:
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"wf_del_{uuid4()}@example.com",
|
||||
password="hashed",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role="owner",
|
||||
current=True,
|
||||
)
|
||||
session.add(join)
|
||||
session.flush()
|
||||
return tenant, account
|
||||
|
||||
def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App:
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"App {uuid4()}",
|
||||
description="",
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="bot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
return app
|
||||
|
||||
def _create_workflow(
|
||||
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0"
|
||||
) -> Workflow:
|
||||
workflow = Workflow(
|
||||
id=str(uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version=version,
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
_features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(workflow)
|
||||
session.flush()
|
||||
return workflow
|
||||
|
||||
def _create_tool_provider(
|
||||
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str
|
||||
) -> WorkflowToolProvider:
|
||||
provider = WorkflowToolProvider(
|
||||
name=f"tool-{uuid4()}",
|
||||
label=f"Tool {uuid4()}",
|
||||
icon="wrench",
|
||||
app_id=app.id,
|
||||
version=version,
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
description="test tool provider",
|
||||
)
|
||||
session.add(provider)
|
||||
session.flush()
|
||||
return provider
|
||||
|
||||
def test_delete_workflow_success(self, db_session_with_containers):
|
||||
tenant, account = self._create_tenant_and_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
|
||||
workflow = self._create_workflow(
|
||||
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
workflow_id = workflow.id
|
||||
|
||||
service = WorkflowService(sessionmaker(bind=db.engine))
|
||||
result = service.delete_workflow(
|
||||
session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
db_session_with_containers.expire_all()
|
||||
assert db_session_with_containers.get(Workflow, workflow_id) is None
|
||||
|
||||
def test_delete_draft_workflow_raises_error(self, db_session_with_containers):
|
||||
tenant, account = self._create_tenant_and_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
|
||||
workflow = self._create_workflow(
|
||||
db_session_with_containers, tenant=tenant, app=app, account=account, version="draft"
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = WorkflowService(sessionmaker(bind=db.engine))
|
||||
with pytest.raises(DraftWorkflowDeletionError):
|
||||
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
|
||||
|
||||
def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers):
|
||||
tenant, account = self._create_tenant_and_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
|
||||
workflow = self._create_workflow(
|
||||
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
|
||||
)
|
||||
# Point app to this workflow
|
||||
app.workflow_id = workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = WorkflowService(sessionmaker(bind=db.engine))
|
||||
with pytest.raises(WorkflowInUseError, match="currently in use by app"):
|
||||
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
|
||||
|
||||
def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers):
|
||||
tenant, account = self._create_tenant_and_account(db_session_with_containers)
|
||||
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
|
||||
workflow = self._create_workflow(
|
||||
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
|
||||
)
|
||||
self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0")
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = WorkflowService(sessionmaker(bind=db.engine))
|
||||
with pytest.raises(WorkflowInUseError, match="published as a tool"):
|
||||
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.message import (
|
||||
ChatMessageListApi,
|
||||
ChatMessagesQuery,
|
||||
FeedbackExportQuery,
|
||||
MessageAnnotationCountApi,
|
||||
MessageApi,
|
||||
MessageFeedbackApi,
|
||||
MessageFeedbackExportApi,
|
||||
MessageFeedbackPayload,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from models import App, AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
import contextlib
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None
|
||||
):
|
||||
with (
|
||||
patch("extensions.ext_database.db") as mock_db,
|
||||
patch("controllers.console.app.wraps.db", mock_db),
|
||||
patch("controllers.console.wraps.db", mock_db),
|
||||
patch("controllers.console.app.message.db", mock_db),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
# Set up a generic query mock that usually returns mock_app_model when getting app
|
||||
app_query_mock = MagicMock()
|
||||
app_query_mock.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
|
||||
data_query_mock = MagicMock()
|
||||
|
||||
def query_side_effect(*args, **kwargs):
|
||||
if args and hasattr(args[0], "__name__") and args[0].__name__ == "App":
|
||||
return app_query_mock
|
||||
return data_query_mock
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
mock_db.data_query = data_query_mock
|
||||
|
||||
# Let the caller override the stat db query logic
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
|
||||
full_path = f"{route_path}?{query_string}" if qs else route_path
|
||||
|
||||
with (
|
||||
patch("libs.login.current_user", proxy_mock),
|
||||
patch("flask_login.current_user", proxy_mock),
|
||||
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
|
||||
):
|
||||
with test_app.test_request_context(full_path, method=method, json=payload):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
|
||||
if "suggested-questions" in route_path:
|
||||
# simplistic extraction for message_id
|
||||
parts = route_path.split("chat-messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
elif "messages/" in route_path and "chat-messages" not in route_path:
|
||||
parts = route_path.split("messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
|
||||
# Check if it has a dispatch_request or method
|
||||
if hasattr(api_instance, method.lower()):
|
||||
yield api_instance, mock_db, request.view_args
|
||||
|
||||
|
||||
class TestMessageValidators:
|
||||
def test_chat_messages_query_validators(self):
|
||||
# Test empty_to_none
|
||||
assert ChatMessagesQuery.empty_to_none("") is None
|
||||
assert ChatMessagesQuery.empty_to_none("val") == "val"
|
||||
|
||||
# Test validate_uuid
|
||||
assert ChatMessagesQuery.validate_uuid(None) is None
|
||||
assert (
|
||||
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_message_feedback_validators(self):
|
||||
assert (
|
||||
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_feedback_export_validators(self):
|
||||
assert FeedbackExportQuery.parse_bool(None) is None
|
||||
assert FeedbackExportQuery.parse_bool(True) is True
|
||||
assert FeedbackExportQuery.parse_bool("1") is True
|
||||
assert FeedbackExportQuery.parse_bool("0") is False
|
||||
assert FeedbackExportQuery.parse_bool("off") is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackExportQuery.parse_bool("invalid")
|
||||
|
||||
|
||||
class TestMessageEndpoints:
|
||||
def test_chat_message_list_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.get(**v_args)
|
||||
|
||||
def test_chat_message_list_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
# scalar() is called twice: first for conversation lookup, second for has_more check
|
||||
mock_db.session.scalar.side_effect = [mock_conv, False]
|
||||
scalars_result = MagicMock()
|
||||
scalars_result.all.return_value = [mock_msg]
|
||||
mock_db.session.scalars.return_value = scalars_result
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["limit"] == 1
|
||||
assert resp["has_more"] is False
|
||||
assert len(resp["data"]) == 1
|
||||
|
||||
def test_message_feedback_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageFeedbackApi,
|
||||
"/apps/app_123/feedbacks",
|
||||
"POST",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.post(**v_args)
|
||||
|
||||
def test_message_feedback_success(self, app, mock_account, mock_app_model):
|
||||
payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"}
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.admin_feedback = None
|
||||
mock_db.session.scalar.return_value = mock_msg
|
||||
|
||||
resp = api.post(**v_args)
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_message_annotation_count(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.session.scalar.return_value = 5
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"count": 5}
|
||||
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"]
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"data": ["q1", "q2"]}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected_exc"),
|
||||
[
|
||||
(MessageNotExistsError, NotFound),
|
||||
(ConversationNotExistsError, NotFound),
|
||||
(ProviderTokenNotInitError, ProviderNotInitializeError),
|
||||
(QuotaExceededError, ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError),
|
||||
(SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError),
|
||||
(Exception, InternalServerError),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_errors(
|
||||
self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc()
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
with pytest.raises(expected_exc):
|
||||
api.get(**v_args)
|
||||
|
||||
@patch("services.feedback_service.FeedbackService.export_feedbacks")
|
||||
def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model):
|
||||
mock_export.return_value = {"exported": True}
|
||||
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"exported": True}
|
||||
|
||||
def test_message_api_get_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
mock_db.session.scalar.return_value = mock_msg
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["id"] == "msg_123"
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.statistic import (
|
||||
AverageResponseTimeStatistic,
|
||||
AverageSessionInteractionStatistic,
|
||||
DailyConversationStatistic,
|
||||
DailyMessageStatistic,
|
||||
DailyTerminalsStatistic,
|
||||
DailyTokenCostStatistic,
|
||||
TokensPerSecondStatistic,
|
||||
UserSatisfactionRateStatistic,
|
||||
)
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None)
|
||||
):
|
||||
with (
|
||||
patch("controllers.console.app.statistic.db") as mock_db_stat,
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch(
|
||||
"controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123")
|
||||
),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute.return_value = mock_rs
|
||||
|
||||
mock_begin = MagicMock()
|
||||
mock_begin.__enter__.return_value = mock_conn
|
||||
mock_db_stat.engine.begin.return_value = mock_begin
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method="GET"):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
api_instance = endpoint_class()
|
||||
response = api_instance.get(app_id="app_123")
|
||||
return response
|
||||
|
||||
|
||||
class TestStatisticEndpoints:
|
||||
def test_daily_message_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["message_count"] == 10
|
||||
|
||||
def test_daily_conversation_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.conversation_count = 5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyConversationStatistic,
|
||||
"/apps/app_123/statistics/daily-conversations",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["conversation_count"] == 5
|
||||
|
||||
def test_daily_terminals_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.terminal_count = 2
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTerminalsStatistic,
|
||||
"/apps/app_123/statistics/daily-end-users",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["terminal_count"] == 2
|
||||
|
||||
def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.token_count = 100
|
||||
mock_row.total_price = Decimal("0.02")
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyTokenCostStatistic,
|
||||
"/apps/app_123/statistics/token-costs",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["token_count"] == 100
|
||||
assert response.json["data"][0]["total_price"] == "0.02"
|
||||
|
||||
def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.interactions = Decimal("3.523")
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageSessionInteractionStatistic,
|
||||
"/apps/app_123/statistics/average-session-interactions",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["interactions"] == 3.52
|
||||
|
||||
def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.message_count = 100
|
||||
mock_row.feedback_count = 10
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
UserSatisfactionRateStatistic,
|
||||
"/apps/app_123/statistics/user-satisfaction-rate",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["rate"] == 100.0
|
||||
|
||||
def test_average_response_time_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.latency = 1.234
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
AverageResponseTimeStatistic,
|
||||
"/apps/app_123/statistics/average-response-time",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["latency"] == 1234.0
|
||||
|
||||
def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model):
|
||||
mock_row = MagicMock()
|
||||
mock_row.date = "2023-01-01"
|
||||
mock_row.tokens_per_second = 15.5
|
||||
mock_row.interactions = Decimal(0)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)):
|
||||
response = setup_test_context(
|
||||
app,
|
||||
TokensPerSecondStatistic,
|
||||
"/apps/app_123/statistics/tokens-per-second",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[mock_row],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json["data"][0]["tps"] == 15.5
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model):
|
||||
mock_parse.side_effect = ValueError("Invalid time")
|
||||
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=invalid&end=invalid",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.statistic.parse_time_range")
|
||||
def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model):
|
||||
import datetime
|
||||
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
mock_parse.return_value = (start, end)
|
||||
|
||||
response = setup_test_context(
|
||||
app,
|
||||
DailyMessageStatistic,
|
||||
"/apps/app_123/statistics/daily-messages?start=something&end=something",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
[],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_parse.assert_called_once()
|
||||
|
|
@ -129,6 +129,136 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch)
|
|||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
unique_hash="restored-hash",
|
||||
updated_at=None,
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
user = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
lambda: SimpleNamespace(restore_published_workflow_to_draft=lambda **_kwargs: workflow),
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/published-workflow/restore",
|
||||
method="POST",
|
||||
):
|
||||
response = handler(
|
||||
api,
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
|
||||
workflow_id="published-workflow",
|
||||
)
|
||||
|
||||
assert response["result"] == "success"
|
||||
assert response["hash"] == "restored-hash"
|
||||
|
||||
|
||||
def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
lambda: SimpleNamespace(
|
||||
restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw(
|
||||
workflow_module.WorkflowNotFoundError("Workflow not found")
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/published-workflow/restore",
|
||||
method="POST",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(
|
||||
api,
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
|
||||
workflow_id="published-workflow",
|
||||
)
|
||||
|
||||
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
lambda: SimpleNamespace(
|
||||
restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw(
|
||||
workflow_module.IsDraftWorkflowError(
|
||||
"Cannot use draft workflow version. Workflow ID: draft-workflow. "
|
||||
"Please use a published workflow version or leave workflow_id empty."
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft-workflow/restore",
|
||||
method="POST",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(
|
||||
api,
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
|
||||
workflow_id="draft-workflow",
|
||||
)
|
||||
|
||||
assert exc.value.code == 400
|
||||
assert exc.value.description == workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE
|
||||
|
||||
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
lambda: SimpleNamespace(
|
||||
restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw(
|
||||
ValueError("invalid workflow graph")
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/published-workflow/restore",
|
||||
method="POST",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(
|
||||
api,
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
|
||||
workflow_id="published-workflow",
|
||||
)
|
||||
|
||||
assert exc.value.code == 400
|
||||
assert exc.value.description == "invalid workflow graph"
|
||||
|
||||
|
||||
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,313 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
ConversationVariableCollectionApi,
|
||||
EnvironmentVariableCollectionApi,
|
||||
NodeVariableCollectionApi,
|
||||
SystemVariableCollectionApi,
|
||||
VariableApi,
|
||||
VariableResetApi,
|
||||
WorkflowVariableCollectionApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from models import App, AppMode
|
||||
from models.enums import DraftVariableType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None):
|
||||
with (
|
||||
patch("controllers.console.app.wraps.db") as mock_db_wraps,
|
||||
patch("controllers.console.wraps.db", mock_db_wraps),
|
||||
patch("controllers.console.app.workflow_draft_variable.db"),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.first.return_value = mock_app_model
|
||||
mock_query.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
mock_db_wraps.session.query.return_value = mock_query
|
||||
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
with test_app.test_request_context(route_path, method=method, json=payload):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
# extract node_id or variable_id from path manually since view_args overrides
|
||||
if "nodes/" in route_path:
|
||||
request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0]
|
||||
if "variables/" in route_path:
|
||||
# simplistic extraction
|
||||
parts = route_path.split("variables/")
|
||||
if len(parts) > 1 and parts[1] and parts[1] != "reset":
|
||||
request.view_args["variable_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
# we just call dispatch_request to avoid manual argument passing
|
||||
if hasattr(api_instance, method.lower()):
|
||||
func = getattr(api_instance, method.lower())
|
||||
return func(**request.view_args)
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
@staticmethod
|
||||
def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock:
|
||||
class DummyValueType:
|
||||
def exposed_type(self):
|
||||
return DraftVariableType.NODE
|
||||
|
||||
mock_var = MagicMock()
|
||||
mock_var.app_id = "app_123"
|
||||
mock_var.id = "var_123"
|
||||
mock_var.name = "test_var"
|
||||
mock_var.description = ""
|
||||
mock_var.get_variable_type.return_value = variable_type
|
||||
mock_var.get_selector.return_value = []
|
||||
mock_var.value_type = DummyValueType()
|
||||
mock_var.edited = False
|
||||
mock_var.visible = True
|
||||
mock_var.file_id = None
|
||||
mock_var.variable_file = None
|
||||
mock_var.is_truncated.return_value = False
|
||||
mock_var.get_value.return_value.model_copy.return_value.value = "test_value"
|
||||
return mock_var
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_get_success(
|
||||
self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_wf_srv.return_value.is_workflow_exist.return_value = True
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList(
|
||||
variables=[], total=0
|
||||
)
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables?page=1&limit=20",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": [], "total": 0}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.is_workflow_exist.return_value = False
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/variables",
|
||||
"DELETE",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/node_123/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/sys/variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
NodeVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/nodes/node_123/variables",
|
||||
"DELETE",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
|
||||
)
|
||||
assert resp["id"] == "var_123"
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = None
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model
|
||||
)
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
VariableApi,
|
||||
"/apps/app_123/workflows/draft/variables/var_123",
|
||||
"PATCH",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"name": "new_name"},
|
||||
)
|
||||
assert resp["id"] == "var_123"
|
||||
mock_draft_srv.return_value.update_variable.assert_called_once()
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
|
||||
resp = setup_test_context(
|
||||
app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
mock_draft_srv.return_value.delete_variable.assert_called_once()
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
|
||||
mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable()
|
||||
mock_draft_srv.return_value.reset_variable.return_value = None # means no content
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
VariableResetApi,
|
||||
"/apps/app_123/workflows/draft/variables/var_123/reset",
|
||||
"PUT",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock()
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
ConversationVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/conversation-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService")
|
||||
def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model):
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[])
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
SystemVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/system-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
||||
@patch("controllers.console.app.workflow_draft_variable.WorkflowService")
|
||||
def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model):
|
||||
mock_wf = MagicMock()
|
||||
mock_wf.environment_variables = []
|
||||
mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf
|
||||
|
||||
resp = setup_test_context(
|
||||
app,
|
||||
EnvironmentVariableCollectionApi,
|
||||
"/apps/app_123/workflows/draft/environment-variables",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
)
|
||||
assert resp == {"items": []}
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
ApiKeyAuthDataSource,
|
||||
ApiKeyAuthDataSourceBinding,
|
||||
ApiKeyAuthDataSourceBindingDelete,
|
||||
)
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_binding = MagicMock()
|
||||
mock_binding.id = "bind_123"
|
||||
mock_binding.category = "api_key"
|
||||
mock_binding.provider = "custom_provider"
|
||||
mock_binding.disabled = False
|
||||
mock_binding.created_at.timestamp.return_value = 1620000000
|
||||
mock_binding.updated_at.timestamp.return_value = 1620000001
|
||||
|
||||
mock_get_list.return_value = [mock_binding]
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 1
|
||||
assert response["sources"][0]["provider"] == "custom_provider"
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list")
|
||||
def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_get_list.return_value = None
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSource()
|
||||
response = api_instance.get()
|
||||
|
||||
assert "sources" in response
|
||||
assert len(response["sources"]) == 0
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_validate.assert_called_once()
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args")
|
||||
def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
mock_create.side_effect = ValueError("Invalid structure")
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
method="POST",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBinding()
|
||||
with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestApiKeyAuthDataSourceBindingDelete:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth")
|
||||
def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app):
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant_123"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"):
|
||||
proxy_mock = MagicMock()
|
||||
proxy_mock._get_current_object.return_value = mock_account
|
||||
with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock):
|
||||
api_instance = ApiKeyAuthDataSourceBindingDelete()
|
||||
response = api_instance.delete("binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 204
|
||||
mock_delete.assert_called_once_with("tenant_123", "binding_123")
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.auth.data_source_oauth import (
|
||||
OAuthDataSource,
|
||||
OAuthDataSourceBinding,
|
||||
OAuthDataSourceCallback,
|
||||
OAuthDataSourceSync,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthDataSource:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.current_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None)
|
||||
def test_get_oauth_url_successful(
|
||||
self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app
|
||||
):
|
||||
mock_oauth_provider = MagicMock()
|
||||
mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth"
|
||||
mock_get_providers.return_value = {"notion": mock_oauth_provider}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
mock_libs_user.return_value = mock_account
|
||||
mock_flask_user.return_value = mock_account
|
||||
|
||||
# also patch current_account_with_tenant
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["data"] == "http://oauth.provider/auth"
|
||||
assert response[1] == 200
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("flask_login.current_user")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSource()
|
||||
response = api_instance.get("unknown_provider")
|
||||
|
||||
assert response[0]["error"] == "Invalid provider"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceCallback:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_successful(self, mock_get_providers, app):
|
||||
provider_mock = MagicMock()
|
||||
mock_get_providers.return_value = {"notion": provider_mock}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "code=mock_code" in response.location
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_missing_code(self, mock_get_providers, app):
|
||||
provider_mock = MagicMock()
|
||||
mock_get_providers.return_value = {"notion": provider_mock}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "error=Access denied" in response.location
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_oauth_callback_invalid_provider(self, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"):
|
||||
api_instance = OAuthDataSourceCallback()
|
||||
response = api_instance.get("invalid")
|
||||
|
||||
assert response[0]["error"] == "Invalid provider"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceBinding:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_get_binding_successful(self, mock_get_providers, app):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_access_token.return_value = None
|
||||
mock_get_providers.return_value = {"notion": mock_provider}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"):
|
||||
api_instance = OAuthDataSourceBinding()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_provider.get_access_token.assert_called_once_with("auth_code_123")
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
def test_get_binding_missing_code(self, mock_get_providers, app):
|
||||
mock_get_providers.return_value = {"notion": MagicMock()}
|
||||
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"):
|
||||
api_instance = OAuthDataSourceBinding()
|
||||
response = api_instance.get("notion")
|
||||
|
||||
assert response[0]["error"] == "Invalid code"
|
||||
assert response[1] == 400
|
||||
|
||||
|
||||
class TestOAuthDataSourceSync:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.auth.data_source_oauth.get_oauth_providers")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app):
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.sync_data_source.return_value = None
|
||||
mock_get_providers.return_value = {"notion": mock_provider}
|
||||
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "user_123"
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_account.is_admin_or_owner = True
|
||||
mock_account.current_tenant.current_role = "owner"
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())):
|
||||
with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"):
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
with patch("libs.login.current_user", proxy_mock):
|
||||
api_instance = OAuthDataSourceSync()
|
||||
# The route pattern uses <uuid:binding_id>, so we just pass a string for unit testing
|
||||
response = api_instance.get("notion", "binding_123")
|
||||
|
||||
assert response[0]["result"] == "success"
|
||||
assert response[1] == 200
|
||||
mock_provider.sync_data_source.assert_called_once_with("binding_123")
|
||||
|
|
@ -0,0 +1,417 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.auth.oauth_server import (
|
||||
OAuthServerAppApi,
|
||||
OAuthServerUserAccountApi,
|
||||
OAuthServerUserAuthorizeApi,
|
||||
OAuthServerUserTokenApi,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerAppApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
oauth_app.app_icon = "icon_url"
|
||||
oauth_app.app_label = "Test App"
|
||||
oauth_app.scope = "read,write"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["app_icon"] == "icon_url"
|
||||
assert response["app_label"] == "Test App"
|
||||
assert response["scope"] == "read,write"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_client_id(self, mock_get_app, mock_db, app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider",
|
||||
method="POST",
|
||||
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
):
|
||||
api_instance = OAuthServerAppApi()
|
||||
with pytest.raises(NotFound, match="client_id is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAuthorizeApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
oauth_app = MagicMock()
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code")
|
||||
@patch("libs.login.check_csrf_token")
|
||||
def test_successful_authorize(
|
||||
self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.id = "user_123"
|
||||
from models.account import AccountStatus
|
||||
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
|
||||
mock_current.return_value = (mock_account, MagicMock())
|
||||
mock_wrap_current.return_value = (mock_account, MagicMock())
|
||||
|
||||
mock_sign.return_value = "auth_code_123"
|
||||
|
||||
with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}):
|
||||
with patch("libs.login.current_user", mock_account):
|
||||
api_instance = OAuthServerUserAuthorizeApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["code"] == "auth_code_123"
|
||||
mock_sign.assert_called_once_with("test_client_id", "user_123")
|
||||
|
||||
|
||||
class TestOAuthServerUserTokenApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
oauth_app.client_secret = "test_secret"
|
||||
oauth_app.redirect_uris = ["http://localhost/callback"]
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("access_123", "refresh_123")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "access_123"
|
||||
assert response["refresh_token"] == "refresh_123"
|
||||
assert response["token_type"] == "Bearer"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="code is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "invalid_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="client_secret is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://invalid/callback",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="redirect_uri is invalid"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token")
|
||||
def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_sign.return_value = ("new_access", "new_refresh")
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["access_token"] == "new_access"
|
||||
assert response["refresh_token"] == "new_refresh"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="refresh_token is required"):
|
||||
api_instance.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/token",
|
||||
method="POST",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "invalid_grant",
|
||||
},
|
||||
):
|
||||
api_instance = OAuthServerUserTokenApi()
|
||||
with pytest.raises(BadRequest, match="invalid grant_type"):
|
||||
api_instance.post()
|
||||
|
||||
|
||||
class TestOAuthServerUserAccountApi:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider_app(self):
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
oauth_app = MagicMock(spec=OAuthProviderApp)
|
||||
oauth_app.client_id = "test_client_id"
|
||||
return oauth_app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.name = "Test User"
|
||||
mock_account.email = "test@example.com"
|
||||
mock_account.avatar = "avatar_url"
|
||||
mock_account.interface_language = "en-US"
|
||||
mock_account.timezone = "UTC"
|
||||
mock_validate.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer valid_access_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response["name"] == "Test User"
|
||||
assert response["email"] == "test@example.com"
|
||||
assert response["avatar"] == "avatar_url"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Authorization header is required"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "InvalidFormat"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Basic something"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "token_type is invalid"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer "},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "Invalid Authorization header format"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app")
|
||||
@patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token")
|
||||
def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_app.return_value = mock_oauth_provider_app
|
||||
mock_validate.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/oauth/provider/account",
|
||||
method="POST",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
):
|
||||
api_instance = OAuthServerUserAccountApi()
|
||||
response = api_instance.post()
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json["error"] == "access_token or client_id is invalid"
|
||||
|
|
@ -2,7 +2,7 @@ from datetime import datetime
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import Forbidden, HTTPException, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -19,13 +19,14 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
|||
RagPipelineDraftNodeRunApi,
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
RagPipelineDraftWorkflowRestoreApi,
|
||||
RagPipelineRecommendedPluginApi,
|
||||
RagPipelineTaskStopApi,
|
||||
RagPipelineTransformApi,
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
|
|
@ -116,6 +117,86 @@ class TestDraftWorkflowApi:
|
|||
response, status = method(api, pipeline)
|
||||
assert status == 400
|
||||
|
||||
def test_restore_published_workflow_to_draft_success(self, app):
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="account-1")
|
||||
workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1))
|
||||
|
||||
service = MagicMock()
|
||||
service.restore_published_workflow_to_draft.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "published-workflow")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert result["hash"] == "restored-hash"
|
||||
|
||||
def test_restore_published_workflow_to_draft_not_found(self, app):
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="account-1")
|
||||
|
||||
service = MagicMock()
|
||||
service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "published-workflow")
|
||||
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app):
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="account-1")
|
||||
|
||||
service = MagicMock()
|
||||
service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError(
|
||||
"source workflow must be published"
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
method(api, pipeline, "draft-workflow")
|
||||
|
||||
assert exc.value.code == 400
|
||||
assert exc.value.description == "source workflow must be published"
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
def test_iteration_node_success(self, app):
|
||||
|
|
|
|||
|
|
@ -24,13 +24,8 @@ class TestBannerApi:
|
|||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = datetime(2024, 1, 1)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = [banner]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [banner]
|
||||
|
||||
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
|
@ -58,16 +53,14 @@ class TestBannerApi:
|
|||
banner.status = BannerStatus.ENABLED
|
||||
banner.created_at = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.side_effect = [
|
||||
scalars_result = MagicMock()
|
||||
scalars_result.all.side_effect = [
|
||||
[],
|
||||
[banner],
|
||||
]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
|
@ -87,13 +80,8 @@ class TestBannerApi:
|
|||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = []
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
|
||||
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
|
|
|||
|
|
@ -260,11 +260,10 @@ class TestInstalledAppsCreateApi:
|
|||
app_entity.tenant_id = "t2"
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
None,
|
||||
]
|
||||
# scalar() is called for recommended_app and installed_app lookups
|
||||
session.scalar.side_effect = [recommended, None]
|
||||
# get() is called for app PK lookup
|
||||
session.get.return_value = app_entity
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
|
|
@ -282,7 +281,7 @@ class TestInstalledAppsCreateApi:
|
|||
method = unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
|
|
@ -300,10 +299,10 @@ class TestInstalledAppsCreateApi:
|
|||
app_entity = MagicMock(is_public=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
]
|
||||
# scalar() returns recommended_app
|
||||
session.scalar.return_value = recommended
|
||||
# get() returns the app entity
|
||||
session.get.return_value = app_entity
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
|
|
|
|||
|
|
@ -958,8 +958,8 @@ class TestTrialSitApi:
|
|||
app_model = MagicMock()
|
||||
app_model.id = "a1"
|
||||
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, app_model)
|
||||
|
||||
|
|
@ -973,8 +973,8 @@ class TestTrialSitApi:
|
|||
app_model.tenant = MagicMock()
|
||||
app_model.tenant.status = TenantStatus.ARCHIVE
|
||||
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = site
|
||||
with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar:
|
||||
mock_scalar.return_value = site
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, app_model)
|
||||
|
||||
|
|
@ -990,10 +990,10 @@ class TestTrialSitApi:
|
|||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module.db.session, "query") as mock_query,
|
||||
patch.object(module.db.session, "scalar") as mock_scalar,
|
||||
patch.object(module.SiteResponse, "model_validate") as mock_validate,
|
||||
):
|
||||
mock_query.return_value.where.return_value.first.return_value = site
|
||||
mock_scalar.return_value = site
|
||||
mock_validate_result = MagicMock()
|
||||
mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"}
|
||||
mock_validate.return_value = mock_validate_result
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ def test_installed_app_required_not_found():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
scalar_mock.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
|
|
@ -54,11 +54,11 @@ def test_installed_app_required_app_deleted():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
patch("controllers.console.explore.wraps.db.session.delete"),
|
||||
patch("controllers.console.explore.wraps.db.session.commit"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
scalar_mock.return_value = installed_app
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
|
|
@ -76,9 +76,9 @@ def test_installed_app_required_success():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
scalar_mock.return_value = installed_app
|
||||
|
||||
result = view("app-id")
|
||||
assert result == installed_app
|
||||
|
|
@ -149,9 +149,9 @@ def test_trial_app_required_not_allowed():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
scalar_mock.return_value = None
|
||||
|
||||
with pytest.raises(TrialAppNotAllowed):
|
||||
view("app-id")
|
||||
|
|
@ -170,9 +170,9 @@ def test_trial_app_required_limit_exceeded():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
scalar_mock.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
|
@ -194,9 +194,9 @@ def test_trial_app_required_success():
|
|||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
scalar_mock.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from controllers.console.tag.tags import (
|
|||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
from models.enums import TagType
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
|
|
@ -52,7 +53,7 @@ def tag():
|
|||
tag = MagicMock()
|
||||
tag.id = "tag-1"
|
||||
tag.name = "test-tag"
|
||||
tag.type = "knowledge"
|
||||
tag.type = TagType.KNOWLEDGE
|
||||
return tag
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class TestBaseApiKeyResource:
|
|||
|
||||
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = None
|
||||
db_mock.session.scalar.return_value = None
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
|
|
@ -125,7 +125,7 @@ class TestBaseApiKeyResource:
|
|||
|
||||
def test_delete_success(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
db_mock.session.scalar.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource"),
|
||||
|
|
|
|||
|
|
@ -328,7 +328,7 @@ class TestSystemSetup:
|
|||
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
|
||||
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = None # No setup
|
||||
mock_db.session.scalar.return_value = None # No setup
|
||||
mock_environ_get.return_value = "some_password"
|
||||
|
||||
@setup_required
|
||||
|
|
@ -345,7 +345,7 @@ class TestSystemSetup:
|
|||
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
|
||||
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = None # No setup
|
||||
mock_db.session.scalar.return_value = None # No setup
|
||||
mock_environ_get.return_value = None # No INIT_PASSWORD
|
||||
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -55,9 +55,9 @@ class TestAccountInitApi:
|
|||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
|
||||
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.workspace.account.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
|
||||
scalar_mock.return_value = MagicMock(status="unused")
|
||||
resp = method(api)
|
||||
|
||||
assert resp["result"] == "success"
|
||||
|
|
|
|||
|
|
@ -207,10 +207,10 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 200
|
||||
|
|
@ -226,9 +226,9 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
get_mock.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "x")
|
||||
|
|
@ -244,13 +244,13 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 400
|
||||
|
|
@ -266,13 +266,13 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 403
|
||||
|
|
@ -288,13 +288,13 @@ class TestMemberCancelInviteApi:
|
|||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 404
|
||||
|
|
|
|||
|
|
@ -36,7 +36,115 @@ def unwrap(func):
|
|||
|
||||
|
||||
class TestTenantListApi:
|
||||
def test_get_success(self, app):
|
||||
def test_get_success_saas_path(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
return_value={
|
||||
"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0},
|
||||
"t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0},
|
||||
},
|
||||
) as get_plan_bulk_mock,
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app):
|
||||
"""Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used.
|
||||
|
||||
billing.enabled is mocked False to prove the endpoint does not gate on it for this path
|
||||
(SaaS contract treats enabled as on; display follows subscription.plan).
|
||||
"""
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
features_t2 = MagicMock()
|
||||
features_t2.billing.enabled = False
|
||||
features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}},
|
||||
) as get_plan_bulk_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features_t2,
|
||||
) as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
get_features_mock.assert_called_once_with("t2")
|
||||
|
||||
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
|
||||
"""Test fallback to FeatureService when bulk billing returns empty result.
|
||||
|
||||
BillingService.get_plan_bulk catches exceptions internally and returns empty dict,
|
||||
so we simulate the real failure mode by returning empty dict for non-empty input.
|
||||
"""
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
|
|
@ -54,27 +162,41 @@ class TestTenantListApi:
|
|||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = True
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
features.billing.enabled = False
|
||||
features.billing.subscription.plan = CloudPlan.TEAM
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
return_value={}, # Simulates real failure: empty result for non-empty input
|
||||
) as get_plan_bulk_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
) as get_features_mock,
|
||||
patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.TEAM
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
assert get_features_mock.call_count == 2
|
||||
logger_warning_mock.assert_called_once()
|
||||
|
||||
def test_get_billing_disabled(self, app):
|
||||
def test_get_billing_disabled_community_path(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
|
|
@ -87,6 +209,7 @@ class TestTenantListApi:
|
|||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = False
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
|
|
@ -98,15 +221,83 @@ class TestTenantListApi:
|
|||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
),
|
||||
) as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
get_features_mock.assert_called_once_with("t1")
|
||||
|
||||
def test_get_enterprise_only_skips_feature_service(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX
|
||||
assert result["workspaces"][0]["current"] is False
|
||||
assert result["workspaces"][1]["current"] is True
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
def test_get_enterprise_only_with_empty_tenants(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None)
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"] == []
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
|
||||
class TestWorkspaceListApi:
|
||||
|
|
@ -258,12 +449,12 @@ class TestSwitchWorkspaceApi:
|
|||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
|
||||
),
|
||||
):
|
||||
query_mock.return_value.get.return_value = tenant
|
||||
get_mock.return_value = tenant
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
|
@ -297,9 +488,9 @@ class TestSwitchWorkspaceApi:
|
|||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
|
||||
):
|
||||
query_mock.return_value.get.return_value = None
|
||||
get_mock.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import (
|
|||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.enums import TagType
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
|
@ -277,7 +278,7 @@ class TestDatasetTagsApi:
|
|||
mock_tag = Mock()
|
||||
mock_tag.id = "tag_1"
|
||||
mock_tag.name = "Test Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag.binding_count = "0" # Required for Pydantic validation - must be string
|
||||
mock_tag_service.get_tags.return_value = [mock_tag]
|
||||
|
||||
|
|
@ -316,7 +317,7 @@ class TestDatasetTagsApi:
|
|||
mock_tag = Mock()
|
||||
mock_tag.id = "new_tag_1"
|
||||
mock_tag.name = "New Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag_service.save_tags.return_value = mock_tag
|
||||
mock_service_api_ns.payload = {"name": "New Tag"}
|
||||
|
||||
|
|
@ -378,7 +379,7 @@ class TestDatasetTagsApi:
|
|||
mock_tag = Mock()
|
||||
mock_tag.id = "tag_1"
|
||||
mock_tag.name = "Updated Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_tag.binding_count = "5"
|
||||
mock_tag_service.update_tags.return_value = mock_tag
|
||||
mock_tag_service.get_tag_binding_count.return_value = 5
|
||||
|
|
@ -866,7 +867,7 @@ class TestTagService:
|
|||
mock_tag = Mock()
|
||||
mock_tag.id = str(uuid.uuid4())
|
||||
mock_tag.name = "New Tag"
|
||||
mock_tag.type = "knowledge"
|
||||
mock_tag.type = TagType.KNOWLEDGE
|
||||
mock_save.return_value = mock_tag
|
||||
|
||||
result = TagService.save_tags({"name": "New Tag", "type": "knowledge"})
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
|||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.task_pipeline import message_cycle_manager
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
|
|
@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation():
|
|||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id="user-id",
|
||||
from_account_id=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,181 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint
|
||||
from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams
|
||||
from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
|
||||
class TestApiModeration:
|
||||
@pytest.fixture
|
||||
def api_config(self):
|
||||
return {
|
||||
"inputs_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"outputs_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"api_based_extension_id": "test-extension-id",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def api_moderation(self, api_config):
|
||||
return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config)
|
||||
|
||||
def test_moderation_input_params(self):
|
||||
params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query")
|
||||
assert params.app_id == "app-1"
|
||||
assert params.inputs == {"key": "val"}
|
||||
assert params.query == "test query"
|
||||
|
||||
# Test defaults
|
||||
params_default = ModerationInputParams()
|
||||
assert params_default.app_id == ""
|
||||
assert params_default.inputs == {}
|
||||
assert params_default.query == ""
|
||||
|
||||
def test_moderation_output_params(self):
|
||||
params = ModerationOutputParams(app_id="app-1", text="test text")
|
||||
assert params.app_id == "app-1"
|
||||
assert params.text == "test text"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ModerationOutputParams()
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_validate_config_success(self, mock_get_extension, api_config):
|
||||
mock_get_extension.return_value = MagicMock(spec=APIBasedExtension)
|
||||
ApiModeration.validate_config("test-tenant-id", api_config)
|
||||
mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id")
|
||||
|
||||
def test_validate_config_missing_extension_id(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": True},
|
||||
"outputs_config": {"enabled": True},
|
||||
}
|
||||
with pytest.raises(ValueError, match="api_based_extension_id is required"):
|
||||
ApiModeration.validate_config("test-tenant-id", config)
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_validate_config_extension_not_found(self, mock_get_extension, api_config):
|
||||
mock_get_extension.return_value = None
|
||||
with pytest.raises(ValueError, match="API-based Extension not found"):
|
||||
ApiModeration.validate_config("test-tenant-id", api_config)
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
|
||||
def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation):
|
||||
mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"}
|
||||
|
||||
result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello")
|
||||
|
||||
assert isinstance(result, ModerationInputsResult)
|
||||
assert result.flagged is True
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == "Blocked by API"
|
||||
|
||||
mock_get_config.assert_called_once_with(
|
||||
APIBasedExtensionPoint.APP_MODERATION_INPUT,
|
||||
{"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"},
|
||||
)
|
||||
|
||||
def test_moderation_for_inputs_disabled(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": False},
|
||||
"outputs_config": {"enabled": True},
|
||||
"api_based_extension_id": "ext-id",
|
||||
}
|
||||
moderation = ApiModeration("app-id", "tenant-id", config)
|
||||
result = moderation.moderation_for_inputs(inputs={}, query="")
|
||||
|
||||
assert result.flagged is False
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == ""
|
||||
|
||||
def test_moderation_for_inputs_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation.moderation_for_inputs({}, "")
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_config_by_requestor")
|
||||
def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation):
|
||||
mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""}
|
||||
|
||||
result = api_moderation.moderation_for_outputs(text="hello world")
|
||||
|
||||
assert isinstance(result, ModerationOutputsResult)
|
||||
assert result.flagged is False
|
||||
|
||||
mock_get_config.assert_called_once_with(
|
||||
APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"}
|
||||
)
|
||||
|
||||
def test_moderation_for_outputs_disabled(self):
|
||||
config = {
|
||||
"inputs_config": {"enabled": True},
|
||||
"outputs_config": {"enabled": False},
|
||||
"api_based_extension_id": "ext-id",
|
||||
}
|
||||
moderation = ApiModeration("app-id", "tenant-id", config)
|
||||
result = moderation.moderation_for_outputs(text="test")
|
||||
|
||||
assert result.flagged is False
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
|
||||
def test_moderation_for_outputs_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation.moderation_for_outputs("test")
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
@patch("core.moderation.api.api.decrypt_token")
|
||||
@patch("core.moderation.api.api.APIBasedExtensionRequestor")
|
||||
def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation):
|
||||
mock_ext = MagicMock(spec=APIBasedExtension)
|
||||
mock_ext.api_endpoint = "http://api.test"
|
||||
mock_ext.api_key = "encrypted-key"
|
||||
mock_get_ext.return_value = mock_ext
|
||||
|
||||
mock_decrypt.return_value = "decrypted-key"
|
||||
|
||||
mock_requestor = MagicMock()
|
||||
mock_requestor.request.return_value = {"flagged": True}
|
||||
mock_requestor_cls.return_value = mock_requestor
|
||||
|
||||
params = {"some": "params"}
|
||||
result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
|
||||
|
||||
assert result == {"flagged": True}
|
||||
mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id")
|
||||
mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key")
|
||||
mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key")
|
||||
mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params)
|
||||
|
||||
def test_get_config_by_requestor_no_config(self):
|
||||
moderation = ApiModeration("app-id", "tenant-id", None)
|
||||
with pytest.raises(ValueError, match="The config is not set"):
|
||||
moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
|
||||
|
||||
@patch("core.moderation.api.api.ApiModeration._get_api_based_extension")
|
||||
def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation):
|
||||
mock_get_ext.return_value = None
|
||||
with pytest.raises(ValueError, match="API-based Extension not found"):
|
||||
api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {})
|
||||
|
||||
@patch("core.moderation.api.api.db.session.scalar")
|
||||
def test_get_api_based_extension(self, mock_scalar):
|
||||
mock_ext = MagicMock(spec=APIBasedExtension)
|
||||
mock_scalar.return_value = mock_ext
|
||||
|
||||
result = ApiModeration._get_api_based_extension("tenant-1", "ext-1")
|
||||
|
||||
assert result == mock_ext
|
||||
mock_scalar.assert_called_once()
|
||||
# Verify the call has the correct filters
|
||||
args, kwargs = mock_scalar.call_args
|
||||
stmt = args[0]
|
||||
# We can't easily inspect the statement without complex sqlalchemy tricks,
|
||||
# but calling it is usually enough for unit tests if we mock the result.
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity
|
||||
from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class TestInputModeration:
|
||||
@pytest.fixture
|
||||
def app_config(self):
|
||||
config = MagicMock(spec=AppConfig)
|
||||
config.sensitive_word_avoidance = None
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def input_moderation(self):
|
||||
return InputModeration()
|
||||
|
||||
def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is False
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {"keywords": ["bad"]}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is False
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
mock_factory_cls.assert_called_once_with(
|
||||
name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]}
|
||||
)
|
||||
mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query)
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
@patch("core.moderation.input_moderation.TraceTask")
|
||||
def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
trace_manager = MagicMock(spec=TraceQueueManager)
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value)
|
||||
mock_trace_task.assert_called_once()
|
||||
call_kwargs = mock_trace_task.call_args.kwargs
|
||||
call_args = mock_trace_task.call_args.args
|
||||
assert call_args[0] == TraceTaskName.MODERATION_TRACE
|
||||
assert call_kwargs["message_id"] == message_id
|
||||
assert call_kwargs["moderation_result"] == mock_result
|
||||
assert call_kwargs["inputs"] == inputs
|
||||
assert "timer" in call_kwargs
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content"
|
||||
)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
with pytest.raises(ModerationError) as excinfo:
|
||||
input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == "Blocked content"
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = ModerationInputsResult(
|
||||
flagged=True,
|
||||
action=ModerationAction.OVERRIDDEN,
|
||||
inputs={"input_key": "overridden_value"},
|
||||
query="overridden query",
|
||||
)
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id
|
||||
)
|
||||
|
||||
assert flagged is True
|
||||
assert final_inputs == {"input_key": "overridden_value"}
|
||||
assert final_query == "overridden query"
|
||||
|
||||
@patch("core.moderation.input_moderation.ModerationFactory")
|
||||
def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation):
|
||||
app_id = "test_app_id"
|
||||
tenant_id = "test_tenant_id"
|
||||
inputs = {"input_key": "input_value"}
|
||||
query = "test query"
|
||||
message_id = "test_message_id"
|
||||
|
||||
# Setup config
|
||||
sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity)
|
||||
sensitive_word_config.type = "keywords"
|
||||
sensitive_word_config.config = {}
|
||||
app_config.sensitive_word_avoidance = sensitive_word_config
|
||||
|
||||
# Setup factory mock
|
||||
mock_factory = mock_factory_cls.return_value
|
||||
mock_result = MagicMock()
|
||||
mock_result.flagged = True
|
||||
mock_result.action = "NONE" # Some other action
|
||||
mock_factory.moderation_for_inputs.return_value = mock_result
|
||||
|
||||
flagged, final_inputs, final_query = input_moderation.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
assert flagged is True
|
||||
assert final_inputs == inputs
|
||||
assert final_query == query
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageReplaceEvent
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
|
||||
|
||||
class TestOutputModeration:
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
return MagicMock(spec=AppQueueManager)
|
||||
|
||||
@pytest.fixture
|
||||
def moderation_rule(self):
|
||||
return ModerationRule(type="keywords", config={"keywords": "badword"})
|
||||
|
||||
@pytest.fixture
|
||||
def output_moderation(self, mock_queue_manager, moderation_rule):
|
||||
return OutputModeration(
|
||||
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
|
||||
)
|
||||
|
||||
def test_should_direct_output(self, output_moderation):
|
||||
assert output_moderation.should_direct_output() is False
|
||||
output_moderation.final_output = "blocked"
|
||||
assert output_moderation.should_direct_output() is True
|
||||
|
||||
def test_get_final_output(self, output_moderation):
|
||||
assert output_moderation.get_final_output() == ""
|
||||
output_moderation.final_output = "blocked"
|
||||
assert output_moderation.get_final_output() == "blocked"
|
||||
|
||||
def test_append_new_token(self, output_moderation):
|
||||
with patch.object(OutputModeration, "start_thread") as mock_start:
|
||||
output_moderation.append_new_token("hello")
|
||||
assert output_moderation.buffer == "hello"
|
||||
mock_start.assert_called_once()
|
||||
|
||||
output_moderation.thread = MagicMock()
|
||||
output_moderation.append_new_token(" world")
|
||||
assert output_moderation.buffer == "hello world"
|
||||
assert mock_start.call_count == 1
|
||||
|
||||
def test_moderation_completion_no_flag(self, output_moderation):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("safe content")
|
||||
|
||||
assert output == "safe content"
|
||||
assert flagged is False
|
||||
assert output_moderation.is_final_chunk is True
|
||||
|
||||
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
|
||||
)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
|
||||
|
||||
assert output == "preset"
|
||||
assert flagged is True
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert isinstance(args[0], QueueMessageReplaceEvent)
|
||||
assert args[0].text == "preset"
|
||||
assert args[1] == PublishFrom.TASK_PIPELINE
|
||||
|
||||
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
|
||||
)
|
||||
|
||||
output, flagged = output_moderation.moderation_completion("badword content", public_event=True)
|
||||
|
||||
assert output == "masked content"
|
||||
assert flagged is True
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert args[0].text == "masked content"
|
||||
|
||||
def test_start_thread(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
|
||||
mock_current_app._get_current_object.return_value = mock_app
|
||||
with patch("threading.Thread") as mock_thread_class:
|
||||
mock_thread_instance = MagicMock()
|
||||
mock_thread_class.return_value = mock_thread_instance
|
||||
|
||||
thread = output_moderation.start_thread()
|
||||
|
||||
assert thread == mock_thread_instance
|
||||
mock_thread_class.assert_called_once()
|
||||
mock_thread_instance.start.assert_called_once()
|
||||
|
||||
def test_stop_thread(self, output_moderation):
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.is_alive.return_value = True
|
||||
output_moderation.thread = mock_thread
|
||||
|
||||
output_moderation.stop_thread()
|
||||
assert output_moderation.thread_running is False
|
||||
|
||||
output_moderation.thread_running = True
|
||||
mock_thread.is_alive.return_value = False
|
||||
output_moderation.stop_thread()
|
||||
assert output_moderation.thread_running is True
|
||||
|
||||
@patch("core.moderation.output_moderation.ModerationFactory")
|
||||
def test_moderation_success(self, mock_factory_class, output_moderation):
|
||||
mock_factory = mock_factory_class.return_value
|
||||
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
mock_factory.moderation_for_outputs.return_value = mock_result
|
||||
|
||||
result = output_moderation.moderation("tenant", "app", "buffer")
|
||||
|
||||
assert result == mock_result
|
||||
mock_factory_class.assert_called_once_with(
|
||||
name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"}
|
||||
)
|
||||
|
||||
@patch("core.moderation.output_moderation.ModerationFactory")
|
||||
def test_moderation_exception(self, mock_factory_class, output_moderation):
|
||||
mock_factory_class.side_effect = Exception("error")
|
||||
|
||||
result = output_moderation.moderation("tenant", "app", "buffer")
|
||||
assert result is None
|
||||
|
||||
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
# Test exit on thread_running=False
|
||||
output_moderation.thread_running = False
|
||||
output_moderation.worker(mock_app, 10)
|
||||
# Should exit immediately
|
||||
|
||||
def test_worker_no_flag(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
|
||||
|
||||
output_moderation.buffer = "safe"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
# To avoid infinite loop, we'll set thread_running to False after one iteration
|
||||
def side_effect(*args, **kwargs):
|
||||
output_moderation.thread_running = False
|
||||
return mock_moderation.return_value
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
assert mock_moderation.called
|
||||
|
||||
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
mock_moderation.return_value = ModerationOutputsResult(
|
||||
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
|
||||
)
|
||||
|
||||
output_moderation.buffer = "badword"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
assert output_moderation.final_output == "preset"
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
# It breaks on DIRECT_OUTPUT
|
||||
|
||||
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
# Use side_effect to change thread_running on second call
|
||||
def side_effect(*args, **kwargs):
|
||||
if mock_moderation.call_count > 1:
|
||||
output_moderation.thread_running = False
|
||||
return None
|
||||
return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked")
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.buffer = "badword"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
args, _ = mock_queue_manager.publish.call_args
|
||||
assert args[0].text == "masked"
|
||||
|
||||
def test_worker_chunk_too_small(self, output_moderation):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# chunk_length < buffer_size and not is_final_chunk
|
||||
output_moderation.buffer = "123" # length 3
|
||||
output_moderation.is_final_chunk = False
|
||||
|
||||
def sleep_side_effect(seconds):
|
||||
output_moderation.thread_running = False
|
||||
|
||||
mock_sleep.side_effect = sleep_side_effect
|
||||
|
||||
output_moderation.worker(mock_app, 10) # buffer_size 10
|
||||
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
|
||||
mock_app = MagicMock(spec=Flask)
|
||||
with patch.object(OutputModeration, "moderation") as mock_moderation:
|
||||
# Return None (exception or no rule)
|
||||
mock_moderation.return_value = None
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
output_moderation.thread_running = False
|
||||
|
||||
mock_moderation.side_effect = side_effect
|
||||
|
||||
output_moderation.buffer = "something"
|
||||
output_moderation.is_final_chunk = True
|
||||
|
||||
output_moderation.worker(mock_app, 10)
|
||||
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import (
|
||||
TidbOnQdrantConfig,
|
||||
TidbOnQdrantVector,
|
||||
)
|
||||
|
||||
|
||||
class TestTidbOnQdrantVectorDeleteByIds:
|
||||
"""Unit tests for TidbOnQdrantVector.delete_by_ids method."""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_instance(self):
|
||||
"""Create a TidbOnQdrantVector instance for testing."""
|
||||
config = TidbOnQdrantConfig(
|
||||
endpoint="http://localhost:6333",
|
||||
api_key="test_api_key",
|
||||
)
|
||||
|
||||
with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"):
|
||||
vector = TidbOnQdrantVector(
|
||||
collection_name="test_collection",
|
||||
group_id="test_group",
|
||||
config=config,
|
||||
)
|
||||
return vector
|
||||
|
||||
def test_delete_by_ids_with_multiple_ids(self, vector_instance):
|
||||
"""Test batch deletion with multiple document IDs."""
|
||||
ids = ["doc1", "doc2", "doc3"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify that delete was called once with MatchAny filter
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
# Check collection name
|
||||
assert call_args[1]["collection_name"] == "test_collection"
|
||||
|
||||
# Verify filter uses MatchAny with all IDs
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
assert len(filter_obj.must) == 1
|
||||
|
||||
field_condition = filter_obj.must[0]
|
||||
assert field_condition.key == "metadata.doc_id"
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"}
|
||||
|
||||
def test_delete_by_ids_with_single_id(self, vector_instance):
|
||||
"""Test deletion with a single document ID."""
|
||||
ids = ["doc1"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify that delete was called once
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
# Verify filter uses MatchAny with single ID
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ["doc1"]
|
||||
|
||||
def test_delete_by_ids_with_empty_list(self, vector_instance):
|
||||
"""Test deletion with empty ID list returns early without API call."""
|
||||
vector_instance.delete_by_ids([])
|
||||
|
||||
# Verify that delete was NOT called
|
||||
vector_instance._client.delete.assert_not_called()
|
||||
|
||||
def test_delete_by_ids_with_404_error(self, vector_instance):
|
||||
"""Test that 404 errors (collection not found) are handled gracefully."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
# Mock a 404 error
|
||||
error = UnexpectedResponse(
|
||||
status_code=404,
|
||||
reason_phrase="Not Found",
|
||||
content=b"Collection not found",
|
||||
headers=httpx.Headers(),
|
||||
)
|
||||
vector_instance._client.delete.side_effect = error
|
||||
|
||||
# Should not raise an exception
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify delete was called
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
|
||||
def test_delete_by_ids_with_unexpected_error(self, vector_instance):
|
||||
"""Test that non-404 errors are re-raised."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
# Mock a 500 error
|
||||
error = UnexpectedResponse(
|
||||
status_code=500,
|
||||
reason_phrase="Internal Server Error",
|
||||
content=b"Server error",
|
||||
headers=httpx.Headers(),
|
||||
)
|
||||
vector_instance._client.delete.side_effect = error
|
||||
|
||||
# Should re-raise the exception
|
||||
with pytest.raises(UnexpectedResponse) as exc_info:
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
def test_delete_by_ids_with_large_batch(self, vector_instance):
|
||||
"""Test deletion with a large batch of IDs."""
|
||||
# Create 1000 IDs
|
||||
ids = [f"doc_{i}" for i in range(1000)]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
# Verify single delete call with all IDs
|
||||
vector_instance._client.delete.assert_called_once()
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
field_condition = filter_obj.must[0]
|
||||
|
||||
# Verify all 1000 IDs are in the batch
|
||||
assert len(field_condition.match.any) == 1000
|
||||
assert "doc_0" in field_condition.match.any
|
||||
assert "doc_999" in field_condition.match.any
|
||||
|
||||
def test_delete_by_ids_filter_structure(self, vector_instance):
|
||||
"""Test that the filter structure is correctly constructed."""
|
||||
ids = ["doc1", "doc2"]
|
||||
|
||||
vector_instance.delete_by_ids(ids)
|
||||
|
||||
call_args = vector_instance._client.delete.call_args
|
||||
filter_selector = call_args[1]["points_selector"]
|
||||
filter_obj = filter_selector.filter
|
||||
|
||||
# Verify Filter structure
|
||||
assert isinstance(filter_obj, rest.Filter)
|
||||
assert filter_obj.must is not None
|
||||
assert len(filter_obj.must) == 1
|
||||
|
||||
# Verify FieldCondition structure
|
||||
field_condition = filter_obj.must[0]
|
||||
assert isinstance(field_condition, rest.FieldCondition)
|
||||
assert field_condition.key == "metadata.doc_id"
|
||||
|
||||
# Verify MatchAny structure
|
||||
assert isinstance(field_condition.match, rest.MatchAny)
|
||||
assert field_condition.match.any == ids
|
||||
|
|
@ -0,0 +1,677 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
HumanInputFormSubmissionRepository,
|
||||
_HumanInputFormEntityImpl,
|
||||
_HumanInputFormRecipientEntityImpl,
|
||||
_InvalidTimeoutStatusError,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from dify_graph.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _FakeSelect:
|
||||
def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect:
|
||||
return self
|
||||
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect())
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader")
|
||||
|
||||
|
||||
def _make_form_definition_json(*, include_expiration_time: bool) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"form_content": "hi",
|
||||
"inputs": [],
|
||||
"user_actions": [{"id": "submit", "title": "Submit"}],
|
||||
"rendered_content": "<p>hi</p>",
|
||||
}
|
||||
if include_expiration_time:
|
||||
payload["expiration_time"] = naive_utc_now()
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyForm:
|
||||
id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_definition: str
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
|
||||
selected_action_id: str | None = None
|
||||
submitted_data: str | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
completed_by_recipient_id: str | None = None
|
||||
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyRecipient:
|
||||
id: str
|
||||
form_id: str
|
||||
recipient_type: RecipientType
|
||||
access_token: str | None
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def first(self) -> Any:
|
||||
if isinstance(self._obj, list):
|
||||
return self._obj[0] if self._obj else None
|
||||
return self._obj
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
if self._obj is None:
|
||||
return []
|
||||
if isinstance(self._obj, list):
|
||||
return list(self._obj)
|
||||
return [self._obj]
|
||||
|
||||
|
||||
class _FakeExecuteResult:
|
||||
def __init__(self, rows: Sequence[tuple[Any, ...]]):
|
||||
self._rows = list(rows)
|
||||
|
||||
def all(self) -> list[tuple[Any, ...]]:
|
||||
return list(self._rows)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scalars_result: Any = None,
|
||||
scalars_results: list[Any] | None = None,
|
||||
forms: dict[str, _DummyForm] | None = None,
|
||||
recipients: dict[str, _DummyRecipient] | None = None,
|
||||
execute_rows: Sequence[tuple[Any, ...]] = (),
|
||||
):
|
||||
if scalars_results is not None:
|
||||
self._scalars_queue = list(scalars_results)
|
||||
else:
|
||||
self._scalars_queue = [scalars_result]
|
||||
self._forms = forms or {}
|
||||
self._recipients = recipients or {}
|
||||
self._execute_rows = list(execute_rows)
|
||||
self.added: list[Any] = []
|
||||
|
||||
def scalars(self, _query: Any) -> _FakeScalarResult:
|
||||
if self._scalars_queue:
|
||||
value = self._scalars_queue.pop(0)
|
||||
else:
|
||||
value = None
|
||||
return _FakeScalarResult(value)
|
||||
|
||||
def execute(self, _stmt: Any) -> _FakeExecuteResult:
|
||||
return _FakeExecuteResult(self._execute_rows)
|
||||
|
||||
def get(self, model_cls: Any, obj_id: str) -> Any:
|
||||
name = getattr(model_cls, "__name__", "")
|
||||
if name == "HumanInputForm":
|
||||
return self._forms.get(obj_id)
|
||||
if name == "HumanInputFormRecipient":
|
||||
return self._recipients.get(obj_id)
|
||||
return None
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def add_all(self, objs: Sequence[Any]) -> None:
|
||||
self.added.extend(list(objs))
|
||||
|
||||
def flush(self) -> None:
|
||||
# Simulate DB default population for attributes referenced in entity wrappers.
|
||||
for obj in self.added:
|
||||
if hasattr(obj, "id") and obj.id in (None, ""):
|
||||
obj.id = f"gen-{len(str(self.added))}"
|
||||
if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None:
|
||||
if obj.recipient_type == RecipientType.CONSOLE:
|
||||
obj.access_token = "token-console"
|
||||
elif obj.recipient_type == RecipientType.BACKSTAGE:
|
||||
obj.access_token = "token-backstage"
|
||||
else:
|
||||
obj.access_token = "token-webapp"
|
||||
|
||||
def refresh(self, _obj: Any) -> None:
|
||||
return None
|
||||
|
||||
def begin(self) -> _FakeSession:
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _FakeSession:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _SessionFactoryStub:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def create_session(self) -> _FakeSession:
|
||||
return self._session
|
||||
|
||||
|
||||
def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None:
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session))
|
||||
|
||||
|
||||
def test_recipient_entity_token_raises_when_missing() -> None:
|
||||
recipient = SimpleNamespace(id="r1", access_token=None)
|
||||
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
|
||||
with pytest.raises(AssertionError, match="access_token should not be None"):
|
||||
_ = entity.token
|
||||
|
||||
|
||||
def test_recipient_entity_id_and_token_success() -> None:
|
||||
recipient = SimpleNamespace(id="r1", access_token="tok")
|
||||
entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type]
|
||||
assert entity.id == "r1"
|
||||
assert entity.token == "tok"
|
||||
|
||||
|
||||
def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok")
|
||||
webapp = _DummyRecipient(
|
||||
id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok"
|
||||
)
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "ctok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token == "wtok"
|
||||
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.web_app_token is None
|
||||
|
||||
|
||||
def test_form_entity_submitted_data_parsed() -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
submitted_data='{"a": 1}',
|
||||
submitted_at=naive_utc_now(),
|
||||
)
|
||||
entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type]
|
||||
assert entity.submitted is True
|
||||
assert entity.submitted_data == {"a": 1}
|
||||
assert entity.rendered_content == "<p>x</p>"
|
||||
assert entity.selected_action_id is None
|
||||
assert entity.status == HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
def test_form_record_from_models_injects_expiration_time_when_missing() -> None:
|
||||
expiration = naive_utc_now()
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=False),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=expiration,
|
||||
submitted_data='{"k": "v"}',
|
||||
)
|
||||
record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type]
|
||||
assert record.definition.expiration_time == expiration
|
||||
assert record.submitted_data == {"k": "v"}
|
||||
assert record.submitted is False
|
||||
|
||||
|
||||
def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: list[SimpleNamespace] = []
|
||||
|
||||
def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def]
|
||||
recipient = SimpleNamespace(
|
||||
id=f"{payload.TYPE}-{len(created)}",
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
access_token="tok",
|
||||
)
|
||||
created.append(recipient)
|
||||
return recipient
|
||||
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new))
|
||||
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined]
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
members=[
|
||||
_WorkspaceMemberInfo(user_id="u1", email=""),
|
||||
_WorkspaceMemberInfo(user_id="u2", email="a@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="u3", email="a@example.com"),
|
||||
],
|
||||
external_emails=["", "a@example.com", "b@example.com", "b@example.com"],
|
||||
)
|
||||
assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL]
|
||||
|
||||
|
||||
def test_query_workspace_members_by_ids_empty_returns_empty() -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == []
|
||||
|
||||
|
||||
def test_query_workspace_members_by_ids_maps_rows() -> None:
|
||||
session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")])
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"])
|
||||
assert rows == [
|
||||
_WorkspaceMemberInfo(user_id="u1", email="a@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="u2", email="b@example.com"),
|
||||
]
|
||||
|
||||
|
||||
def test_query_all_workspace_members_maps_rows() -> None:
|
||||
session = _FakeSession(execute_rows=[("u1", "a@example.com")])
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
rows = repo._query_all_workspace_members(session=session)
|
||||
assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
|
||||
|
||||
|
||||
def test_repository_init_sets_tenant_id() -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo._tenant_id == "tenant"
|
||||
|
||||
|
||||
def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
|
||||
result = repo._delivery_method_to_model(
|
||||
session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod()
|
||||
)
|
||||
assert result.delivery.id == "del-1"
|
||||
assert result.delivery.form_id == "form-1"
|
||||
assert len(result.recipients) == 1
|
||||
assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
|
||||
|
||||
def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1")
|
||||
called: dict[str, Any] = {}
|
||||
|
||||
def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]:
|
||||
called.update(
|
||||
{"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config}
|
||||
)
|
||||
return ["r"]
|
||||
|
||||
monkeypatch.setattr(repo, "_build_email_recipients", fake_build)
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
subject="s",
|
||||
body="b",
|
||||
)
|
||||
)
|
||||
result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method)
|
||||
assert result.recipients == ["r"]
|
||||
assert called["delivery_id"] == "del-1"
|
||||
|
||||
|
||||
def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
monkeypatch.setattr(
|
||||
repo,
|
||||
"_query_all_workspace_members",
|
||||
lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")],
|
||||
)
|
||||
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
|
||||
recipients = repo._build_email_recipients(
|
||||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
||||
def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
|
||||
def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]:
|
||||
assert restrict_to_user_ids == ["u1"]
|
||||
return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")]
|
||||
|
||||
monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query)
|
||||
monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"])
|
||||
recipients = repo._build_email_recipients(
|
||||
session=MagicMock(),
|
||||
form_id="f",
|
||||
delivery_id="d",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")],
|
||||
),
|
||||
)
|
||||
assert recipients == ["ok"]
|
||||
|
||||
|
||||
def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None]))
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
assert repo.get_form("run", "node") is None
|
||||
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id="run",
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="r1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="tok",
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
entity = repo.get_form("run", "node")
|
||||
assert entity is not None
|
||||
assert entity.id == "f1"
|
||||
assert entity.recipients[0].id == "r1"
|
||||
assert entity.recipients[0].token == "tok"
|
||||
|
||||
|
||||
def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
ids = iter(["form-id", "del-web", "del-console", "del-backstage"])
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids))
|
||||
|
||||
session = _FakeSession()
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id="tenant")
|
||||
|
||||
form_config = HumanInputNodeData(
|
||||
title="Title",
|
||||
delivery_methods=[],
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
params = FormCreateParams(
|
||||
app_id="app",
|
||||
workflow_execution_id="run",
|
||||
node_id="node",
|
||||
form_config=form_config,
|
||||
rendered_content="<p>hello</p>",
|
||||
delivery_methods=[WebAppDeliveryMethod()],
|
||||
display_in_ui=True,
|
||||
resolved_default_values={},
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
console_recipient_required=True,
|
||||
console_creator_account_id="acc-1",
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
|
||||
entity = repo.create_form(params)
|
||||
assert entity.id == "form-id"
|
||||
assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout)
|
||||
# Console token should take precedence when console recipient is present.
|
||||
assert entity.web_app_token == "token-console"
|
||||
assert len(entity.recipients) == 3
|
||||
|
||||
|
||||
def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_token("tok") is None
|
||||
|
||||
recipient = SimpleNamespace(form=None)
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_token("tok") is None
|
||||
|
||||
|
||||
def test_submission_repository_init_no_args() -> None:
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert isinstance(repo, HumanInputFormSubmissionRepository)
|
||||
|
||||
|
||||
def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f1",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = SimpleNamespace(
|
||||
id="r1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="tok",
|
||||
form=form,
|
||||
)
|
||||
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.get_by_token("tok")
|
||||
assert record is not None
|
||||
assert record.access_token == "tok"
|
||||
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP)
|
||||
assert record is not None
|
||||
assert record.recipient_id == "r1"
|
||||
|
||||
|
||||
def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(scalars_result=None))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None
|
||||
|
||||
|
||||
def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
missing_session = _FakeSession(forms={})
|
||||
_patch_session_factory(monkeypatch, missing_session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form not found"):
|
||||
repo.mark_submitted(
|
||||
form_id="missing",
|
||||
recipient_id=None,
|
||||
selected_action_id="a",
|
||||
form_data={},
|
||||
submission_user_id=None,
|
||||
submission_end_user_id=None,
|
||||
)
|
||||
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=fixed_now,
|
||||
)
|
||||
recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok")
|
||||
session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=recipient.id,
|
||||
selected_action_id="approve",
|
||||
form_data={"k": "v"},
|
||||
submission_user_id="u",
|
||||
submission_end_user_id="eu",
|
||||
)
|
||||
assert form.status == HumanInputFormStatus.SUBMITTED
|
||||
assert form.submitted_at == fixed_now
|
||||
assert record.submitted_data == {"k": "v"}
|
||||
|
||||
|
||||
def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"):
|
||||
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
status=HumanInputFormStatus.TIMEOUT,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r")
|
||||
assert record.status == HumanInputFormStatus.TIMEOUT
|
||||
|
||||
|
||||
def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form already submitted"):
|
||||
repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
|
||||
|
||||
|
||||
def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = _DummyForm(
|
||||
id="f",
|
||||
workflow_run_id=None,
|
||||
node_id="node",
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
form_definition=_make_form_definition_json(include_expiration_time=True),
|
||||
rendered_content="<p>x</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
selected_action_id="a",
|
||||
submitted_data="{}",
|
||||
submission_user_id="u",
|
||||
submission_end_user_id="eu",
|
||||
completed_by_recipient_id="r",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
_patch_session_factory(monkeypatch, session)
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED)
|
||||
assert form.status == HumanInputFormStatus.EXPIRED
|
||||
assert form.selected_action_id is None
|
||||
assert form.submitted_data is None
|
||||
assert form.submission_user_id is None
|
||||
assert form.submission_end_user_id is None
|
||||
assert form.completed_by_recipient_id is None
|
||||
assert record.status == HumanInputFormStatus.EXPIRED
|
||||
|
||||
|
||||
def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_patch_session_factory(monkeypatch, _FakeSession(forms={}))
|
||||
repo = HumanInputFormSubmissionRepository()
|
||||
with pytest.raises(FormNotFoundError, match="form not found"):
|
||||
repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT)
|
||||
|
|
@ -1,84 +1,291 @@
|
|||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, WorkflowRun
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from models import Account, CreatorUserRole, EndUser, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = str(uuid4())
|
||||
user.current_tenant_id = str(uuid4())
|
||||
|
||||
repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=real_session_factory,
|
||||
user=user,
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = False
|
||||
repository._session_factory = MagicMock(return_value=session_context)
|
||||
return repository
|
||||
|
||||
|
||||
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
|
||||
return WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "hello"},
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist():
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
session_factory.return_value.__enter__.return_value = session
|
||||
return session_factory
|
||||
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists():
|
||||
session = MagicMock()
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""Mock SQLAlchemy Engine."""
|
||||
return MagicMock(spec=Engine)
|
||||
|
||||
execution_id = str(uuid4())
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repository._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
session.get.return_value = existing_run
|
||||
|
||||
execution = _build_execution(
|
||||
execution_id=execution_id,
|
||||
started_at=datetime(2026, 1, 1, 12, 30, 0),
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
outputs={"output1": "result1"},
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
error_message="",
|
||||
total_tokens=100,
|
||||
total_steps=5,
|
||||
exceptions_count=0,
|
||||
started_at=datetime.now(UTC),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
class TestSQLAlchemyWorkflowExecutionRepository:
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
app_id = "test_app_id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
assert repo._session_factory == mock_session_factory
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role == CreatorUserRole.ACCOUNT
|
||||
|
||||
def test_init_with_engine(self, mock_engine, mock_account):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_engine,
|
||||
user=mock_account,
|
||||
app_id="test_app_id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert isinstance(repo._session_factory, sessionmaker)
|
||||
assert repo._session_factory.kw["bind"] == mock_engine
|
||||
|
||||
def test_init_invalid_session_factory(self, mock_account):
|
||||
with pytest.raises(ValueError, match="Invalid session_factory type"):
|
||||
SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory="invalid", user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
def test_init_no_tenant_id(self, mock_session_factory):
|
||||
user = MagicMock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None
|
||||
)
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
assert repo._creator_user_role == CreatorUserRole.END_USER
|
||||
|
||||
def test_to_domain_model(self, mock_session_factory, mock_account):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
db_model = MagicMock(spec=WorkflowRun)
|
||||
db_model.id = str(uuid4())
|
||||
db_model.workflow_id = str(uuid4())
|
||||
db_model.type = "workflow"
|
||||
db_model.version = "1.0"
|
||||
db_model.inputs_dict = {"in": "val"}
|
||||
db_model.outputs_dict = {"out": "val"}
|
||||
db_model.graph_dict = {"nodes": []}
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = "some error"
|
||||
db_model.total_tokens = 50
|
||||
db_model.total_steps = 3
|
||||
db_model.exceptions_count = 1
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = datetime.now(UTC)
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
assert domain_model.id_ == db_model.id
|
||||
assert domain_model.workflow_id == db_model.workflow_id
|
||||
assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED
|
||||
assert domain_model.inputs == db_model.inputs_dict
|
||||
assert domain_model.error_message == "some error"
|
||||
|
||||
def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
|
||||
# Make elapsed time deterministic to avoid flaky tests
|
||||
sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC)
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
assert db_model.id == sample_workflow_execution.id_
|
||||
assert db_model.tenant_id == repo._tenant_id
|
||||
assert db_model.app_id == "test_app"
|
||||
assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
assert db_model.status == sample_workflow_execution.status.value
|
||||
assert db_model.total_tokens == sample_workflow_execution.total_tokens
|
||||
assert db_model.elapsed_time == 10.0
|
||||
|
||||
def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
# Test with empty/None fields
|
||||
sample_workflow_execution.graph = None
|
||||
sample_workflow_execution.inputs = None
|
||||
sample_workflow_execution.outputs = None
|
||||
sample_workflow_execution.error_message = None
|
||||
sample_workflow_execution.finished_at = None
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
assert db_model.graph is None
|
||||
assert db_model.inputs is None
|
||||
assert db_model.outputs is None
|
||||
assert db_model.error is None
|
||||
assert db_model.elapsed_time == 0
|
||||
|
||||
def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
db_model = repo._to_db_model(sample_workflow_execution)
|
||||
assert not hasattr(db_model, "app_id") or db_model.app_id is None
|
||||
assert db_model.tenant_id == repo._tenant_id
|
||||
|
||||
def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None
|
||||
)
|
||||
|
||||
# Test triggered_from missing
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
repo._creator_user_id = "some_id"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._to_db_model(sample_workflow_execution)
|
||||
|
||||
def test_save(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.merge.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
# Check cache
|
||||
assert sample_workflow_execution.id_ in repo._execution_cache
|
||||
cached_model = repo._execution_cache[sample_workflow_execution.id_]
|
||||
assert cached_model.id == sample_workflow_execution.id_
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
sample_workflow_execution.started_at = started_at
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = None
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
execution_id = sample_workflow_execution.id_
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repo._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = existing_run
|
||||
|
||||
sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,772 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
_deterministic_json_dump,
|
||||
_filter_by_offload_type,
|
||||
_find_first,
|
||||
_replace_or_append_offload,
|
||||
)
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models import Account, EndUser
|
||||
from models.enums import ExecutionOffLoadType
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account:
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
return user
|
||||
|
||||
|
||||
def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser:
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = user_id
|
||||
user.tenant_id = tenant_id
|
||||
return user
|
||||
|
||||
|
||||
def _execution(
|
||||
*,
|
||||
execution_id: str = "exec-id",
|
||||
node_execution_id: str = "node-exec-id",
|
||||
workflow_run_id: str = "run-id",
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
node_execution_id=node_execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-id",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Title",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
status=status,
|
||||
error=None,
|
||||
elapsed_time=1.0,
|
||||
metadata=metadata,
|
||||
created_at=datetime.now(UTC),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
|
||||
class _SessionCtx:
|
||||
def __init__(self, session: Any):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _session_factory(session: Any) -> sessionmaker:
|
||||
factory = Mock(spec=sessionmaker)
|
||||
factory.return_value = _SessionCtx(session)
|
||||
return factory
|
||||
|
||||
|
||||
def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
engine: Engine = create_engine("sqlite:///:memory:")
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=engine,
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert isinstance(repo._session_factory, sessionmaker)
|
||||
|
||||
sm = Mock(spec=sessionmaker)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=sm,
|
||||
user=_mock_end_user(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert repo._creator_user_role.value == "end_user"
|
||||
|
||||
|
||||
def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Invalid session_factory type"):
|
||||
SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type]
|
||||
session_factory=object(),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
user = _mock_account()
|
||||
user.current_tenant_id = None
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=user,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
created: dict[str, Any] = {}
|
||||
|
||||
class FakeTruncator:
|
||||
def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int):
|
||||
created.update(
|
||||
{
|
||||
"max_size_bytes": max_size_bytes,
|
||||
"array_element_limit": array_element_limit,
|
||||
"string_length_limit": string_length_limit,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator",
|
||||
FakeTruncator,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
_ = repo._create_truncator()
|
||||
assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE
|
||||
|
||||
|
||||
def test_helpers_find_first_and_replace_or_append_and_filter() -> None:
|
||||
assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}'
|
||||
assert _find_first([], lambda _: True) is None
|
||||
assert _find_first([1, 2, 3], lambda x: x > 1) == 2
|
||||
|
||||
off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2
|
||||
|
||||
replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS))
|
||||
assert len(replaced) == 2
|
||||
assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS]
|
||||
|
||||
|
||||
def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1})
|
||||
|
||||
# Happy path: deterministic json dump should be sorted
|
||||
db_model = repo._to_db_model(execution)
|
||||
assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1}
|
||||
assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1
|
||||
|
||||
repo._triggered_from = None
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
|
||||
def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution()
|
||||
db_model = repo._to_db_model(execution)
|
||||
assert db_model.app_id == "app"
|
||||
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
repo._creator_user_id = "user"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._to_db_model(execution)
|
||||
|
||||
|
||||
def test_is_duplicate_key_error_and_regenerate_id(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
unique = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError("dup", params=None, orig=unique)
|
||||
assert repo._is_duplicate_key_error(duplicate_error) is True
|
||||
assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False
|
||||
|
||||
execution = _execution(execution_id="old-id")
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "old-id"
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
|
||||
caplog.set_level(logging.WARNING)
|
||||
repo._regenerate_id_on_duplicate(execution, db_model)
|
||||
assert execution.id == "new-id"
|
||||
assert db_model.id == "new-id"
|
||||
assert any("Duplicate key conflict" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id1"
|
||||
db_model.node_execution_id = "node1"
|
||||
db_model.foo = "bar" # type: ignore[attr-defined]
|
||||
db_model.__dict__["_private"] = "x"
|
||||
|
||||
existing = SimpleNamespace()
|
||||
session.get.return_value = existing
|
||||
repo._persist_to_database(db_model)
|
||||
assert existing.foo == "bar"
|
||||
session.add.assert_not_called()
|
||||
assert repo._node_execution_cache["node1"] is db_model
|
||||
|
||||
session.reset_mock()
|
||||
session.get.return_value = None
|
||||
repo._node_execution_cache.clear()
|
||||
repo._persist_to_database(db_model)
|
||||
session.add.assert_called_once_with(db_model)
|
||||
assert repo._node_execution_cache["node1"] is db_model
|
||||
|
||||
|
||||
def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None
|
||||
|
||||
class FakeTruncator:
|
||||
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
|
||||
return value, False
|
||||
|
||||
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
|
||||
assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None
|
||||
|
||||
|
||||
def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
uploaded: dict[str, Any] = {}
|
||||
|
||||
class FakeFileService:
|
||||
def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def]
|
||||
uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user})
|
||||
return SimpleNamespace(id="file-id", key="file-key")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService()
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id")
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
class FakeTruncator:
|
||||
def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def]
|
||||
return {"truncated": True}, True
|
||||
|
||||
monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator())
|
||||
|
||||
result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS)
|
||||
assert result is not None
|
||||
assert result.truncated_value == {"truncated": True}
|
||||
assert uploaded["filename"].startswith("node_execution_exec_inputs.json")
|
||||
assert result.offload.file_id == "file-id"
|
||||
assert result.offload.type_ == ExecutionOffLoadType.INPUTS
|
||||
|
||||
|
||||
def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id"
|
||||
db_model.node_execution_id = "node-exec"
|
||||
db_model.workflow_id = "wf"
|
||||
db_model.workflow_run_id = "run"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = BuiltinNodeTypes.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"trunc": "i"})
|
||||
db_model.process_data = json.dumps({"trunc": "p"})
|
||||
db_model.outputs = json.dumps({"trunc": "o"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 0.1
|
||||
db_model.execution_metadata = json.dumps({"total_tokens": 3})
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
|
||||
off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
|
||||
off_in.file = SimpleNamespace(key="k-in")
|
||||
off_out.file = SimpleNamespace(key="k-out")
|
||||
off_proc.file = SimpleNamespace(key="k-proc")
|
||||
db_model.offload_data = [off_out, off_in, off_proc]
|
||||
|
||||
def fake_load(key: str) -> bytes:
|
||||
return json.dumps({"full": key}).encode()
|
||||
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load)
|
||||
|
||||
domain = repo._to_domain_model(db_model)
|
||||
assert domain.inputs == {"full": "k-in"}
|
||||
assert domain.outputs == {"full": "k-out"}
|
||||
assert domain.process_data == {"full": "k-proc"}
|
||||
assert domain.get_truncated_inputs() == {"trunc": "i"}
|
||||
assert domain.get_truncated_outputs() == {"trunc": "o"}
|
||||
assert domain.get_truncated_process_data() == {"trunc": "p"}
|
||||
|
||||
|
||||
def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "id"
|
||||
db_model.node_execution_id = "node-exec"
|
||||
db_model.workflow_id = "wf"
|
||||
db_model.workflow_run_id = "run"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node"
|
||||
db_model.node_type = BuiltinNodeTypes.LLM
|
||||
db_model.title = "t"
|
||||
db_model.inputs = json.dumps({"i": 1})
|
||||
db_model.process_data = json.dumps({"p": 2})
|
||||
db_model.outputs = json.dumps({"o": 3})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 0.1
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
db_model.offload_data = []
|
||||
|
||||
domain = repo._to_domain_model(db_model)
|
||||
assert domain.inputs == {"i": 1}
|
||||
assert domain.outputs == {"o": 3}
|
||||
|
||||
|
||||
def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FakeConverter:
|
||||
def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return {"wrapped": values["a"]}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter",
|
||||
FakeConverter,
|
||||
)
|
||||
assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}'
|
||||
|
||||
|
||||
def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace(
|
||||
id="id",
|
||||
offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)],
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
process_data=None,
|
||||
)
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
|
||||
|
||||
trunc_result = SimpleNamespace(
|
||||
truncated_value={"trunc": True},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None
|
||||
)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
# Inputs should be truncated, outputs/process_data encoded directly
|
||||
db_model = session.merge.call_args.args[0]
|
||||
assert json.loads(db_model.inputs) == {"trunc": True}
|
||||
assert json.loads(db_model.outputs) == {"b": 2}
|
||||
assert json.loads(db_model.process_data) == {"c": 3}
|
||||
assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data)
|
||||
assert execution.get_truncated_inputs() == {"trunc": True}
|
||||
|
||||
|
||||
def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
existing = SimpleNamespace(
|
||||
id="id",
|
||||
offload_data=[],
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
process_data=None,
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = existing
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3})
|
||||
|
||||
def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any:
|
||||
if values == {"b": 2}:
|
||||
return SimpleNamespace(
|
||||
truncated_value={"b": "trunc"},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"),
|
||||
)
|
||||
if values == {"c": 3}:
|
||||
return SimpleNamespace(
|
||||
truncated_value={"c": "trunc"},
|
||||
offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"),
|
||||
)
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(repo, "_truncate_and_upload", trunc)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
db_model = session.merge.call_args.args[0]
|
||||
assert json.loads(db_model.outputs) == {"b": "trunc"}
|
||||
assert json.loads(db_model.process_data) == {"c": "trunc"}
|
||||
assert execution.get_truncated_outputs() == {"b": "trunc"}
|
||||
assert execution.get_truncated_process_data() == {"c": "trunc"}
|
||||
|
||||
|
||||
def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
session.merge = Mock()
|
||||
session.flush = Mock()
|
||||
session.begin.return_value.__enter__ = Mock(return_value=session)
|
||||
session.begin.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(inputs={"a": 1})
|
||||
fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None)
|
||||
monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model)
|
||||
monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values))
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
merged = session.merge.call_args.args[0]
|
||||
assert merged.inputs == '{"a": 1}'
|
||||
|
||||
|
||||
def test_save_retries_duplicate_and_logs_non_duplicate(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = _execution(execution_id="id")
|
||||
unique = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError("dup", params=None, orig=unique)
|
||||
other_error = IntegrityError("other", params=None, orig=None)
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
def persist(_db_model: Any) -> None:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise duplicate_error
|
||||
|
||||
monkeypatch.setattr(repo, "_persist_to_database", persist)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id")
|
||||
repo.save(execution)
|
||||
assert execution.id == "new-id"
|
||||
assert repo._node_execution_cache[execution.node_execution_id] is not None
|
||||
|
||||
caplog.set_level(logging.ERROR)
|
||||
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error))
|
||||
with pytest.raises(IntegrityError):
|
||||
repo.save(_execution(execution_id="id2", node_execution_id="node2"))
|
||||
assert any("Non-duplicate key integrity error" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_save_logs_and_reraises_on_unexpected_error(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
caplog.set_level(logging.ERROR)
|
||||
monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
repo.save(_execution(execution_id="id3", node_execution_id="node3"))
|
||||
assert any("Failed to save workflow node execution" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
class FakeStmt:
|
||||
def __init__(self) -> None:
|
||||
self.where_calls = 0
|
||||
self.order_by_args: tuple[Any, ...] | None = None
|
||||
|
||||
def where(self, *_args: Any) -> FakeStmt:
|
||||
self.where_calls += 1
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any) -> FakeStmt:
|
||||
self.order_by_args = args
|
||||
return self
|
||||
|
||||
stmt = FakeStmt()
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
|
||||
lambda _q: stmt,
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
|
||||
|
||||
model1 = SimpleNamespace(node_execution_id="n1")
|
||||
model2 = SimpleNamespace(node_execution_id=None)
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [model1, model2]
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
order = OrderConfig(order_by=["index", "missing"], order_direction="desc")
|
||||
db_models = repo.get_db_models_by_workflow_run("run", order)
|
||||
assert db_models == [model1, model2]
|
||||
assert repo._node_execution_cache["n1"] is model1
|
||||
assert stmt.order_by_args is not None
|
||||
|
||||
|
||||
def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
class FakeStmt:
|
||||
def where(self, *_args: Any) -> FakeStmt:
|
||||
return self
|
||||
|
||||
def order_by(self, *args: Any) -> FakeStmt:
|
||||
self.args = args # type: ignore[attr-defined]
|
||||
return self
|
||||
|
||||
stmt = FakeStmt()
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files",
|
||||
lambda _q: stmt,
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select")
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=_session_factory(session),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc"))
|
||||
|
||||
|
||||
def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")]
|
||||
monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models)
|
||||
monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}")
|
||||
|
||||
class FakeExecutor:
|
||||
def __enter__(self) -> FakeExecutor:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
def map(self, func, items, timeout: int): # type: ignore[no-untyped-def]
|
||||
assert timeout == 30
|
||||
return list(map(func, items))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor",
|
||||
lambda max_workers: FakeExecutor(),
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("run", order_config=None)
|
||||
assert result == ["domain:db1", "domain:db2"]
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
|
||||
|
||||
class TestSchemaRegistry:
|
||||
def test_initialization(self, tmp_path):
|
||||
base_dir = tmp_path / "schemas"
|
||||
base_dir.mkdir()
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
assert registry.base_dir == base_dir
|
||||
assert registry.versions == {}
|
||||
assert registry.metadata == {}
|
||||
|
||||
def test_default_registry_singleton(self):
|
||||
registry1 = SchemaRegistry.default_registry()
|
||||
registry2 = SchemaRegistry.default_registry()
|
||||
assert registry1 is registry2
|
||||
assert isinstance(registry1, SchemaRegistry)
|
||||
|
||||
def test_load_all_versions_non_existent_dir(self, tmp_path):
|
||||
base_dir = tmp_path / "non_existent"
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
registry.load_all_versions()
|
||||
assert registry.versions == {}
|
||||
|
||||
def test_load_all_versions_filtering(self, tmp_path):
|
||||
base_dir = tmp_path / "schemas"
|
||||
base_dir.mkdir()
|
||||
(base_dir / "not_a_version_dir").mkdir()
|
||||
(base_dir / "v1").mkdir()
|
||||
(base_dir / "some_file.txt").write_text("content")
|
||||
|
||||
registry = SchemaRegistry(str(base_dir))
|
||||
with patch.object(registry, "_load_version_dir") as mock_load:
|
||||
registry.load_all_versions()
|
||||
mock_load.assert_called_once()
|
||||
assert mock_load.call_args[0][0] == "v1"
|
||||
|
||||
def test_load_version_dir_filtering(self, tmp_path):
|
||||
version_dir = tmp_path / "v1"
|
||||
version_dir.mkdir()
|
||||
(version_dir / "schema1.json").write_text("{}")
|
||||
(version_dir / "not_a_schema.txt").write_text("content")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
with patch.object(registry, "_load_schema") as mock_load:
|
||||
registry._load_version_dir("v1", version_dir)
|
||||
mock_load.assert_called_once()
|
||||
assert mock_load.call_args[0][1] == "schema1"
|
||||
|
||||
def test_load_version_dir_non_existent(self, tmp_path):
|
||||
version_dir = tmp_path / "non_existent"
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry._load_version_dir("v1", version_dir)
|
||||
assert "v1" not in registry.versions
|
||||
|
||||
def test_load_schema_success(self, tmp_path):
|
||||
schema_path = tmp_path / "test.json"
|
||||
schema_content = {"title": "Test Schema", "description": "A test schema"}
|
||||
schema_path.write_text(json.dumps(schema_content))
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
registry._load_schema("v1", "test", schema_path)
|
||||
|
||||
assert registry.versions["v1"]["test"] == schema_content
|
||||
uri = "https://dify.ai/schemas/v1/test.json"
|
||||
assert registry.metadata[uri]["title"] == "Test Schema"
|
||||
assert registry.metadata[uri]["version"] == "v1"
|
||||
|
||||
def test_load_schema_invalid_json(self, tmp_path, caplog):
|
||||
schema_path = tmp_path / "invalid.json"
|
||||
schema_path.write_text("invalid json")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
registry._load_schema("v1", "invalid", schema_path)
|
||||
|
||||
assert "Failed to load schema v1/invalid" in caplog.text
|
||||
|
||||
def test_load_schema_os_error(self, tmp_path, caplog):
|
||||
schema_path = tmp_path / "error.json"
|
||||
schema_path.write_text("{}")
|
||||
|
||||
registry = SchemaRegistry(str(tmp_path))
|
||||
registry.versions["v1"] = {}
|
||||
|
||||
with patch("builtins.open", side_effect=OSError("Read error")):
|
||||
registry._load_schema("v1", "error", schema_path)
|
||||
|
||||
assert "Failed to load schema v1/error" in caplog.text
|
||||
|
||||
def test_get_schema(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"test": {"type": "object"}}}
|
||||
|
||||
# Valid URI
|
||||
assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"}
|
||||
|
||||
# Invalid URI
|
||||
assert registry.get_schema("invalid-uri") is None
|
||||
|
||||
# Missing version
|
||||
assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None
|
||||
|
||||
def test_list_versions(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v2": {}, "v1": {}}
|
||||
assert registry.list_versions() == ["v1", "v2"]
|
||||
|
||||
def test_list_schemas(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"b": {}, "a": {}}}
|
||||
|
||||
assert registry.list_schemas("v1") == ["a", "b"]
|
||||
assert registry.list_schemas("v2") == []
|
||||
|
||||
def test_get_all_schemas_for_version(self):
|
||||
registry = SchemaRegistry("/tmp")
|
||||
registry.versions = {"v1": {"test": {"title": "Test Label"}}}
|
||||
|
||||
results = registry.get_all_schemas_for_version("v1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "test"
|
||||
assert results[0]["label"] == "Test Label"
|
||||
assert results[0]["schema"] == {"title": "Test Label"}
|
||||
|
||||
# Default label if title missing
|
||||
registry.versions["v1"]["no_title"] = {}
|
||||
results = registry.get_all_schemas_for_version("v1")
|
||||
item = next(r for r in results if r["name"] == "no_title")
|
||||
assert item["label"] == "no_title"
|
||||
|
||||
# Empty if version missing
|
||||
assert registry.get_all_schemas_for_version("v2") == []
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
|
||||
|
||||
def test_init_with_provided_registry():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
assert manager.registry == mock_registry
|
||||
|
||||
|
||||
@patch("core.schemas.schema_manager.SchemaRegistry.default_registry")
|
||||
def test_init_with_default_registry(mock_default_registry):
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_default_registry.return_value = mock_registry
|
||||
|
||||
manager = SchemaManager()
|
||||
|
||||
mock_default_registry.assert_called_once()
|
||||
assert manager.registry == mock_registry
|
||||
|
||||
|
||||
def test_get_all_schema_definitions():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}]
|
||||
mock_registry.get_all_schemas_for_version.return_value = expected_definitions
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_all_schema_definitions(version="v2")
|
||||
|
||||
mock_registry.get_all_schemas_for_version.assert_called_once_with("v2")
|
||||
assert result == expected_definitions
|
||||
|
||||
|
||||
def test_get_schema_by_name_success():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_schema = {"type": "object"}
|
||||
mock_registry.get_schema.return_value = mock_schema
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_schema_by_name("my_schema", version="v1")
|
||||
|
||||
expected_uri = "https://dify.ai/schemas/v1/my_schema.json"
|
||||
mock_registry.get_schema.assert_called_once_with(expected_uri)
|
||||
assert result == {"name": "my_schema", "schema": mock_schema}
|
||||
|
||||
|
||||
def test_get_schema_by_name_not_found():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
mock_registry.get_schema.return_value = None
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.get_schema_by_name("non_existent", version="v1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_list_available_schemas():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_schemas = ["schema1", "schema2"]
|
||||
mock_registry.list_schemas.return_value = expected_schemas
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.list_available_schemas(version="v1")
|
||||
|
||||
mock_registry.list_schemas.assert_called_once_with("v1")
|
||||
assert result == expected_schemas
|
||||
|
||||
|
||||
def test_list_available_versions():
|
||||
mock_registry = MagicMock(spec=SchemaRegistry)
|
||||
expected_versions = ["v1", "v2"]
|
||||
mock_registry.list_versions.return_value = expected_versions
|
||||
|
||||
manager = SchemaManager(registry=mock_registry)
|
||||
result = manager.list_available_versions()
|
||||
|
||||
mock_registry.list_versions.assert_called_once()
|
||||
assert result == expected_versions
|
||||
|
|
@ -95,13 +95,11 @@ class TestGitHubOAuth(BaseOAuthTest):
|
|||
],
|
||||
"primary@example.com",
|
||||
),
|
||||
# User with no emails - fallback to noreply
|
||||
({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
|
||||
# User with only secondary email - fallback to noreply
|
||||
# User with private email (null email and name from API)
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||
[{"email": "secondary@example.com", "primary": False}],
|
||||
"12345+testuser@users.noreply.github.com",
|
||||
{"id": 12345, "login": "testuser", "name": None, "email": None},
|
||||
[{"email": "primary@example.com", "primary": True}],
|
||||
"primary@example.com",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -118,9 +116,54 @@ class TestGitHubOAuth(BaseOAuthTest):
|
|||
user_info = oauth.get_user_info("test_token")
|
||||
|
||||
assert user_info.id == str(user_data["id"])
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.name == (user_data["name"] or "")
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_data", "email_data"),
|
||||
[
|
||||
# User with no emails
|
||||
({"id": 12345, "login": "testuser", "name": "Test User"}, []),
|
||||
# User with only secondary email
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||
[{"email": "secondary@example.com", "primary": False}],
|
||||
),
|
||||
# User with private email and no primary in emails endpoint
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": None, "email": None},
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_raise_error_when_no_primary_email(self, mock_get, oauth, user_data, email_data):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
|
||||
email_response = MagicMock()
|
||||
email_response.json.return_value = email_data
|
||||
|
||||
mock_get.side_effect = [user_response, email_response]
|
||||
|
||||
with pytest.raises(ValueError, match="Keep my email addresses private"):
|
||||
oauth.get_user_info("test_token")
|
||||
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_raise_error_when_email_endpoint_fails(self, mock_get, oauth):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = {"id": 12345, "login": "testuser", "name": "Test User"}
|
||||
|
||||
email_response = MagicMock()
|
||||
email_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Forbidden", request=MagicMock(), response=MagicMock()
|
||||
)
|
||||
|
||||
mock_get.side_effect = [user_response, email_response]
|
||||
|
||||
with pytest.raises(ValueError, match="Keep my email addresses private"):
|
||||
oauth.get_user_info("test_token")
|
||||
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
|
|
@ -324,7 +325,7 @@ class TestConversationModel:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=from_end_user_id,
|
||||
)
|
||||
|
||||
|
|
@ -345,7 +346,7 @@ class TestConversationModel:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
conversation._inputs = inputs
|
||||
|
|
@ -364,7 +365,7 @@ class TestConversationModel:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
inputs = {"query": "Hello", "context": "test"}
|
||||
|
|
@ -383,7 +384,7 @@ class TestConversationModel:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
summary="Test summary",
|
||||
)
|
||||
|
|
@ -402,7 +403,7 @@ class TestConversationModel:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
summary=None,
|
||||
)
|
||||
|
|
@ -425,7 +426,7 @@ class TestConversationModel:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
override_model_configs='{"model": "gpt-4"}',
|
||||
)
|
||||
|
|
@ -446,7 +447,7 @@ class TestConversationModel:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=from_end_user_id,
|
||||
dialogue_count=5,
|
||||
)
|
||||
|
|
@ -487,7 +488,7 @@ class TestMessageModel:
|
|||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
@ -511,7 +512,7 @@ class TestMessageModel:
|
|||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
message._inputs = inputs
|
||||
|
||||
|
|
@ -533,7 +534,7 @@ class TestMessageModel:
|
|||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
inputs = {"query": "Hello", "context": "test"}
|
||||
|
||||
|
|
@ -555,7 +556,7 @@ class TestMessageModel:
|
|||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
override_model_configs='{"model": "gpt-4"}',
|
||||
)
|
||||
|
||||
|
|
@ -578,7 +579,7 @@ class TestMessageModel:
|
|||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
message_metadata=json.dumps(metadata),
|
||||
)
|
||||
|
||||
|
|
@ -600,7 +601,7 @@ class TestMessageModel:
|
|||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
message_metadata=None,
|
||||
)
|
||||
|
||||
|
|
@ -627,7 +628,7 @@ class TestMessageModel:
|
|||
answer_unit_price=Decimal("0.0002"),
|
||||
total_price=Decimal("0.0003"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
status="normal",
|
||||
)
|
||||
message.id = str(uuid4())
|
||||
|
|
@ -988,7 +989,7 @@ class TestModelIntegration:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
|
@ -1003,7 +1004,7 @@ class TestModelIntegration:
|
|||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
message.id = message_id
|
||||
|
||||
|
|
@ -1064,7 +1065,7 @@ class TestModelIntegration:
|
|||
message_unit_price=Decimal("0.0001"),
|
||||
answer_unit_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
message.id = message_id
|
||||
|
||||
|
|
@ -1158,7 +1159,7 @@ class TestConversationStatusCount:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = str(uuid4())
|
||||
|
||||
|
|
@ -1183,7 +1184,7 @@ class TestConversationStatusCount:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
|
|
@ -1215,7 +1216,7 @@ class TestConversationStatusCount:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
|
|
@ -1307,7 +1308,7 @@ class TestConversationStatusCount:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
|
|
@ -1361,7 +1362,7 @@ class TestConversationStatusCount:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
|
|
@ -1418,7 +1419,7 @@ class TestConversationStatusCount:
|
|||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ This test suite covers:
|
|||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
|
||||
from models.tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
|
|
@ -631,7 +631,7 @@ class TestToolLabelBinding:
|
|||
"""Test creating a tool label binding."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
tool_type = ToolProviderType.BUILT_IN
|
||||
label_name = "search"
|
||||
|
||||
# Act
|
||||
|
|
@ -655,7 +655,7 @@ class TestToolLabelBinding:
|
|||
# Act
|
||||
label_binding = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
|
|
@ -667,7 +667,7 @@ class TestToolLabelBinding:
|
|||
"""Test multiple labels can be bound to the same tool."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
tool_type = ToolProviderType.BUILT_IN
|
||||
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
|
|
@ -688,7 +688,7 @@ class TestToolLabelBinding:
|
|||
def test_tool_label_binding_different_tool_types(self):
|
||||
"""Test label bindings for different tool types."""
|
||||
# Arrange
|
||||
tool_types = ["builtin", "api", "workflow"]
|
||||
tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW]
|
||||
|
||||
# Act & Assert
|
||||
for tool_type in tool_types:
|
||||
|
|
@ -951,12 +951,12 @@ class TestToolProviderRelationships:
|
|||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name="search",
|
||||
)
|
||||
binding2 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
label_name="web",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,18 @@ from unittest import mock
|
|||
from uuid import uuid4
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import encrypter
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from dify_graph.variables.segments import IntegerSegment, Segment
|
||||
from factories.variable_factory import build_segment
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDraftVariable,
|
||||
WorkflowNodeExecutionModel,
|
||||
is_system_variable_editable,
|
||||
)
|
||||
|
||||
|
||||
def test_environment_variables():
|
||||
|
|
@ -144,6 +150,36 @@ def test_to_dict():
|
|||
assert workflow_dict["environment_variables"][1]["value"] == "text"
|
||||
|
||||
|
||||
def test_normalize_environment_variable_mappings_converts_full_mask_to_hidden_value():
|
||||
normalized = Workflow.normalize_environment_variable_mappings(
|
||||
[
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "secret",
|
||||
"value": encrypter.full_mask_token(),
|
||||
"value_type": "secret",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert normalized[0]["value"] == HIDDEN_VALUE
|
||||
|
||||
|
||||
def test_normalize_environment_variable_mappings_keeps_hidden_value():
|
||||
normalized = Workflow.normalize_environment_variable_mappings(
|
||||
[
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "secret",
|
||||
"value": HIDDEN_VALUE,
|
||||
"value_type": "secret",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert normalized[0]["value"] == HIDDEN_VALUE
|
||||
|
||||
|
||||
class TestWorkflowNodeExecution:
|
||||
def test_execution_metadata_dict(self):
|
||||
node_exec = WorkflowNodeExecutionModel()
|
||||
|
|
|
|||
|
|
@ -1,135 +0,0 @@
|
|||
"""Unit tests for non-SQL helper logic in workflow run repository."""
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus
|
||||
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowPauseReason
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
_build_human_input_required_reason,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause() -> Mock:
|
||||
"""Create a sample WorkflowPause model."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
|
||||
# Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||
|
||||
def test_get_state(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result == expected_state
|
||||
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result1 = entity.get_state()
|
||||
result2 = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once()
|
||||
|
||||
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
"""Test helper that builds HumanInputRequired pause reasons."""
|
||||
|
||||
def test_prefers_backstage_token_when_available(self) -> None:
|
||||
"""Use backstage token when multiple recipient types may exist."""
|
||||
# Arrange
|
||||
expiration_time = datetime.now(UTC)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id="form-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id="pause-1",
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
access_token = secrets.token_urlsafe(8)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id="form-1",
|
||||
delivery_id="delivery-1",
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Act
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
|
||||
|
||||
# Assert
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.form_token == access_token
|
||||
assert reason.node_title == "Ask Name"
|
||||
assert reason.form_content == "content"
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
|
|
@ -1,251 +0,0 @@
|
|||
"""Unit tests for workflow run repository with status filter."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from models import WorkflowRun, WorkflowRunTriggeredFrom
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Test workflow run repository with status filtering."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self):
|
||||
"""Create a mock session maker."""
|
||||
return MagicMock(spec=sessionmaker)
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_maker):
|
||||
"""Create repository instance with mock session."""
|
||||
return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||
|
||||
def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker):
|
||||
"""Test getting paginated workflow runs without status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_runs
|
||||
|
||||
# Act
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 3
|
||||
assert result.limit == 20
|
||||
assert result.has_more is False
|
||||
|
||||
def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker):
|
||||
"""Test getting paginated workflow runs with status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_runs
|
||||
|
||||
# Act
|
||||
result = repository.get_paginated_workflow_runs(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
limit=20,
|
||||
last_id=None,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.data) == 2
|
||||
assert all(run.status == "succeeded" for run in result.data)
|
||||
|
||||
def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count without status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the GROUP BY query results
|
||||
mock_results = [
|
||||
("succeeded", 5),
|
||||
("failed", 2),
|
||||
("running", 1),
|
||||
]
|
||||
mock_session.execute.return_value.all.return_value = mock_results
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 8
|
||||
assert result["succeeded"] == 5
|
||||
assert result["failed"] == 2
|
||||
assert result["running"] == 1
|
||||
assert result["stopped"] == 0
|
||||
assert result["partial-succeeded"] == 0
|
||||
|
||||
def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with status filter."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the count query for succeeded status
|
||||
mock_session.scalar.return_value = 5
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 5
|
||||
assert result["succeeded"] == 5
|
||||
assert result["running"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["stopped"] == 0
|
||||
assert result["partial-succeeded"] == 0
|
||||
|
||||
def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker):
|
||||
"""Test that invalid status is still counted in total but not in any specific status."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock count query returning 0 for invalid status
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="invalid_status",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["total"] == 0
|
||||
assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"])
|
||||
|
||||
def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with time range filter verifies SQL query construction."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the GROUP BY query results
|
||||
mock_results = [
|
||||
("succeeded", 3),
|
||||
("running", 2),
|
||||
]
|
||||
mock_session.execute.return_value.all.return_value = mock_results
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status=None,
|
||||
time_range="1d",
|
||||
)
|
||||
|
||||
# Assert results
|
||||
assert result["total"] == 5
|
||||
assert result["succeeded"] == 3
|
||||
assert result["running"] == 2
|
||||
assert result["failed"] == 0
|
||||
|
||||
# Verify that execute was called (which means GROUP BY query was used)
|
||||
assert mock_session.execute.called, "execute should have been called for GROUP BY query"
|
||||
|
||||
# Verify SQL query includes time filter by checking the statement
|
||||
call_args = mock_session.execute.call_args
|
||||
assert call_args is not None, "execute should have been called with a statement"
|
||||
|
||||
# The first argument should be the SQL statement
|
||||
stmt = call_args[0][0]
|
||||
# Convert to string to inspect the query
|
||||
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Verify the query includes created_at filter
|
||||
# The query should have a WHERE clause with created_at comparison
|
||||
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
|
||||
"Query should include created_at filter for time range"
|
||||
)
|
||||
|
||||
def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker):
|
||||
"""Test getting workflow runs count with both status and time range filters verifies SQL query."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the count query for running status within time range
|
||||
mock_session.scalar.return_value = 2
|
||||
|
||||
# Act
|
||||
result = repository.get_workflow_runs_count(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
status="running",
|
||||
time_range="1d",
|
||||
)
|
||||
|
||||
# Assert results
|
||||
assert result["total"] == 2
|
||||
assert result["running"] == 2
|
||||
assert result["succeeded"] == 0
|
||||
assert result["failed"] == 0
|
||||
|
||||
# Verify that scalar was called (which means COUNT query was used)
|
||||
assert mock_session.scalar.called, "scalar should have been called for count query"
|
||||
|
||||
# Verify SQL query includes both status and time filter
|
||||
call_args = mock_session.scalar.call_args
|
||||
assert call_args is not None, "scalar should have been called with a statement"
|
||||
|
||||
# The first argument should be the SQL statement
|
||||
stmt = call_args[0][0]
|
||||
# Convert to string to inspect the query
|
||||
query_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
|
||||
# Verify the query includes both filters
|
||||
assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), (
|
||||
"Query should include created_at filter for time range"
|
||||
)
|
||||
assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), (
|
||||
"Query should include status filter"
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue