Merge branch 'main' of https://github.com/langgenius/dify into chore/core.rag.cleaner-data_post_processor-datasource_unit-tests

This commit is contained in:
Rajat Agarwal 2026-03-24 16:14:19 +05:30
commit 7d23da5432
473 changed files with 32259 additions and 14820 deletions

13
.gemini/config.yaml Normal file
View File

@ -0,0 +1,13 @@
have_fun: false
memory_config:
disabled: false
code_review:
disable: true
comment_severity_threshold: MEDIUM
max_review_comments: -1
pull_request_opened:
help: false
summary: false
code_review: false
include_drafts: false
ignore_patterns: []

View File

@ -4,10 +4,9 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
with:
node-version-file: web/.nvmrc
working-directory: web
node-version-file: .nvmrc
cache: true
cache-dependency-path: web/pnpm-lock.yaml
run-install: |
cwd: ./web
run-install: true

View File

@ -84,20 +84,20 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
vp run lint:ci
# pnpm run lint:report
# continue-on-error: true
# - name: Annotate Code
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
# with:
# eslint-report: web/eslint_report.json
# github-token: ${{ secrets.GITHUB_TOKEN }}
run: vp run lint:ci
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
@ -114,6 +114,13 @@ jobs:
working-directory: ./web
run: vp run knip
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -353,6 +353,9 @@ BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
# Upstash configuration
UPSTASH_VECTOR_URL=your-server-url

View File

@ -10,6 +10,7 @@ from configs import dify_config
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
@ -269,7 +270,7 @@ def migrate_knowledge_vector_database():
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == "hierarchical_model":
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []

View File

@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings):
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
default="COARSE_MODE",
)
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field(
description="Auto build row count increment threshold (default is 500)",
default=500,
)
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field(
description="Auto build row count increment ratio threshold (default is 0.05)",
default=0.05,
)
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field(
description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)",
default=300,
)

View File

@ -9,6 +9,7 @@ from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
resource_type: ApiTokenType | None = None
resource_model: type | None = None
resource_id_field: str | None = None
token_prefix: str | None = None
@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource):
)
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
assert self.resource_type is not None, "resource_type must be set"
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_tenant_id
@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource):
class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
resource_type: ApiTokenType | None = None
resource_model: type | None = None
resource_id_field: str | None = None
@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app"""
return super().post(resource_id)
resource_type = "app"
resource_type = ApiTokenType.APP
resource_model = App
resource_id_field = "app_id"
token_prefix = "app-"
@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
resource_type = "app"
resource_type = ApiTokenType.APP
resource_model = App
resource_id_field = "app_id"
@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset"""
return super().post(resource_id)
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
resource_model = Dataset
resource_id_field = "dataset_id"
token_prefix = "ds-"
@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
resource_model = Dataset
resource_id_field = "dataset_id"

View File

@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel):
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel):
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -594,7 +594,7 @@ class AppApi(Resource):
args_dict: AppService.ArgsDict = {
"name": args.name,
"description": args.description or "",
"icon_type": args.icon_type or "",
"icon_type": args.icon_type,
"icon": args.icon or "",
"icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,

View File

@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
subquery = (
db.session.query(
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
)
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery()
)
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
)
if not conversation:

View File

@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
app = db.session.get(App, args.flow_id)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)

View File

@ -2,6 +2,7 @@ import json
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
return server
@console_ns.doc("create_app_mcp_server")
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
server = db.session.get(AppMCPServer, payload.id)
if not server:
raise NotFound()
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
server = db.session.scalar(
select(AppMCPServer)
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
.limit(1)
)
if not server:
raise NotFound()

View File

@ -4,7 +4,7 @@ from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -244,27 +244,25 @@ class ChatMessageListApi(Resource):
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = (
db.session.query(Conversation)
conversation = db.session.scalar(
select(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first()
.limit(1)
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args.first_id:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
first_message = db.session.scalar(
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
)
if not first_message:
raise NotFound("First message not found")
history_messages = (
db.session.query(Message)
history_messages = db.session.scalars(
select(Message)
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
@ -272,16 +270,14 @@ class ChatMessageListApi(Resource):
)
.order_by(Message.created_at.desc())
.limit(args.limit)
.all()
)
).all()
else:
history_messages = (
db.session.query(Message)
history_messages = db.session.scalars(
select(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args.limit)
.all()
)
).all()
# Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit:
@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource):
message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")
@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
count = db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
)
return {"count": count}
@ -479,7 +479,9 @@ class MessageApi(Resource):
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")

View File

@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
)
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
if original_app_model_config is None:
raise ValueError("Original app model config not found")
agent_mode = original_app_model_config.agent_mode_dict

View File

@ -2,6 +2,7 @@ from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
@ -75,7 +76,7 @@ class AppSite(Resource):
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound

View File

@ -2,6 +2,8 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from sqlalchemy import select
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
return app_model

View File

@ -1,4 +1,5 @@
import logging
import urllib.parse
import httpx
from flask import current_app, redirect, request
@ -112,6 +113,9 @@ class OAuthCallback(Resource):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400
except ValueError as e:
logger.warning("OAuth error with %s", provider, exc_info=True)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService.get_invitation_by_token(token=invite_token)

View File

@ -54,7 +54,7 @@ from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import SegmentStatus
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -777,7 +777,7 @@ class DatasetIndexingStatusApi(Resource):
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = "dataset-"
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@ -826,7 +826,7 @@ class DatasetApiKeyApi(Resource):
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
class DatasetApiDeleteApi(Resource):
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
@console_ns.doc("delete_dataset_api_key")
@console_ns.doc(description="Delete dataset API key")

View File

@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
app_id=self._application_generate_entity.app_config.app_id,
workflow_id=self._workflow.id,
workflow_run_id=workflow_run_id,
created_from=created_from.value,
created_from=created_from,
created_by_role=self._created_by_role,
created_by=self._user_id,
)

View File

@ -19,6 +19,7 @@ class RateLimit:
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
max_active_requests: int
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
@ -27,7 +28,13 @@ class RateLimit:
return cls._instance_dict[client_id]
def __init__(self, client_id: str, max_active_requests: int):
flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests
self.max_active_requests = max_active_requests
# Only flush here if this instance has already been fully initialized,
# i.e. the Redis key attributes exist. Otherwise, rely on the flush at
# the end of initialization below.
if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"):
self.flush_cache(use_local_value=True)
# must be called after max_active_requests is set
if self.disabled():
return
@ -41,8 +48,6 @@ class RateLimit:
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
if self.disabled():
return
self.last_recalculate_time = time.time()
# flush max active requests
if use_local_value or not redis_client.exists(self.max_active_requests_key):
@ -50,7 +55,8 @@ class RateLimit:
else:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
if self.disabled():
return
# flush max active requests (in-transit request list)
if not redis_client.exists(self.active_requests_key):
return

View File

@ -181,10 +181,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
arize_phoenix_config: ArizeConfig | PhoenixConfig,
):
super().__init__(arize_phoenix_config)
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
self.arize_phoenix_config = arize_phoenix_config
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project

View File

@ -918,11 +918,11 @@ class ProviderManager:
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
pool_type=ProviderQuotaType.TRIAL,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
pool_type=ProviderQuotaType.PAID,
)
else:
trail_pool = None

View File

@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore
from pymochow.model.database import Database
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
from pymochow.model.schema import (
AutoBuildRowCountIncrement,
Field,
FilteringIndex,
HNSWParams,
@ -51,6 +52,9 @@ class BaiduConfig(BaseModel):
replicas: int = 3
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
inverted_index_parser_mode: str = "COARSE_MODE"
auto_build_row_count_increment: int = 500
auto_build_row_count_increment_ratio: float = 0.05
rebuild_index_timeout_in_seconds: int = 300
@model_validator(mode="before")
@classmethod
@ -107,18 +111,6 @@ class BaiduVector(BaseVector):
rows.append(row)
table.upsert(rows=rows)
# rebuild vector index after upsert finished
table.rebuild_index(self.vector_index)
timeout = 3600 # 1 hour timeout
start_time = time.time()
while True:
time.sleep(1)
index = table.describe_index(self.vector_index)
if index.state == IndexState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
if res and res.code == 0:
@ -232,8 +224,14 @@ class BaiduVector(BaseVector):
return self._client.database(self._client_config.database)
def _table_existed(self) -> bool:
tables = self._db.list_table()
return any(table.table_name == self._collection_name for table in tables)
try:
table = self._db.table(self._collection_name)
except ServerError as e:
if e.code == ServerErrCode.TABLE_NOT_EXIST:
return False
else:
raise
return True
def _create_table(self, dimension: int):
# Try to grab distributed lock and create table
@ -287,6 +285,11 @@ class BaiduVector(BaseVector):
field=VDBField.VECTOR,
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
auto_build=True,
auto_build_index_policy=AutoBuildRowCountIncrement(
row_count_increment=self._client_config.auto_build_row_count_increment,
row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio,
),
)
)
@ -335,7 +338,7 @@ class BaiduVector(BaseVector):
)
# Wait for table created
timeout = 300 # 5 minutes timeout
timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout
start_time = time.time()
while True:
time.sleep(1)
@ -345,6 +348,20 @@ class BaiduVector(BaseVector):
if time.time() - start_time > timeout:
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
redis_client.set(table_exist_cache_key, 1, ex=3600)
# rebuild vector index immediately after table created, make sure index is ready
table.rebuild_index(self.vector_index)
timeout = 3600 # 1 hour timeout
self._wait_for_index_ready(table, timeout)
def _wait_for_index_ready(self, table, timeout: int = 3600):
start_time = time.time()
while True:
time.sleep(1)
index = table.describe_index(self.vector_index)
if index.state == IndexState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
class BaiduVectorFactory(AbstractVectorFactory):
@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory):
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT,
auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO,
rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS,
),
)

View File

@ -33,6 +33,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, TidbAuthBinding
from models.enums import TidbAuthBindingStatus
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
password=new_cluster["password"],
tenant_id=dataset.tenant_id,
active=True,
status="ACTIVE",
status=TidbAuthBindingStatus.ACTIVE,
)
db.session.add(new_tidb_auth_binding)
db.session.commit()

View File

@ -9,6 +9,7 @@ from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
class TidbService:
@ -170,7 +171,7 @@ class TidbService:
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = "ACTIVE"
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()

View File

@ -95,15 +95,11 @@ class FirecrawlApp:
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed":
total = crawl_status_response.get("total", 0)
if total == 0:
# Normalize to avoid None bypassing the zero-guard when the API returns null.
total = crawl_status_response.get("total") or 0
if total <= 0:
raise Exception("Failed to check crawl status. Error: No page found")
data = crawl_status_response.get("data", [])
url_data_list: list[FirecrawlDocumentData] = []
for item in data:
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data = self._extract_common_fields(item)
url_data_list.append(url_data)
url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers)
if url_data_list:
file_key = "website_files/" + job_id + ".txt"
try:
@ -120,6 +116,36 @@ class FirecrawlApp:
self._handle_error(response, "check crawl status")
raise RuntimeError("unreachable: _handle_error always raises")
def _collect_all_crawl_pages(
self, first_page: dict[str, Any], headers: dict[str, str]
) -> list[FirecrawlDocumentData]:
"""Collect all crawl result pages by following pagination links.
Raises an exception if any paginated request fails, to avoid returning
partial data that is inconsistent with the reported total.
The number of pages processed is capped at ``total`` (the
server-reported page count) to guard against infinite loops caused by
a misbehaving server that keeps returning a ``next`` URL.
"""
total: int = first_page.get("total") or 0
url_data_list: list[FirecrawlDocumentData] = []
current_page = first_page
pages_processed = 0
while True:
for item in current_page.get("data", []):
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data_list.append(self._extract_common_fields(item))
next_url: str | None = current_page.get("next")
pages_processed += 1
if not next_url or pages_processed >= total:
break
response = self._get_request(next_url, headers)
if response.status_code != 200:
self._handle_error(response, "fetch next crawl page")
current_page = response.json()
return url_data_list
def _format_crawl_status_response(
self,
status: str,

View File

@ -50,7 +50,7 @@ class BuiltinTool(Tool):
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
)

View File

@ -38,7 +38,7 @@ class ToolLabelManager:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type.value,
tool_type=controller.provider_type,
label_name=label,
)
)
@ -58,7 +58,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
ToolLabelBinding.tool_type == controller.provider_type,
)
labels = db.session.scalars(stmt).all()

View File

@ -9,6 +9,7 @@ from decimal import Decimal
from typing import cast
from core.model_manager import ModelManager
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.model_runtime.entities.llm_entities import LLMResult
from dify_graph.model_runtime.entities.message_entities import PromptMessage
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
@ -78,7 +79,7 @@ class ModelInvocationUtils:
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context

View File

@ -1,6 +1,9 @@
from __future__ import annotations
from collections.abc import Sequence
import json
import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.model_manager import ModelInstance
@ -36,6 +39,11 @@ from .exc import (
)
from .protocols import TemplateRenderer
logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}")
MAX_RESOLVED_VALUE_LENGTH = 1024
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
@ -475,3 +483,61 @@ def _append_file_prompts(
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
def _coerce_resolved_value(raw: str) -> int | float | bool | str:
"""Try to restore the original type from a resolved template string.
Variable references are always resolved to text, but completion params may
expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to
the ``temperature`` parameter). This helper attempts a JSON parse so that
``"0.7"`` ``0.7``, ``"true"`` ``True``, etc. Plain strings that are not
valid JSON literals are returned as-is.
"""
stripped = raw.strip()
if not stripped:
return raw
try:
parsed: object = json.loads(stripped)
except (json.JSONDecodeError, ValueError):
return raw
if isinstance(parsed, (int, float, bool)):
return parsed
return raw
def resolve_completion_params_variables(
completion_params: Mapping[str, Any],
variable_pool: VariablePool,
) -> dict[str, Any]:
"""Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params.
Security notes:
- Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to
prevent denial-of-service through excessively large variable payloads.
- This follows the same ``VariablePool.convert_template`` pattern used across
Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream
model plugin receives these values as structured JSON key-value pairs they
are never concatenated into raw HTTP headers or SQL queries.
- Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are
restored to their native type rather than sent as a bare string.
"""
resolved: dict[str, Any] = {}
for key, value in completion_params.items():
if isinstance(value, str) and VARIABLE_PATTERN.search(value):
segment_group = variable_pool.convert_template(value)
text = segment_group.text
if len(text) > MAX_RESOLVED_VALUE_LENGTH:
logger.warning(
"Resolved value for param '%s' truncated from %d to %d chars",
key,
len(text),
MAX_RESOLVED_VALUE_LENGTH,
)
text = text[:MAX_RESOLVED_VALUE_LENGTH]
resolved[key] = _coerce_resolved_value(text)
else:
resolved[key] = value
return resolved

View File

@ -202,6 +202,10 @@ class LLMNode(Node[LLMNodeData]):
# fetch model config
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
model_name = model_instance.model_name
model_provider = model_instance.provider
model_stop = model_instance.stop

View File

@ -164,6 +164,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")

View File

@ -114,6 +114,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""

View File

@ -18,15 +18,23 @@ if TYPE_CHECKING:
from models.model import EndUser
def _resolve_current_user() -> EndUser | Account | None:
"""
Resolve the current user proxy to its underlying user object.
This keeps unit tests working when they patch `current_user` directly
instead of bootstrapping a full Flask-Login manager.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
def current_account_with_tenant():
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
user = _resolve_current_user()
if not isinstance(user, Account):
raise ValueError("current_user must be an Account instance")
@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _get_user()
user = _resolve_current_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
g._login_user = user
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
check_csrf_token(request, user.id)

View File

@ -1,16 +1,19 @@
import logging
import sys
import urllib.parse
from dataclasses import dataclass
from typing import NotRequired
import httpx
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
logger = logging.getLogger(__name__)
JsonObject = dict[str, object]
JsonObjectList = list[JsonObject]
@ -30,8 +33,8 @@ class GitHubEmailRecord(TypedDict, total=False):
class GitHubRawUserInfo(TypedDict):
id: int | str
login: str
name: NotRequired[str]
email: NotRequired[str]
name: NotRequired[str | None]
email: NotRequired[str | None]
class GoogleRawUserInfo(TypedDict):
@ -127,9 +130,14 @@ class GitHubOAuth(OAuth):
response.raise_for_status()
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
try:
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_response.raise_for_status()
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
except (httpx.HTTPStatusError, ValidationError):
logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True)
primary_email = None
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
@ -137,8 +145,11 @@ class GitHubOAuth(OAuth):
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
email = payload.get("email")
if not email:
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
raise ValueError(
'Dify currently not supports the "Keep my email addresses private" feature,'
" please disable it and login again"
)
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email)
class GoogleOAuth(OAuth):

View File

@ -43,7 +43,9 @@ from .enums import (
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SegmentType,
SummaryStatus,
TidbAuthBindingStatus,
)
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
@ -494,7 +496,9 @@ class Document(Base):
)
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_form: Mapped[IndexStructureType] = mapped_column(
EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'")
)
doc_language = mapped_column(String(255), nullable=True)
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@ -998,7 +1002,9 @@ class ChildChunk(Base):
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase):
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
status: Mapped[TidbAuthBindingStatus] = mapped_column(
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
)
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(

View File

@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum):
TIME = "time"
class SegmentType(StrEnum):
"""Document segment type"""
AUTOMATIC = "automatic"
CUSTOMIZED = "customized"
class SegmentStatus(StrEnum):
"""Document segment status"""
@ -323,3 +330,10 @@ class ProviderQuotaType(StrEnum):
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ApiTokenType(StrEnum):
"""API Token type"""
APP = "app"
DATASET = "dataset"

View File

@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent):
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
@classmethod
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
return cls(form_id=form_id, message_id=message_id)
def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent":
return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id)
form: Mapped["HumanInputForm"] = relationship(
"HumanInputForm",

View File

@ -21,7 +21,7 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.tools.signature import sign_tool_file
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from dify_graph.file import helpers as file_helpers
from extensions.storage.storage_type import StorageType
from libs.helper import generate_string # type: ignore[import-not-found]
@ -31,6 +31,7 @@ from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db
from .enums import (
ApiTokenType,
AppMCPServerStatus,
AppStatus,
BannerStatus,
@ -43,6 +44,8 @@ from .enums import (
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
ProviderQuotaType,
TagType,
)
from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID
@ -1782,7 +1785,7 @@ class MessageFile(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(
EnumText(FileTransferMethod, length=255), nullable=False
)
@ -2094,7 +2097,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=True)
tenant_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(String(16), nullable=False)
type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False)
token: Mapped[str] = mapped_column(String(255), nullable=False)
last_used_at = mapped_column(sa.DateTime, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -2404,7 +2407,7 @@ class Tag(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False)
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
@ -2489,7 +2492,9 @@ class TenantCreditPool(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
pool_type: Mapped[ProviderQuotaType] = mapped_column(
EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial"
)
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(

View File

@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from core.tools.entities.tool_entities import (
ApiProviderSchemaType,
ToolProviderType,
WorkflowToolParameterConfiguration,
)
from .base import TypeBase
from .engine import db
from .model import Account, App, Tenant
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
if TYPE_CHECKING:
from core.entities.mcp_provider import MCPProviderEntity
@ -141,7 +145,9 @@ class ApiToolProvider(TypeBase):
icon: Mapped[str] = mapped_column(String(255), nullable=False)
# original schema
schema: Mapped[str] = mapped_column(LongText, nullable=False)
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column(
EnumText(ApiProviderSchemaType, length=40), nullable=False
)
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
@ -208,7 +214,7 @@ class ToolLabelBinding(TypeBase):
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# label name
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
@ -386,7 +392,7 @@ class ToolModelInvoke(TypeBase):
# provider
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# tool name
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
# invoke parameters

View File

@ -1221,7 +1221,9 @@ class WorkflowAppLog(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column(
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False
)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
@ -1301,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase):
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column(
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True
)
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
run_status: Mapped[WorkflowExecutionStatus] = mapped_column(
EnumText(WorkflowExecutionStatus, length=255), nullable=False
)
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
)

View File

@ -8,7 +8,7 @@ dependencies = [
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.3",
"beautifulsoup4==4.14.3",
"boto3==1.42.68",
"boto3==1.42.73",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.6.2",
@ -23,7 +23,7 @@ dependencies = [
"gevent~=25.9.1",
"gmpy2~=2.3.0",
"google-api-core>=2.19.1",
"google-api-python-client==2.192.0",
"google-api-python-client==2.193.0",
"google-auth>=2.47.0",
"google-auth-httplib2==0.3.0",
"google-cloud-aiplatform>=1.123.0",
@ -40,7 +40,7 @@ dependencies = [
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.28.0",
"opentelemetry-distro==0.49b0",
"opentelemetry-exporter-otlp==1.28.0",
@ -72,10 +72,10 @@ dependencies = [
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=7.3.0",
"resend~=2.23.0",
"sentry-sdk[flask]~=2.54.0",
"resend~=2.26.0",
"sentry-sdk[flask]~=2.55.0",
"sqlalchemy~=2.0.29",
"starlette==0.52.1",
"starlette==1.0.0",
"tiktoken~=0.12.0",
"transformers~=5.3.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
@ -92,7 +92,7 @@ dependencies = [
"apscheduler>=3.11.0",
"weave>=0.52.16",
"fastopenapi[flask]>=0.7.0",
"bleach~=6.2.0",
"bleach~=6.3.0",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@ -119,7 +119,7 @@ dev = [
"ruff~=0.15.5",
"pytest~=9.0.2",
"pytest-benchmark~=5.2.3",
"pytest-cov~=7.0.0",
"pytest-cov~=7.1.0",
"pytest-env~=1.6.0",
"pytest-mock~=3.15.1",
"testcontainers~=4.14.1",
@ -203,7 +203,7 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
# Required by vector store clients
############################################################
vdb = [
"alibabacloud_gpdb20160503~=3.8.0",
"alibabacloud_gpdb20160503~=5.1.0",
"alibabacloud_tea_openapi~=0.4.3",
"chromadb==0.5.20",
"clickhouse-connect~=0.14.1",

View File

@ -8,6 +8,7 @@ from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
@app.celery.task(queue="dataset")
@ -57,7 +58,7 @@ def create_clusters(batch_size):
account=new_cluster["account"],
password=new_cluster["password"],
active=False,
status="CREATING",
status=TidbAuthBindingStatus.CREATING,
)
db.session.add(tidb_auth_binding)
db.session.commit()

View File

@ -9,6 +9,7 @@ from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
@app.celery.task(queue="dataset")
@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
try:
# check the number of idle tidb serverless
tidb_serverless_list = db.session.scalars(
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
select(TidbAuthBinding).where(
TidbAuthBinding.active == False,
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
)
).all()
if len(tidb_serverless_list) == 0:
return

View File

@ -241,7 +241,7 @@ class AppService:
class ArgsDict(TypedDict):
name: str
description: str
icon_type: str
icon_type: IconType | str | None
icon: str
icon_background: str
use_icon_as_answer_icon: bool
@ -257,7 +257,13 @@ class AppService:
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None
icon_type = args.get("icon_type")
if icon_type is None:
resolved_icon_type = app.icon_type
else:
resolved_icon_type = IconType(icon_type)
app.icon_type = resolved_icon_type
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)

View File

@ -1,8 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import TypedDict
class AuthCredentials(TypedDict):
auth_type: str
config: dict[str, Any]
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
self.credentials = credentials
@abstractmethod

View File

@ -1,9 +1,9 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
def __init__(self, provider: str, credentials: AuthCredentials):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -3,11 +3,11 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":

View File

@ -7,6 +7,7 @@ from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
from models.enums import ProviderQuotaType
logger = logging.getLogger(__name__)
@ -16,7 +17,10 @@ class CreditPoolService:
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
credit_pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
tenant_id=tenant_id,
quota_limit=dify_config.HOSTED_POOL_CREDITS,
quota_used=0,
pool_type=ProviderQuotaType.TRIAL,
)
db.session.add(credit_pool)
db.session.commit()

View File

@ -58,6 +58,7 @@ from models.enums import (
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SegmentType,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
@ -1439,7 +1440,7 @@ class DocumentService:
.filter(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != "qa_model", # Skip qa_model documents
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
)
@ -2039,7 +2040,7 @@ class DocumentService:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_form = IndexStructureType(knowledge_config.doc_form)
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
@ -2639,7 +2640,7 @@ class DocumentService:
document.splitting_completed_at = None
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = document_data.doc_form
document.doc_form = IndexStructureType(document_data.doc_form)
db.session.add(document)
db.session.commit()
# update document segment
@ -3100,7 +3101,7 @@ class DocumentService:
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
if "answer" not in args or not args["answer"]:
raise ValueError("Answer is required")
if not args["answer"].strip():
@ -3157,7 +3158,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
@ -3231,7 +3232,7 @@ class SegmentService:
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
@ -3254,7 +3255,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
@ -3321,7 +3322,7 @@ class SegmentService:
content = args.content or segment.content
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3418,7 +3419,7 @@ class SegmentService:
)
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
else:
@ -3435,7 +3436,7 @@ class SegmentService:
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3786,7 +3787,7 @@ class SegmentService:
child_chunk.word_count = len(child_chunk.content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
child_chunk.type = SegmentType.CUSTOMIZED
update_child_chunks.append(child_chunk)
else:
new_child_chunks_args.append(child_chunk_update_args)
@ -3845,7 +3846,7 @@ class SegmentService:
child_chunk.word_count = len(content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
child_chunk.type = SegmentType.CUSTOMIZED
db.session.add(child_chunk)
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
db.session.commit()

View File

@ -9,6 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
@ -79,9 +80,9 @@ class RagPipelineTransformService:
pipeline = self._create_pipeline(pipeline_yaml)
# save chunk structure to dataset
if doc_form == "hierarchical_model":
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset.chunk_structure = "hierarchical_model"
elif doc_form == "text_model":
elif doc_form == IndexStructureType.PARAGRAPH_INDEX:
dataset.chunk_structure = "text_model"
else:
raise ValueError("Unsupported doc form")
@ -101,7 +102,7 @@ class RagPipelineTransformService:
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
pipeline_yaml = {}
if doc_form == "text_model":
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
@ -132,7 +133,7 @@ class RagPipelineTransformService:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
elif doc_form == IndexStructureType.PARENT_CHILD_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
# get graph from transform.file-parentchild.yml

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding
@ -83,7 +84,7 @@ class TagService:
raise ValueError("Tag name already exists")
tag = Tag(
name=args["name"],
type=args["type"],
type=TagType(args["type"]),
created_by=current_user.id,
tenant_id=current_user.current_tenant_id,
)

View File

@ -11,6 +11,7 @@ from sqlalchemy import func
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexStructureType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -109,7 +110,7 @@ def batch_create_segment_to_index_task(
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
@ -159,7 +160,7 @@ def batch_create_segment_to_index_task(
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count

View File

@ -10,6 +10,7 @@ from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
)
if (
document.indexing_status == IndexingStatus.COMPLETED
and document.doc_form != "qa_model"
and document.doc_form != IndexStructureType.QA_INDEX
and document.need_summary is True
):
try:

View File

@ -9,6 +9,7 @@ from celery import shared_task
from sqlalchemy import or_, select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
@ -106,7 +107,7 @@ def regenerate_summary_index_task(
),
DatasetDocument.enabled == True, # Document must be enabled
DatasetDocument.archived == False, # Document must not be archived
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
.all()
@ -209,7 +210,7 @@ def regenerate_summary_index_task(
for dataset_document in dataset_documents:
# Skip qa_model documents
if dataset_document.doc_form == "qa_model":
if dataset_document.doc_form == IndexStructureType.QA_INDEX:
continue
try:

View File

@ -179,7 +179,7 @@ def _record_trigger_failure_log(
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=created_by_role,
created_by=created_by,
)

View File

@ -13,6 +13,7 @@ from unittest.mock import patch
import pytest
from extensions.ext_redis import redis_client
from models.enums import ApiTokenType
from models.model import ApiToken
from services.api_token_service import ApiTokenCache, CachedApiToken
@ -279,7 +280,7 @@ class TestEndToEndCacheFlow:
test_token = ApiToken()
test_token.id = "test-e2e-id"
test_token.token = test_token_value
test_token.type = test_scope
test_token.type = ApiTokenType.APP
test_token.app_id = "test-app"
test_token.tenant_id = "test-tenant"
test_token.last_used_at = None

View File

@ -0,0 +1,342 @@
"""Authenticated controller integration tests for console message APIs."""
from datetime import timedelta
from decimal import Decimal
from unittest.mock import patch
from uuid import uuid4
import pytest
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload
from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from libs.datetime_utils import naive_utc_now
from models.enums import ConversationFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
create_console_app,
)
def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation:
conversation = Conversation(
app_id=app_id,
app_model_config_id=None,
model_provider=None,
model_id="",
override_model_configs=None,
mode=mode,
name="Test Conversation",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
)
db_session.add(conversation)
db_session.commit()
return conversation
def _create_message(
db_session: Session,
app_id: str,
conversation_id: str,
account_id: str,
*,
created_at_offset_seconds: int = 0,
) -> Message:
created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds)
message = Message(
app_id=app_id,
model_provider=None,
model_id="",
override_model_configs=None,
conversation_id=conversation_id,
inputs={},
query="Hello",
message={"type": "text", "content": "Hello"},
message_tokens=1,
message_unit_price=Decimal("0.0001"),
message_price_unit=Decimal("0.001"),
answer="Hi there",
answer_tokens=1,
answer_unit_price=Decimal("0.0001"),
answer_price_unit=Decimal("0.001"),
parent_message_id=None,
provider_response_latency=0,
total_price=Decimal("0.0002"),
currency="USD",
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
created_at=created_at,
updated_at=created_at,
app_mode=AppMode.CHAT,
)
db_session.add(message)
db_session.commit()
return message
class TestMessageValidators:
def test_chat_messages_query_validators(self) -> None:
assert ChatMessagesQuery.empty_to_none("") is None
assert ChatMessagesQuery.empty_to_none("val") == "val"
assert ChatMessagesQuery.validate_uuid(None) is None
assert (
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_message_feedback_validators(self) -> None:
assert (
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_feedback_export_validators(self) -> None:
assert FeedbackExportQuery.parse_bool(None) is None
assert FeedbackExportQuery.parse_bool(True) is True
assert FeedbackExportQuery.parse_bool("1") is True
assert FeedbackExportQuery.parse_bool("0") is False
assert FeedbackExportQuery.parse_bool("off") is False
with pytest.raises(ValueError):
FeedbackExportQuery.parse_bool("invalid")
def test_chat_message_list_not_found(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages",
query_string={"conversation_id": str(uuid4())},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert payload["code"] == "not_found"
def test_chat_message_list_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
_create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0)
second = _create_message(
db_session_with_containers,
app.id,
conversation.id,
account.id,
created_at_offset_seconds=1,
)
with patch(
"controllers.console.app.message.attach_message_extra_contents",
side_effect=_attach_message_extra_contents,
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages",
query_string={"conversation_id": conversation.id, "limit": 1},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["limit"] == 1
assert payload["has_more"] is True
assert len(payload["data"]) == 1
assert payload["data"][0]["id"] == second.id
def test_message_feedback_not_found(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
response = test_client_with_containers.post(
f"/console/api/apps/{app.id}/feedbacks",
json={"message_id": str(uuid4()), "rating": "like"},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert payload["code"] == "not_found"
def test_message_feedback_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
response = test_client_with_containers.post(
f"/console/api/apps/{app.id}/feedbacks",
json={"message_id": message.id, "rating": "like"},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
feedback = db_session_with_containers.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == message.id)
)
assert feedback is not None
assert feedback.rating == FeedbackRating.LIKE
assert feedback.from_account_id == account.id
def test_message_annotation_count(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
db_session_with_containers.add(
MessageAnnotation(
app_id=app.id,
conversation_id=conversation.id,
message_id=message.id,
question="Q",
content="A",
account_id=account.id,
)
)
db_session_with_containers.commit()
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/annotations/count",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"count": 1}
def test_message_suggested_questions_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
message_id = str(uuid4())
with patch(
"controllers.console.app.message.MessageService.get_suggested_questions_after_answer",
return_value=["q1", "q2"],
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"data": ["q1", "q2"]}
@pytest.mark.parametrize(
("exc", "expected_status", "expected_code"),
[
(MessageNotExistsError(), 404, "not_found"),
(ConversationNotExistsError(), 404, "not_found"),
(ProviderTokenNotInitError(), 400, "provider_not_initialize"),
(QuotaExceededError(), 400, "provider_quota_exceeded"),
(ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"),
(SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"),
(Exception(), 500, "internal_server_error"),
],
)
def test_message_suggested_questions_errors(
exc: Exception,
expected_status: int,
expected_code: str,
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
message_id = str(uuid4())
with patch(
"controllers.console.app.message.MessageService.get_suggested_questions_after_answer",
side_effect=exc,
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == expected_status
payload = response.get_json()
assert payload is not None
assert payload["code"] == expected_code
def test_message_feedback_export_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/feedbacks/export",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"exported": True}
def test_message_api_get_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
with patch(
"controllers.console.app.message.attach_message_extra_contents",
side_effect=_attach_message_extra_contents,
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/messages/{message.id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["id"] == message.id

View File

@ -0,0 +1,334 @@
"""Controller integration tests for console statistic routes."""
from datetime import timedelta
from decimal import Decimal
from unittest.mock import patch
from uuid import uuid4
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from libs.datetime_utils import naive_utc_now
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageFeedback
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
create_console_app,
)
def _create_conversation(
db_session: Session,
app_id: str,
account_id: str,
*,
mode: AppMode,
created_at_offset_days: int = 0,
) -> Conversation:
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
conversation = Conversation(
app_id=app_id,
app_model_config_id=None,
model_provider=None,
model_id="",
override_model_configs=None,
mode=mode,
name="Stats Conversation",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
created_at=created_at,
updated_at=created_at,
)
db_session.add(conversation)
db_session.commit()
return conversation
def _create_message(
db_session: Session,
app_id: str,
conversation_id: str,
*,
from_account_id: str | None,
from_end_user_id: str | None = None,
message_tokens: int = 1,
answer_tokens: int = 1,
total_price: Decimal = Decimal("0.01"),
provider_response_latency: float = 1.0,
created_at_offset_days: int = 0,
) -> Message:
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
message = Message(
app_id=app_id,
model_provider=None,
model_id="",
override_model_configs=None,
conversation_id=conversation_id,
inputs={},
query="Hello",
message={"type": "text", "content": "Hello"},
message_tokens=message_tokens,
message_unit_price=Decimal("0.001"),
message_price_unit=Decimal("0.001"),
answer="Hi there",
answer_tokens=answer_tokens,
answer_unit_price=Decimal("0.001"),
answer_price_unit=Decimal("0.001"),
parent_message_id=None,
provider_response_latency=provider_response_latency,
total_price=total_price,
currency="USD",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=from_end_user_id,
from_account_id=from_account_id,
created_at=created_at,
updated_at=created_at,
app_mode=AppMode.CHAT,
)
db_session.add(message)
db_session.commit()
return message
def _create_like_feedback(
db_session: Session,
app_id: str,
conversation_id: str,
message_id: str,
account_id: str,
) -> None:
db_session.add(
MessageFeedback(
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.ADMIN,
from_account_id=account_id,
)
)
db_session.commit()
def test_daily_message_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-messages",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["message_count"] == 1
def test_daily_conversation_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-conversations",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["conversation_count"] == 1
def test_daily_terminals_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(
db_session_with_containers,
app.id,
conversation.id,
from_account_id=None,
from_end_user_id=str(uuid4()),
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-end-users",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["terminal_count"] == 1
def test_daily_token_cost_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(
db_session_with_containers,
app.id,
conversation.id,
from_account_id=account.id,
message_tokens=40,
answer_tokens=60,
total_price=Decimal("0.02"),
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/token-costs",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload["data"][0]["token_count"] == 100
assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02")
def test_average_session_interaction_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/average-session-interactions",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["interactions"] == 2.0
def test_user_satisfaction_rate_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
for _ in range(9):
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
_create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["rate"] == 100.0
def test_average_response_time_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(
db_session_with_containers,
app.id,
conversation.id,
from_account_id=account.id,
provider_response_latency=1.234,
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/average-response-time",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["latency"] == 1234.0
def test_tokens_per_second_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(
db_session_with_containers,
app.id,
conversation.id,
from_account_id=account.id,
answer_tokens=31,
provider_response_latency=2.0,
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/tokens-per-second",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["tps"] == 15.5
def test_invalid_time_range(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 400
assert response.get_json()["message"] == "Invalid time"
def test_time_range_params_passed(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
import datetime
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
start = datetime.datetime.now()
end = datetime.datetime.now()
with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse:
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
mock_parse.assert_called_once_with("something", "something", "UTC")

View File

@ -0,0 +1,415 @@
"""Authenticated controller integration tests for workflow draft variable APIs."""
import uuid
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from dify_graph.variables.segments import StringSegment
from factories.variable_factory import segment_to_variable
from models import Workflow
from models.model import AppMode
from models.workflow import WorkflowDraftVariable
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
create_console_app,
)
def _create_draft_workflow(
db_session: Session,
app_id: str,
tenant_id: str,
account_id: str,
*,
environment_variables: list | None = None,
conversation_variables: list | None = None,
) -> Workflow:
workflow = Workflow.new(
tenant_id=tenant_id,
app_id=app_id,
type="workflow",
version=Workflow.VERSION_DRAFT,
graph='{"nodes": [], "edges": []}',
features="{}",
created_by=account_id,
environment_variables=environment_variables or [],
conversation_variables=conversation_variables or [],
rag_pipeline_variables=[],
)
db_session.add(workflow)
db_session.commit()
return workflow
def _create_node_variable(
db_session: Session,
app_id: str,
user_id: str,
*,
node_id: str = "node_1",
name: str = "test_var",
) -> WorkflowDraftVariable:
variable = WorkflowDraftVariable.new_node_variable(
app_id=app_id,
user_id=user_id,
node_id=node_id,
name=name,
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
visible=True,
editable=True,
)
db_session.add(variable)
db_session.commit()
return variable
def _create_system_variable(
db_session: Session, app_id: str, user_id: str, name: str = "query"
) -> WorkflowDraftVariable:
variable = WorkflowDraftVariable.new_sys_variable(
app_id=app_id,
user_id=user_id,
name=name,
value=StringSegment(value="system-value"),
node_execution_id=str(uuid.uuid4()),
editable=True,
)
db_session.add(variable)
db_session.commit()
return variable
def _build_environment_variable(name: str, value: str):
return segment_to_variable(
segment=StringSegment(value=value),
selector=[ENVIRONMENT_VARIABLE_NODE_ID, name],
name=name,
description=f"Environment variable {name}",
)
def _build_conversation_variable(name: str, value: str):
return segment_to_variable(
segment=StringSegment(value=value),
selector=[CONVERSATION_VARIABLE_NODE_ID, name],
name=name,
description=f"Conversation variable {name}",
)
def test_workflow_variable_collection_get_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"items": [], "total": 0}
def test_workflow_variable_collection_get_not_exist(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert payload["code"] == "draft_workflow_not_exist"
def test_workflow_variable_collection_delete(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_node_variable(db_session_with_containers, app.id, account.id)
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var")
response = test_client_with_containers.delete(
f"/console/api/apps/{app.id}/workflows/draft/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
remaining = db_session_with_containers.scalars(
select(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app.id,
WorkflowDraftVariable.user_id == account.id,
)
).all()
assert remaining == []
def test_node_variable_collection_get_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other")
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert [item["id"] for item in payload["items"]] == [node_variable.id]
def test_node_variable_collection_get_invalid_node_id(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 400
payload = response.get_json()
assert payload is not None
assert payload["code"] == "invalid_param"
def test_node_variable_collection_delete(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456")
target_id = target.id
untouched_id = untouched.id
response = test_client_with_containers.delete(
f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id))
is None
)
assert (
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id))
is not None
)
def test_variable_api_get_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["id"] == variable.id
assert payload["name"] == "test_var"
def test_variable_api_get_not_found(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert payload["code"] == "not_found"
def test_variable_api_patch_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.patch(
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
headers=authenticate_console_client(test_client_with_containers, account),
json={"name": "renamed_var"},
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["id"] == variable.id
assert payload["name"] == "renamed_var"
refreshed = db_session_with_containers.scalar(
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)
)
assert refreshed is not None
assert refreshed.name == "renamed_var"
def test_variable_api_delete_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.delete(
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id))
is None
)
def test_variable_reset_api_put_success_returns_no_content_without_execution(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.put(
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id))
is None
)
def test_conversation_variable_collection_get(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(
db_session_with_containers,
app.id,
tenant.id,
account.id,
conversation_variables=[_build_conversation_variable("session_name", "Alice")],
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/conversation-variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert [item["name"] for item in payload["items"]] == ["session_name"]
created = db_session_with_containers.scalars(
select(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app.id,
WorkflowDraftVariable.user_id == account.id,
WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID,
)
).all()
assert len(created) == 1
def test_system_variable_collection_get(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
variable = _create_system_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/system-variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert [item["id"] for item in payload["items"]] == [variable.id]
def test_environment_variable_collection_get(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(
db_session_with_containers,
app.id,
tenant.id,
account.id,
environment_variables=[_build_environment_variable("api_key", "secret-value")],
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/environment-variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["items"][0]["name"] == "api_key"
assert payload["items"][0]["value"] == "secret-value"

View File

@ -0,0 +1,131 @@
"""Controller integration tests for API key data source auth routes."""
import json
from unittest.mock import patch
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from models.source import DataSourceApiKeyAuthBinding
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
)
def test_get_api_key_auth_data_source(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant.id,
category="api_key",
provider="custom_provider",
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
disabled=False,
)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
response = test_client_with_containers.get(
"/console/api/api-key-auth/data-source",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert len(payload["sources"]) == 1
assert payload["sources"][0]["provider"] == "custom_provider"
def test_get_api_key_auth_data_source_empty(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
response = test_client_with_containers.get(
"/console/api/api-key-auth/data-source",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"sources": []}
def test_create_binding_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
with (
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"),
):
response = test_client_with_containers.post(
"/console/api/api-key-auth/data-source/binding",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
def test_create_binding_failure(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
with (
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth",
side_effect=ValueError("Invalid structure"),
),
):
response = test_client_with_containers.post(
"/console/api/api-key-auth/data-source/binding",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 500
payload = response.get_json()
assert payload is not None
assert payload["code"] == "auth_failed"
assert payload["message"] == "Invalid structure"
def test_delete_binding_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant.id,
category="api_key",
provider="custom_provider",
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
disabled=False,
)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
response = test_client_with_containers.delete(
f"/console/api/api-key-auth/data-source/{binding.id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id)
)
is None
)

View File

@ -0,0 +1,120 @@
"""Controller integration tests for console OAuth data source routes."""
from unittest.mock import MagicMock, patch
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from models.source import DataSourceOauthBinding
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
)
def test_get_oauth_url_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
provider = MagicMock()
provider.get_authorization_url.return_value = "http://oauth.provider/auth"
with (
patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}),
patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None),
):
response = test_client_with_containers.get(
"/console/api/oauth/data-source/notion",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert tenant.id == account.current_tenant_id
assert response.status_code == 200
assert response.get_json() == {"data": "http://oauth.provider/auth"}
provider.get_authorization_url.assert_called_once()
def test_get_oauth_url_invalid_provider(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get(
"/console/api/oauth/data-source/unknown_provider",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 400
assert response.get_json() == {"error": "Invalid provider"}
def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None:
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code")
assert response.status_code == 302
assert "code=mock_code" in response.location
def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None:
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion")
assert response.status_code == 302
assert "error=Access%20denied" in response.location
def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None:
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code")
assert response.status_code == 400
assert response.get_json() == {"error": "Invalid provider"}
def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None:
provider = MagicMock()
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}):
response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123")
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
provider.get_access_token.assert_called_once_with("auth_code_123")
def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None:
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=")
assert response.status_code == 400
assert response.get_json() == {"error": "Invalid code"}
def test_sync_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
binding = DataSourceOauthBinding(
tenant_id=tenant.id,
access_token="test-access-token",
provider="notion",
source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []},
disabled=False,
)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
provider = MagicMock()
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}):
response = test_client_with_containers.get(
f"/console/api/oauth/data-source/notion/{binding.id}/sync",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
provider.sync_data_source.assert_called_once_with(binding.id)

View File

@ -0,0 +1,365 @@
"""Controller integration tests for console OAuth server routes."""
from unittest.mock import patch
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
ensure_dify_setup,
)
def _build_oauth_provider_app() -> OAuthProviderApp:
return OAuthProviderApp(
app_icon="icon_url",
client_id="test_client_id",
client_secret="test_secret",
app_label={"en-US": "Test App"},
redirect_uris=["http://localhost/callback"],
scope="read,write",
)
def test_oauth_provider_successful_post(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider",
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["app_icon"] == "icon_url"
assert payload["app_label"] == {"en-US": "Test App"}
assert payload["scope"] == "read,write"
def test_oauth_provider_invalid_redirect_uri(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider",
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
)
assert response.status_code == 400
payload = response.get_json()
assert payload is not None
assert "redirect_uri is invalid" in payload["message"]
def test_oauth_provider_invalid_client_id(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
response = test_client_with_containers.post(
"/console/api/oauth/provider",
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert "client_id is invalid" in payload["message"]
def test_oauth_authorize_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
with (
patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
),
patch(
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
return_value="auth_code_123",
) as mock_sign,
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/authorize",
json={"client_id": "test_client_id"},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"code": "auth_code_123"}
mock_sign.assert_called_once_with("test_client_id", account.id)
def test_oauth_token_authorization_code_grant(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with (
patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
),
patch(
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
return_value=("access_123", "refresh_123"),
),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
)
assert response.status_code == 200
assert response.get_json() == {
"access_token": "access_123",
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": "refresh_123",
}
def test_oauth_token_authorization_code_grant_missing_code(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
)
assert response.status_code == 400
assert response.get_json()["message"] == "code is required"
def test_oauth_token_authorization_code_grant_invalid_secret(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "invalid_secret",
"redirect_uri": "http://localhost/callback",
},
)
assert response.status_code == 400
assert response.get_json()["message"] == "client_secret is invalid"
def test_oauth_token_authorization_code_grant_invalid_redirect_uri(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://invalid/callback",
},
)
assert response.status_code == 400
assert response.get_json()["message"] == "redirect_uri is invalid"
def test_oauth_token_refresh_token_grant(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with (
patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
),
patch(
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
return_value=("new_access", "new_refresh"),
),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
)
assert response.status_code == 200
assert response.get_json() == {
"access_token": "new_access",
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": "new_refresh",
}
def test_oauth_token_refresh_token_grant_missing_token(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={"client_id": "test_client_id", "grant_type": "refresh_token"},
)
assert response.status_code == 400
assert response.get_json()["message"] == "refresh_token is required"
def test_oauth_token_invalid_grant_type(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={"client_id": "test_client_id", "grant_type": "invalid_grant"},
)
assert response.status_code == 400
assert response.get_json()["message"] == "invalid grant_type"
def test_oauth_account_successful_retrieval(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
account.avatar = "avatar_url"
db_session_with_containers.commit()
with (
patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
),
patch(
"controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token",
return_value=account,
),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/account",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer valid_access_token"},
)
assert response.status_code == 200
assert response.get_json() == {
"name": "Test User",
"email": account.email,
"avatar": "avatar_url",
"interface_language": "en-US",
"timezone": "UTC",
}
def test_oauth_account_missing_authorization_header(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/account",
json={"client_id": "test_client_id"},
)
assert response.status_code == 401
assert response.get_json() == {"error": "Authorization header is required"}
def test_oauth_account_invalid_authorization_header_format(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/account",
json={"client_id": "test_client_id"},
headers={"Authorization": "InvalidFormat"},
)
assert response.status_code == 401
assert response.get_json() == {"error": "Invalid Authorization header format"}

View File

@ -1,17 +1,10 @@
"""
Test suite for password reset authentication flows.
"""Testcontainers integration tests for password reset authentication flows."""
This module tests the password reset mechanism including:
- Password reset email sending
- Verification code validation
- Password reset with token
- Rate limiting and security checks
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.error import (
EmailCodeError,
@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import (
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
@pytest.fixture(autouse=True)
def _mock_forgot_password_session():
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.__exit__.return_value = None
yield mock_session
@pytest.fixture(autouse=True)
def _mock_forgot_password_db():
with patch("controllers.console.auth.forgot_password.db") as mock_db:
mock_db.engine = MagicMock()
yield mock_db
class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_account(self):
@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi:
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
mock_wraps_db,
app,
mock_account,
):
"""
Test successful password reset email sending.
Verifies that:
- Email is sent to valid account
- Reset token is generated and returned
- IP rate limiting is checked
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_account.return_value = mock_account
mock_send_email.return_value = "reset_token_123"
@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi:
assert response["data"] == "reset_token_123"
mock_send_email.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
"""
Test password reset email blocked by IP rate limit.
@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi:
- No email is sent when rate limited
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi:
(None, "en-US"), # Defaults to en-US when not provided
],
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
mock_wraps_db,
app,
mock_account,
language_input,
@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi:
- Unsupported languages default to en-US
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token"
@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi:
"""Test cases for verifying password reset codes."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
"""
@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi:
- Rate limit is reset on success
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_generate_token.return_value = (None, "new_token")
@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi:
)
mock_reset_rate_limit.assert_called_once_with("test@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
mock_generate_token.return_value = (None, "fresh-token")
@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi:
mock_revoke_token.assert_called_once_with("upper_token")
mock_reset_rate_limit.assert_called_once_with("user@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
"""
Test code verification blocked by rate limit.
@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi:
- Prevents brute force attacks on verification codes
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
# Act & Assert
@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi:
with pytest.raises(EmailPasswordResetLimitError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with invalid token.
@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = None
@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with mismatched email.
@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi:
- Prevents token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi:
with pytest.raises(InvalidEmailError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with incorrect code.
@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi:
- Rate limit counter is incremented
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
@ -380,11 +321,8 @@ class TestForgotPasswordResetApi:
"""Test cases for resetting password with verified token."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_account(self):
@ -394,7 +332,6 @@ class TestForgotPasswordResetApi:
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@ -405,7 +342,6 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_revoke_token,
mock_get_data,
mock_wraps_db,
app,
mock_account,
):
@ -418,7 +354,6 @@ class TestForgotPasswordResetApi:
- Success response is returned
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_get_account.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@ -436,9 +371,8 @@ class TestForgotPasswordResetApi:
assert response["result"] == "success"
mock_revoke_token.assert_called_once_with("valid_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
def test_reset_password_mismatch(self, mock_get_data, app):
"""
Test password reset with mismatched passwords.
@ -447,7 +381,6 @@ class TestForgotPasswordResetApi:
- No password update occurs
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
# Act & Assert
@ -460,9 +393,8 @@ class TestForgotPasswordResetApi:
with pytest.raises(PasswordMismatchError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
def test_reset_password_invalid_token(self, mock_get_data, app):
"""
Test password reset with invalid token.
@ -470,7 +402,6 @@ class TestForgotPasswordResetApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@ -483,9 +414,8 @@ class TestForgotPasswordResetApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
def test_reset_password_wrong_phase(self, mock_get_data, app):
"""
Test password reset with token not in reset phase.
@ -494,7 +424,6 @@ class TestForgotPasswordResetApi:
- Prevents use of verification-phase tokens for reset
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
# Act & Assert
@ -507,13 +436,10 @@ class TestForgotPasswordResetApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_reset_password_account_not_found(
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
):
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
"""
Test password reset for non-existent account.
@ -521,7 +447,6 @@ class TestForgotPasswordResetApi:
- AccountNotFound is raised when account doesn't exist
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
mock_get_account.return_value = None

View File

@ -0,0 +1,85 @@
"""Shared helpers for authenticated console controller integration tests."""
import uuid
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HEADER_NAME_CSRF_TOKEN
from libs.datetime_utils import naive_utc_now
from libs.token import _real_cookie_name, generate_csrf_token
from models import Account, DifySetup, Tenant, TenantAccountJoin
from models.account import AccountStatus, TenantAccountRole
from models.model import App, AppMode
from services.account_service import AccountService
def ensure_dify_setup(db_session: Session) -> None:
"""Create a setup marker once so setup-protected console routes can be exercised."""
if db_session.scalar(select(DifySetup).limit(1)) is not None:
return
db_session.add(DifySetup(version=dify_config.project.version))
db_session.commit()
def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]:
"""Create an initialized owner account with a current tenant."""
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="Test User",
interface_language="en-US",
status=AccountStatus.ACTIVE,
)
account.initialized_at = naive_utc_now()
db_session.add(account)
db_session.commit()
tenant = Tenant(name="Test Tenant", status="normal")
db_session.add(tenant)
db_session.commit()
db_session.add(
TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
)
db_session.commit()
account.set_tenant_id(tenant.id)
account.timezone = "UTC"
db_session.commit()
ensure_dify_setup(db_session)
return account, tenant
def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App:
"""Create a minimal app row that can be loaded by get_app_model."""
app = App(
tenant_id=tenant_id,
name="Test App",
mode=mode,
enable_site=True,
enable_api=True,
created_by=account_id,
)
db_session.add(app)
db_session.commit()
return app
def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]:
"""Attach console auth cookies/headers for endpoints guarded by login_required."""
access_token = AccountService.get_account_jwt_token(account)
csrf_token = generate_csrf_token(account.id)
test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost")
return {
"Authorization": f"Bearer {access_token}",
HEADER_NAME_CSRF_TOKEN: csrf_token,
}

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration:
name=f"Document {i}",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Archived
@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Disabled
archived=False,
@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=status, # Not completed
enabled=True,
archived=False,
@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()

View File

@ -1,27 +0,0 @@
from __future__ import annotations
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
create_human_input_message_fixture,
)
def test_get_by_message_ids_returns_human_input_content(db_session_with_containers):
fixture = create_human_input_message_fixture(db_session_with_containers)
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=sessionmaker(bind=db.engine, expire_on_commit=False)
)
results = repository.get_by_message_ids([fixture.message.id])
assert len(results) == 1
assert len(results[0]) == 1
content = results[0][0]
assert content.submitted is True
assert content.form_submission_data is not None
assert content.form_submission_data.action_id == fixture.action_id
assert content.form_submission_data.action_text == fixture.action_text
assert content.form_submission_data.rendered_content == fixture.form.rendered_content

View File

@ -27,7 +27,7 @@ from models.human_input import (
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
@ -218,7 +218,7 @@ class TestDeleteRunsWithRelated:
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)
@ -278,7 +278,7 @@ class TestCountRunsWithRelated:
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)

View File

@ -0,0 +1,407 @@
"""Integration tests for SQLAlchemyExecutionExtraContentRepository using Testcontainers.
Part of #32454 — replaces the mock-based unit tests with real database interactions.
"""
from __future__ import annotations
from collections.abc import Generator
from dataclasses import dataclass
from datetime import datetime, timedelta
from decimal import Decimal
from uuid import uuid4
import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from dify_graph.nodes.human_input.entities import FormDefinition, UserAction
from dify_graph.nodes.human_input.enums import HumanInputFormStatus
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import ConversationFromSource, InvokeFrom
from models.execution_extra_content import ExecutionExtraContent, HumanInputContent
from models.human_input import (
ConsoleRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from models.model import App, Conversation, Message
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
@dataclass
class _TestScope:
"""Per-test data scope used to isolate DB rows.
IDs are populated after flushing the base entities to the database.
"""
tenant_id: str = ""
app_id: str = ""
user_id: str = ""
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
"""Remove test-created DB rows for a test scope."""
form_ids_subquery = select(HumanInputForm.id).where(
HumanInputForm.tenant_id == scope.tenant_id,
)
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
session.execute(
delete(ExecutionExtraContent).where(
ExecutionExtraContent.workflow_run_id.in_(
select(HumanInputForm.workflow_run_id).where(HumanInputForm.tenant_id == scope.tenant_id)
)
)
)
session.execute(delete(HumanInputForm).where(HumanInputForm.tenant_id == scope.tenant_id))
session.execute(delete(Message).where(Message.app_id == scope.app_id))
session.execute(delete(Conversation).where(Conversation.app_id == scope.app_id))
session.execute(delete(App).where(App.id == scope.app_id))
session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == scope.tenant_id))
session.execute(delete(Account).where(Account.id == scope.user_id))
session.execute(delete(Tenant).where(Tenant.id == scope.tenant_id))
session.commit()
def _seed_base_entities(session: Session, scope: _TestScope) -> None:
"""Create the base tenant, account, and app needed by tests."""
tenant = Tenant(name="Test Tenant")
session.add(tenant)
session.flush()
scope.tenant_id = tenant.id
account = Account(
name="Test Account",
email=f"test_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
session.add(account)
session.flush()
scope.user_id = account.id
tenant_join = TenantAccountJoin(
tenant_id=scope.tenant_id,
account_id=scope.user_id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(tenant_join)
app = App(
tenant_id=scope.tenant_id,
name="Test App",
description="",
mode="chat",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=scope.user_id,
updated_by=scope.user_id,
)
session.add(app)
session.flush()
scope.app_id = app.id
def _create_conversation(session: Session, scope: _TestScope) -> Conversation:
conversation = Conversation(
app_id=scope.app_id,
mode="chat",
name="Test Conversation",
summary="",
introduction="",
system_instruction="",
status="normal",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_account_id=scope.user_id,
from_end_user_id=None,
)
conversation.inputs = {}
session.add(conversation)
session.flush()
return conversation
def _create_message(
session: Session,
scope: _TestScope,
conversation_id: str,
workflow_run_id: str,
) -> Message:
message = Message(
app_id=scope.app_id,
conversation_id=conversation_id,
inputs={},
query="test query",
message={"messages": []},
answer="test answer",
message_tokens=50,
message_unit_price=Decimal("0.001"),
answer_tokens=80,
answer_unit_price=Decimal("0.001"),
provider_response_latency=0.5,
currency="USD",
from_source=ConversationFromSource.CONSOLE,
from_account_id=scope.user_id,
workflow_run_id=workflow_run_id,
)
session.add(message)
session.flush()
return message
def _create_submitted_form(
session: Session,
scope: _TestScope,
*,
workflow_run_id: str,
action_id: str = "approve",
action_title: str = "Approve",
node_title: str = "Approval",
) -> HumanInputForm:
expiration_time = datetime.utcnow() + timedelta(days=1)
form_definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id=action_id, title=action_title)],
rendered_content="rendered",
expiration_time=expiration_time,
node_title=node_title,
display_in_ui=True,
)
form = HumanInputForm(
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_run_id=workflow_run_id,
node_id="node-id",
form_definition=form_definition.model_dump_json(),
rendered_content=f"Rendered {action_title}",
status=HumanInputFormStatus.SUBMITTED,
expiration_time=expiration_time,
selected_action_id=action_id,
)
session.add(form)
session.flush()
return form
def _create_waiting_form(
session: Session,
scope: _TestScope,
*,
workflow_run_id: str,
default_values: dict | None = None,
) -> HumanInputForm:
expiration_time = datetime.utcnow() + timedelta(days=1)
form_definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values=default_values or {"name": "John"},
node_title="Approval",
display_in_ui=True,
)
form = HumanInputForm(
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_run_id=workflow_run_id,
node_id="node-id",
form_definition=form_definition.model_dump_json(),
rendered_content="Rendered block",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
session.add(form)
session.flush()
return form
def _create_human_input_content(
session: Session,
*,
workflow_run_id: str,
message_id: str,
form_id: str,
) -> HumanInputContent:
content = HumanInputContent.new(
workflow_run_id=workflow_run_id,
message_id=message_id,
form_id=form_id,
)
session.add(content)
return content
def _create_recipient(
session: Session,
*,
form_id: str,
delivery_id: str,
recipient_type: RecipientType = RecipientType.CONSOLE,
access_token: str = "token-1",
) -> HumanInputFormRecipient:
payload = ConsoleRecipientPayload(account_id=None)
recipient = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=recipient_type,
recipient_payload=payload.model_dump_json(),
access_token=access_token,
)
session.add(recipient)
return recipient
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
from dify_graph.nodes.human_input.enums import DeliveryMethodType
from models.human_input import ConsoleDeliveryPayload
delivery = HumanInputDelivery(
form_id=form_id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
)
session.add(delivery)
session.flush()
return delivery
@pytest.fixture
def repository(db_session_with_containers: Session) -> SQLAlchemyExecutionExtraContentRepository:
"""Build a repository backed by the testcontainers database engine."""
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
return SQLAlchemyExecutionExtraContentRepository(sessionmaker(bind=engine, expire_on_commit=False))
@pytest.fixture
def test_scope(db_session_with_containers: Session) -> Generator[_TestScope]:
"""Provide an isolated scope and clean related data after each test."""
scope = _TestScope()
_seed_base_entities(db_session_with_containers, scope)
db_session_with_containers.commit()
yield scope
_cleanup_scope_data(db_session_with_containers, scope)
class TestGetByMessageIds:
"""Tests for SQLAlchemyExecutionExtraContentRepository.get_by_message_ids."""
def test_groups_contents_by_message(
self,
db_session_with_containers: Session,
repository: SQLAlchemyExecutionExtraContentRepository,
test_scope: _TestScope,
) -> None:
"""Submitted forms are correctly mapped and grouped by message ID."""
workflow_run_id = str(uuid4())
conversation = _create_conversation(db_session_with_containers, test_scope)
msg1 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
msg2 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
form = _create_submitted_form(
db_session_with_containers,
test_scope,
workflow_run_id=workflow_run_id,
action_id="approve",
action_title="Approve",
)
_create_human_input_content(
db_session_with_containers,
workflow_run_id=workflow_run_id,
message_id=msg1.id,
form_id=form.id,
)
db_session_with_containers.commit()
result = repository.get_by_message_ids([msg1.id, msg2.id])
assert len(result) == 2
# msg1 has one submitted content
assert len(result[0]) == 1
content = result[0][0]
assert content.submitted is True
assert content.workflow_run_id == workflow_run_id
assert content.form_submission_data is not None
assert content.form_submission_data.action_id == "approve"
assert content.form_submission_data.action_text == "Approve"
assert content.form_submission_data.rendered_content == "Rendered Approve"
assert content.form_submission_data.node_id == "node-id"
assert content.form_submission_data.node_title == "Approval"
# msg2 has no content
assert result[1] == []
def test_returns_unsubmitted_form_definition(
self,
db_session_with_containers: Session,
repository: SQLAlchemyExecutionExtraContentRepository,
test_scope: _TestScope,
) -> None:
"""Waiting forms return full form_definition with resolved token and defaults."""
workflow_run_id = str(uuid4())
conversation = _create_conversation(db_session_with_containers, test_scope)
msg = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
form = _create_waiting_form(
db_session_with_containers,
test_scope,
workflow_run_id=workflow_run_id,
default_values={"name": "John"},
)
delivery = _create_delivery(db_session_with_containers, form_id=form.id)
_create_recipient(
db_session_with_containers,
form_id=form.id,
delivery_id=delivery.id,
access_token="token-1",
)
_create_human_input_content(
db_session_with_containers,
workflow_run_id=workflow_run_id,
message_id=msg.id,
form_id=form.id,
)
db_session_with_containers.commit()
result = repository.get_by_message_ids([msg.id])
assert len(result) == 1
assert len(result[0]) == 1
domain_content = result[0][0]
assert domain_content.submitted is False
assert domain_content.workflow_run_id == workflow_run_id
assert domain_content.form_definition is not None
form_def = domain_content.form_definition
assert form_def.form_id == form.id
assert form_def.node_id == "node-id"
assert form_def.node_title == "Approval"
assert form_def.form_content == "Rendered block"
assert form_def.display_in_ui is True
assert form_def.form_token == "token-1"
assert form_def.resolved_default_values == {"name": "John"}
assert form_def.expiration_time == int(form.expiration_time.timestamp())
def test_empty_message_ids_returns_empty_list(
self,
repository: SQLAlchemyExecutionExtraContentRepository,
) -> None:
"""Passing no message IDs returns an empty list without hitting the DB."""
result = repository.get_by_message_ids([])
assert result == []

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory:
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = document_id
document.indexing_status = indexing_status

View File

@ -525,3 +525,147 @@ class TestAPIBasedExtensionService:
# Try to get extension with wrong tenant ID
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
def test_save_extension_api_key_exactly_four_chars_rejected(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""API key with exactly 4 characters should be rejected (boundary)."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key="1234",
)
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_api_key_exactly_five_chars_accepted(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""API key with exactly 5 characters should be accepted (boundary)."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key="12345",
)
saved = APIBasedExtensionService.save(extension_data)
assert saved.id is not None
def test_save_extension_requestor_constructor_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Exception raised by requestor constructor is wrapped in ValueError."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config")
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="connection error: bad config"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_network_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Network exceptions during ping are wrapped in ValueError."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError(
"network failure"
)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="connection error: network failure"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_update_duplicate_name_rejected(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Updating an existing extension to use another extension's name should fail."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
ext1 = APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant.id,
name="Extension Alpha",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
ext2 = APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant.id,
name="Extension Beta",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
# Try to rename ext2 to ext1's name
ext2.name = "Extension Alpha"
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(ext2)
def test_get_all_returns_empty_for_different_tenant(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Extensions from one tenant should not be visible to another."""
fake = Faker()
_, tenant1 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
_, tenant2 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant1 is not None
APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant1.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
assert tenant2 is not None
result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id)
assert result == []

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from constants.model_template import default_app_templates
from models import Account
from models.model import App, Site
from models.model import App, IconType, Site
from services.account_service import AccountService, TenantService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -463,6 +463,109 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_should_preserve_icon_type_when_omitted(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update_app keeps the persisted icon_type when the update payload omits it.
"""
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
},
account,
)
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app(
app,
{
"name": "Updated App Name",
"description": "Updated app description",
"icon_type": None,
"icon": "🔄",
"icon_background": "#FF8C42",
"use_icon_as_answer_icon": True,
},
)
assert updated_app.icon_type == IconType.EMOJI
def test_update_app_should_reject_empty_icon_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update_app rejects an explicit empty icon_type.
"""
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
},
account,
)
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
with pytest.raises(ValueError):
app_service.update_app(
app,
{
"name": "Updated App Name",
"description": "Updated app description",
"icon_type": "",
"icon": "🔄",
"icon_background": "#FF8C42",
"use_icon_as_answer_icon": True,
},
)
def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app name update.

View File

@ -0,0 +1,80 @@
"""Testcontainers integration tests for AttachmentService."""
import base64
from datetime import UTC, datetime
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
import services.attachment_service as attachment_service_module
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.attachment_service import AttachmentService
class TestAttachmentService:
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id or str(uuid4()),
storage_type=StorageType.OPENDAL,
key=f"upload/{uuid4()}.txt",
name="test-file.txt",
size=100,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=datetime.now(UTC),
used=False,
)
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
def test_should_initialize_with_sessionmaker(self):
session_factory = sessionmaker()
service = AttachmentService(session_factory=session_factory)
assert service._session_maker is session_factory
def test_should_initialize_with_engine(self):
engine = create_engine("sqlite:///:memory:")
service = AttachmentService(session_factory=engine)
session = service._session_maker()
try:
assert session.bind == engine
finally:
session.close()
engine.dispose()
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory):
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
AttachmentService(session_factory=invalid_session_factory)
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
upload_file = self._create_upload_file(db_session_with_containers)
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
result = service.get_file_base64(upload_file.id)
assert result == base64.b64encode(b"binary-content").decode()
mock_load.assert_called_once_with(upload_file.key)
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
with pytest.raises(NotFound, match="File not found"):
service.get_file_base64(str(uuid4()))
mock_load.assert_not_called()

View File

@ -0,0 +1,58 @@
"""Testcontainers integration tests for ConversationVariableUpdater."""
from uuid import uuid4
import pytest
from sqlalchemy.orm import sessionmaker
from dify_graph.variables import StringVariable
from extensions.ext_database import db
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
class TestConversationVariableUpdater:
def _create_conversation_variable(
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
) -> ConversationVariable:
row = ConversationVariable(
id=variable.id,
conversation_id=conversation_id,
app_id=app_id or str(uuid4()),
data=variable.model_dump_json(),
)
db_session_with_containers.add(row)
db_session_with_containers.commit()
return row
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
self._create_conversation_variable(
db_session_with_containers, conversation_id=conversation_id, variable=variable
)
updated_variable = StringVariable(id=variable.id, name="topic", value="new value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
updater.update(conversation_id=conversation_id, variable=updated_variable)
db_session_with_containers.expire_all()
row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id))
assert row is not None
assert row.data == updated_variable.model_dump_json()
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
updater.update(conversation_id=conversation_id, variable=variable)
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
result = updater.flush()
assert result is None

View File

@ -0,0 +1,104 @@
"""Testcontainers integration tests for CreditPoolService."""
from uuid import uuid4
import pytest
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from models.enums import ProviderQuotaType
from services.credit_pool_service import CreditPoolService
class TestCreditPoolService:
def _create_tenant_id(self) -> str:
return str(uuid4())
def test_create_default_pool(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
assert isinstance(pool, TenantCreditPool)
assert pool.tenant_id == tenant_id
assert pool.pool_type == ProviderQuotaType.TRIAL
assert pool.quota_used == 0
assert pool.quota_limit > 0
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
assert result is not None
assert result.tenant_id == tenant_id
assert result.pool_type == ProviderQuotaType.TRIAL
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
assert result is None
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers):
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
assert result is False
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10)
assert result is True
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
# Exhaust credits
pool.quota_used = pool.quota_limit
db_session_with_containers.commit()
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1)
assert result is False
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers):
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
pool.quota_used = pool.quota_limit
db_session_with_containers.commit()
with pytest.raises(QuotaExceededError, match="No credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
credits_required = 10
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required)
assert result == credits_required
db_session_with_containers.expire_all()
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
db_session_with_containers.commit()
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
assert result == remaining
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == pool.quota_limit

View File

@ -397,6 +397,68 @@ class TestDatasetPermissionServiceClearPartialMemberList:
class TestDatasetServiceCheckDatasetPermission:
"""Verify dataset access checks against persisted partial-member permissions."""
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers):
"""Test that users from different tenants cannot access dataset."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM
)
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other_user)
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers):
"""Test that tenant owners can access any dataset regardless of permission level."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
DatasetService.check_dataset_permission(dataset, owner)
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers):
"""Test ONLY_ME permission allows only the dataset creator to access."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
DatasetService.check_dataset_permission(dataset, creator)
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers):
"""Test ONLY_ME permission denies access to non-creators."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME
)
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other)
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers):
"""Test ALL_TEAM permission allows any team member to access the dataset."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.NORMAL, tenant=tenant
)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM
)
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
"""
Test that user with explicit permission can access partial_members dataset.
@ -443,6 +505,16 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers):
"""Test PARTIAL_TEAM permission allows creator to access without explicit permission."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
dataset = DatasetPermissionTestDataFactory.create_dataset(
tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM
)
DatasetService.check_dataset_permission(dataset, creator)
class TestDatasetServiceCheckDatasetOperatorPermission:
"""Verify operator permission checks against persisted partial-member permissions."""

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory:
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by or str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = document_id or str(uuid4())
document.enabled = enabled
@ -694,3 +695,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1)
patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id)
def test_batch_update_invalid_action_raises_value_error(
self, db_session_with_containers: Session, patched_dependencies
):
"""Test that an invalid action raises ValueError."""
factory = DocumentBatchUpdateIntegrationDataFactory
dataset = factory.create_dataset(db_session_with_containers)
doc = factory.create_document(db_session_with_containers, dataset)
user = UserDouble(id=str(uuid4()))
patched_dependencies["redis_client"].get.return_value = None
with pytest.raises(ValueError, match="Invalid action"):
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user
)

View File

@ -0,0 +1,60 @@
"""Testcontainers integration tests for DatasetService.create_empty_rag_pipeline_dataset."""
from __future__ import annotations
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from models.account import Account, Tenant, TenantAccountJoin
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
class TestDatasetServiceCreateRagPipelineDataset:
def _create_tenant_and_account(self, db_session_with_containers) -> tuple[Tenant, Account]:
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"ds_create_{uuid4()}@example.com",
password="hashed",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return tenant, account
def _build_entity(self, name: str = "Test Dataset") -> RagPipelineDatasetCreateEntity:
icon_info = IconInfo(icon="\U0001f4d9", icon_background="#FFF4ED", icon_type="emoji")
return RagPipelineDatasetCreateEntity(
name=name,
description="",
icon_info=icon_info,
permission="only_me",
)
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers):
tenant, _ = self._create_tenant_and_account(db_session_with_containers)
mock_user = Mock(id=None)
with patch("services.dataset_service.current_user", mock_user):
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=tenant.id,
rag_pipeline_dataset_create_entity=self._build_entity(),
)

View File

@ -3,6 +3,7 @@
from unittest.mock import patch
from uuid import uuid4
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom
@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory:
tenant_id: str,
dataset_id: str,
created_by: str,
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
) -> Document:
"""Persist a document so dataset.doc_form resolves through the real document path."""
document = Document(
@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset:
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=owner.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
# Act

View File

@ -3,6 +3,7 @@ from uuid import uuid4
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
@ -42,7 +43,7 @@ def _create_document(
name=f"doc-{uuid4()}",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = str(uuid4())
document.indexing_status = indexing_status
@ -142,3 +143,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c
rows = db_session_with_containers.scalars(filtered).all()
assert {row.id for row in rows} == {doc1.id, doc2.id}
def test_normalize_display_status_alias_mapping():
"""Test that normalize_display_status maps aliases correctly."""
assert DocumentService.normalize_display_status("ACTIVE") == "available"
assert DocumentService.normalize_display_status("enabled") == "available"
assert DocumentService.normalize_display_status("archived") == "archived"
assert DocumentService.normalize_display_status("unknown") is None

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
@ -69,7 +70,7 @@ def make_document(
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
doc.id = document_id
doc.indexing_status = "completed"

View File

@ -414,3 +414,144 @@ class TestEndUserServiceGetEndUserById:
)
assert result is None
class TestEndUserServiceCreateBatch:
"""Integration tests for EndUserService.create_end_user_batch."""
@pytest.fixture
def factory(self):
return TestEndUserServiceFactory()
def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3):
"""Create multiple apps under the same tenant."""
first_app = factory.create_app_and_account(db_session_with_containers)
tenant_id = first_app.tenant_id
apps = [first_app]
for _ in range(count - 1):
app = App(
tenant_id=tenant_id,
name=f"App {uuid4()}",
description="",
mode="chat",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=first_app.created_by,
updated_by=first_app.updated_by,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all()
return tenant_id, all_apps
def test_create_batch_empty_app_ids(self, db_session_with_containers):
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1"
)
assert result == {}
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
app_ids = [a.id for a in apps]
user_id = f"user-{uuid4()}"
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
)
assert len(result) == 3
for app_id in app_ids:
assert app_id in result
assert result[app_id].session_id == user_id
assert result[app_id].type == InvokeFrom.SERVICE_API
def test_create_batch_default_session_id(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [a.id for a in apps]
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=""
)
assert len(result) == 2
for end_user in result.values():
assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
assert end_user._is_anonymous is True
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id]
user_id = f"user-{uuid4()}"
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
)
assert len(result) == 2
def test_create_batch_returns_existing_users(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [a.id for a in apps]
user_id = f"user-{uuid4()}"
# Create batch first time
first_result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
)
# Create batch second time — should return existing users
second_result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id
)
assert len(second_result) == 2
for app_id in app_ids:
assert first_result[app_id].id == second_result[app_id].id
def test_create_batch_partial_existing_users(self, db_session_with_containers, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
user_id = f"user-{uuid4()}"
# Create for first 2 apps
first_result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_ids=[apps[0].id, apps[1].id],
user_id=user_id,
)
# Create for all 3 apps — should reuse first 2, create 3rd
all_result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_ids=[a.id for a in apps],
user_id=user_id,
)
assert len(all_result) == 3
assert all_result[apps[0].id].id == first_result[apps[0].id].id
assert all_result[apps[1].id].id == first_result[apps[1].id].id
assert all_result[apps[2].id].session_id == user_id
@pytest.mark.parametrize(
"invoke_type",
[InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER],
)
def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1)
user_id = f"user-{uuid4()}"
result = EndUserService.create_end_user_batch(
type=invoke_type, tenant_id=tenant_id, app_ids=[apps[0].id], user_id=user_id
)
assert len(result) == 1
assert result[apps[0].id].type == invoke_type

View File

@ -0,0 +1,96 @@
"""
Testcontainers integration tests for FileService helpers.
Covers:
- ZIP tempfile building (sanitization + deduplication + content writes)
- tenant-scoped batch lookup behavior (get_upload_files_by_ids)
"""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from uuid import uuid4
from zipfile import ZipFile
import pytest
import services.file_service as file_service_module
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.file_service import FileService
def _create_upload_file(db_session, *, tenant_id: str, key: str, name: str) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=StorageType.OPENDAL,
key=key,
name=name,
size=100,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=datetime.now(UTC),
used=False,
)
db_session.add(upload_file)
db_session.commit()
return upload_file
def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None:
"""Ensure ZIP entry names are safe and unique while preserving extensions."""
upload_files: list[Any] = [
SimpleNamespace(name="a/b.txt", key="k1"),
SimpleNamespace(name="c/b.txt", key="k2"),
SimpleNamespace(name="../b.txt", key="k3"),
]
data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]}
def _load(key: str, stream: bool = True) -> list[bytes]:
assert stream is True
return data_by_key[key]
monkeypatch.setattr(file_service_module.storage, "load", _load)
with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp:
with ZipFile(tmp, mode="r") as zf:
assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"]
assert zf.read("b.txt") == b"one"
assert zf.read("b (1).txt") == b"two"
assert zf.read("b (2).txt") == b"three"
def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers) -> None:
"""Ensure empty input returns an empty mapping without hitting the database."""
assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {}
def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers) -> None:
"""Ensure batch lookup returns a dict keyed by stringified UploadFile ids."""
tenant_id = str(uuid4())
file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt")
file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt")
result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id])
assert set(result.keys()) == {file1.id, file2.id}
assert result[file1.id].id == file1.id
assert result[file2.id].id == file2.id
def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers) -> None:
"""Ensure files from other tenants are not returned."""
tenant_a = str(uuid4())
tenant_b = str(uuid4())
file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt")
_create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt")
result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id])
assert set(result.keys()) == {file_a.id}

View File

@ -8,6 +8,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.file.enums import FileType
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
# MessageFile
file = MessageFile(
message_id=message.id,
type="image",
type=FileType.IMAGE,
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to=MessageFileBelongsTo.USER,

View File

@ -5,6 +5,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
@ -139,7 +140,7 @@ class TestMetadataService:
name=fake.file_name(),
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
)

View File

@ -0,0 +1,174 @@
"""Testcontainers integration tests for OAuthServerService."""
from __future__ import annotations
import uuid
from typing import cast
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import BadRequest
from models.model import OAuthProviderApp
from services.oauth_server import (
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
OAUTH_ACCESS_TOKEN_REDIS_KEY,
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
OAUTH_REFRESH_TOKEN_REDIS_KEY,
OAuthGrantType,
OAuthServerService,
)
class TestOAuthServerServiceGetProviderApp:
"""DB-backed tests for get_oauth_provider_app."""
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
app = OAuthProviderApp(
app_icon="icon.png",
client_id=client_id,
client_secret=str(uuid4()),
app_label={"en-US": "Test OAuth App"},
redirect_uris=["https://example.com/callback"],
scope="read",
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
client_id = f"client-{uuid4()}"
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
result = OAuthServerService.get_oauth_provider_app(client_id)
assert result is not None
assert result.client_id == client_id
assert result.id == created.id
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
assert result is None
class TestOAuthServerServiceTokenOperations:
"""Redis-backed tests for token sign/validate operations."""
@pytest.fixture
def mock_redis(self):
with patch("services.oauth_server.redis_client") as mock:
yield mock
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
assert code == str(deterministic_uuid)
mock_redis.set.assert_called_once_with(
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
"user-1",
ex=600,
)
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
mock_redis.get.return_value = None
with pytest.raises(BadRequest, match="invalid code"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="bad-code",
client_id="client-1",
)
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
token_uuids = [
uuid.UUID("00000000-0000-0000-0000-000000000201"),
uuid.UUID("00000000-0000-0000-0000-000000000202"),
]
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
mock_redis.get.return_value = b"user-1"
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="code-1",
client_id="client-1",
)
assert access_token == str(token_uuids[0])
assert refresh_token == str(token_uuids[1])
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
mock_redis.delete.assert_called_once_with(code_key)
mock_redis.set.assert_any_call(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
mock_redis.set.assert_any_call(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
b"user-1",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
mock_redis.get.return_value = None
with pytest.raises(BadRequest, match="invalid refresh token"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="stale-token",
client_id="client-1",
)
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
mock_redis.get.return_value = b"user-1"
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="refresh-1",
client_id="client-1",
)
assert access_token == str(deterministic_uuid)
assert returned_refresh == "refresh-1"
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
grant_type = cast(OAuthGrantType, "invalid-grant-type")
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
assert result is None
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
assert refresh_token == str(deterministic_uuid)
mock_redis.set.assert_called_once_with(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
"user-2",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
mock_redis.get.return_value = None
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
assert result is None
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
mock_redis.get.return_value = b"user-88"
expected_user = MagicMock()
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
assert result is expected_user
mock_load.assert_called_once_with("user-88")

View File

@ -396,11 +396,6 @@ class TestSavedMessageService:
assert "User is required" in str(exc_info.value)
# Verify no database operations were performed
saved_messages = db_session_with_containers.query(SavedMessage).all()
assert len(saved_messages) == 0
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
@ -497,124 +492,140 @@ class TestSavedMessageService:
# The message should still exist, only the saved_message should be deleted
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None
def test_pagination_by_last_id_error_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when no user is provided.
This test verifies:
- Proper error handling for missing user
- ValueError is raised when user is None
- No database operations are performed
"""
# Arrange: Create test data
fake = Faker()
def test_save_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""Test saving a message for an EndUser."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, end_user)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
mock_external_service_dependencies["message_service"].get_message.return_value = message
assert "User is required" in str(exc_info.value)
SavedMessageService.save(app_model=app, user=end_user, message_id=message.id)
# Verify no database operations were performed for this specific test
# Note: We don't check total count as other tests may have created data
# Instead, we verify that the error was properly raised
pass
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
This test verifies:
- Method returns early when user is None
- No database operations are performed
- No exceptions are raised
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Act: Execute the method under test with None user
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
# Assert: Verify the expected outcomes
assert result is None
# Verify no saved message was created
saved_message = (
saved = (
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
)
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
.first()
)
assert saved is not None
assert saved.created_by == end_user.id
assert saved.created_by_role == "end_user"
assert saved_message is None
def test_delete_success_existing_message(
def test_save_duplicate_is_idempotent(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful deletion of an existing saved message.
This test verifies:
- Proper deletion of existing saved message
- Correct database state after deletion
- No errors during deletion process
"""
# Arrange: Create test data
fake = Faker()
"""Test that saving an already-saved message does not create a duplicate."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Create a saved message first
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by=account.id,
)
mock_external_service_dependencies["message_service"].get_message.return_value = message
db_session_with_containers.add(saved_message)
# Save once
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
# Save again
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
count = (
db_session_with_containers.query(SavedMessage)
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
.count()
)
assert count == 1
def test_delete_without_user_does_nothing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that deleting without a user is a no-op."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Pre-create a saved message
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
# Verify saved message exists
SavedMessageService.delete(app_model=app, user=None, message_id=message.id)
# Should still exist
assert (
db_session_with_containers.query(SavedMessage)
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
.first()
is not None
)
def test_delete_non_existent_does_nothing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that deleting a non-existent saved message is a no-op."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Should not raise — use a valid UUID that doesn't exist in DB
from uuid import uuid4
SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4()))
def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""Test deleting a saved message for an EndUser."""
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, end_user)
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id)
assert (
db_session_with_containers.query(SavedMessage)
.where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id)
.first()
is None
)
def test_delete_only_affects_own_saved_messages(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that delete only removes the requesting user's saved message."""
app, account1 = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, account1)
# Both users save the same message
saved_account = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id
)
saved_end_user = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id
)
db_session_with_containers.add_all([saved_account, saved_end_user])
db_session_with_containers.commit()
# Delete only account1's saved message
SavedMessageService.delete(app_model=app, user=account1, message_id=message.id)
# Account's saved message should be gone
assert (
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
SavedMessage.created_by == account1.id,
)
.first()
is not None
is None
)
# Act: Execute the method under test
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
deleted_saved_message = (
# End user's saved message should still exist
assert (
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
SavedMessage.created_by == end_user.id,
)
.first()
is not None
)
assert deleted_saved_message is None
# Verify database state
db_session_with_containers.commit()
# The message should still exist, only the saved_message should be deleted
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None

View File

@ -9,7 +9,7 @@ from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset
from models.enums import DataSourceType
from models.enums import DataSourceType, TagType
from models.model import App, Tag, TagBinding
from services.tag_service import TagService
@ -547,7 +547,7 @@ class TestTagService:
assert result is not None
assert len(result) == 1
assert result[0].name == "python_tag"
assert result[0].type == "app"
assert result[0].type == TagType.APP
assert result[0].tenant_id == tenant.id
def test_get_tag_by_tag_name_no_matches(
@ -638,7 +638,7 @@ class TestTagService:
# Verify all tags are returned
for tag in result:
assert tag.type == "app"
assert tag.type == TagType.APP
assert tag.tenant_id == tenant.id
assert tag.id in [t.id for t in tags]

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from models.workflow import WorkflowAppLogCreatedFrom
from services.account_service import AccountService, TenantService
# Delay import of AppService to avoid circular dependency
@ -221,7 +222,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -357,7 +358,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_1.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -399,7 +400,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_2.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -441,7 +442,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_4.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -521,7 +522,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -627,7 +628,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -732,7 +733,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -860,7 +861,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -902,7 +903,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="web-app",
created_from=WorkflowAppLogCreatedFrom.WEB_APP,
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
@ -1037,7 +1038,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1125,7 +1126,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1279,7 +1280,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1379,7 +1380,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1481,7 +1482,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)

View File

@ -536,3 +536,151 @@ class TestApiToolManageService:
# Verify mock interactions
mock_external_service_dependencies["encrypter"].assert_called_once()
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
def test_delete_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test successful deletion of an API tool provider."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
schema = self._create_test_openapi_schema()
provider_name = fake.unique.word()
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy="",
custom_disclaimer="",
labels=[],
)
provider = (
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert provider is not None
result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name)
assert result == {"result": "success"}
deleted = (
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert deleted is None
def test_delete_api_tool_provider_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test deletion raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
def test_update_api_tool_provider_not_found(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test update raises ValueError when original provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="does not exists"):
ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name="new-name",
original_provider="nonexistent",
icon={},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
privacy_policy=None,
custom_disclaimer="",
labels=[],
)
def test_update_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test update raises ValueError when auth_type is missing from credentials."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
schema = self._create_test_openapi_schema()
provider_name = fake.unique.word()
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy="",
custom_disclaimer="",
labels=[],
)
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
original_provider=provider_name,
icon={},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy=None,
custom_disclaimer="",
labels=[],
)
def test_list_api_tool_provider_tools_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test listing tools raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
def test_test_api_tool_preview_invalid_schema_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test preview raises ValueError for invalid schema type."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="invalid schema type"):
ApiToolManageService.test_api_tool_preview(
tenant_id=tenant.id,
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type="bad-schema-type",
schema="schema",
)

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
@ -52,7 +52,7 @@ class TestToolTransformService:
user_id="test_user_id",
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
elif provider_type == "builtin":
@ -659,7 +659,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
@ -695,7 +695,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
@ -731,7 +731,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)

View File

@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService:
# After the fix, this should always be 0
# For now, we document that the record may exist, demonstrating the bug
# assert tool_count == 0 # Expected after fix
def test_delete_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test successful deletion of a workflow tool."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
tool_name = fake.unique.word()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
)
tool = (
db_session_with_containers.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name)
.first()
)
assert tool is not None
result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id)
assert result == {"result": "success"}
deleted = (
db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first()
)
assert deleted is None
def test_list_tenant_workflow_tools_empty(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test listing workflow tools when none exist returns empty list."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id)
assert result == []
def test_get_workflow_tool_by_tool_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that get_workflow_tool_by_tool_id raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4())
def test_get_workflow_tool_by_app_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that get_workflow_tool_by_app_id raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4())
def test_list_single_workflow_tools_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that list_single_workflow_tools raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4())
def test_create_workflow_tool_with_labels(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that labels are forwarded to ToolLabelManager when provided."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.unique.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
labels=["label-1", "label-2"],
)
assert result == {"result": "success"}
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()

View File

@ -0,0 +1,158 @@
"""Testcontainers integration tests for WorkflowService.delete_workflow."""
import json
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session, sessionmaker
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin
from models.model import App
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
class TestWorkflowDeletion:
def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]:
tenant = Tenant(name=f"Tenant {uuid4()}")
session.add(tenant)
session.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"wf_del_{uuid4()}@example.com",
password="hashed",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
session.add(account)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
session.add(join)
session.flush()
return tenant, account
def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App:
app = App(
tenant_id=tenant.id,
name=f"App {uuid4()}",
description="",
mode="workflow",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
workflow_id=workflow_id,
)
session.add(app)
session.flush()
return app
def _create_workflow(
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0"
) -> Workflow:
workflow = Workflow(
id=str(uuid4()),
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
version=version,
graph=json.dumps({"nodes": [], "edges": []}),
_features=json.dumps({}),
created_by=account.id,
updated_by=account.id,
)
session.add(workflow)
session.flush()
return workflow
def _create_tool_provider(
self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str
) -> WorkflowToolProvider:
provider = WorkflowToolProvider(
name=f"tool-{uuid4()}",
label=f"Tool {uuid4()}",
icon="wrench",
app_id=app.id,
version=version,
user_id=account.id,
tenant_id=tenant.id,
description="test tool provider",
)
session.add(provider)
session.flush()
return provider
def test_delete_workflow_success(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
db_session_with_containers.commit()
workflow_id = workflow.id
service = WorkflowService(sessionmaker(bind=db.engine))
result = service.delete_workflow(
session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id
)
assert result is True
db_session_with_containers.expire_all()
assert db_session_with_containers.get(Workflow, workflow_id) is None
def test_delete_draft_workflow_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="draft"
)
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(DraftWorkflowDeletionError):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
# Point app to this workflow
app.workflow_id = workflow.id
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(WorkflowInUseError, match="currently in use by app"):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0"
)
self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0")
db_session_with_containers.commit()
service = WorkflowService(sessionmaker(bind=db.engine))
with pytest.raises(WorkflowInUseError, match="published as a tool"):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask:
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Execute the task with non-existent dataset
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
batch_clean_document_task(
document_ids=[document_id],
dataset_id=dataset_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
file_ids=[],
)
# Verify that no index processing occurred
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask:
account = self._create_test_account(db_session_with_containers)
# Test different doc_form types
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
doc_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in doc_forms:
dataset = self._create_test_dataset(db_session_with_containers, account)

View File

@ -19,6 +19,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -179,7 +180,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
)
@ -221,17 +222,17 @@ class TestBatchCreateSegmentToIndexTask:
return upload_file
def _create_test_csv_content(self, content_type="text_model"):
def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX):
"""
Helper method to create test CSV content.
Args:
content_type: Type of content to create ("text_model" or "qa_model")
content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX)
Returns:
str: CSV content as string
"""
if content_type == "qa_model":
if content_type == IndexStructureType.QA_INDEX:
csv_content = "content,answer\n"
csv_content += "This is the first segment content,This is the first answer\n"
csv_content += "This is the second segment content,This is the second answer\n"
@ -264,7 +265,7 @@ class TestBatchCreateSegmentToIndexTask:
upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant)
# Create CSV content
csv_content = self._create_test_csv_content("text_model")
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
# Mock storage to return our CSV content
mock_storage = mock_external_service_dependencies["storage"]
@ -451,7 +452,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Document is disabled
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
# Archived document
@ -467,7 +468,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Document is archived
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
# Document with incomplete indexing
@ -483,7 +484,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.INDEXING, # Not completed
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
]
@ -655,7 +656,7 @@ class TestBatchCreateSegmentToIndexTask:
db_session_with_containers.commit()
# Create CSV content
csv_content = self._create_test_csv_content("text_model")
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
# Mock storage to return our CSV content
mock_storage = mock_external_service_dependencies["storage"]

View File

@ -18,6 +18,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
@ -192,7 +193,7 @@ class TestCleanDatasetTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=100,
created_at=datetime.now(),
updated_at=datetime.now(),

Some files were not shown because too many files have changed in this diff Show More