Compare commits

...

18 Commits

Author SHA1 Message Date
Mahmoud Hamdy 47a53d0bea
Merge db26c2d88f into 8b634a9bee 2026-03-24 06:31:23 +02:00
tmimmanuel 8b634a9bee
refactor: use EnumText for ApiToolProvider.schema_type_str and Docume… (#33983) 2026-03-24 13:27:50 +09:00
BitToby ecd3a964c1
refactor(api): type auth service credentials with TypedDict (#33867) 2026-03-24 13:22:17 +09:00
yyh 0589fa423b
fix(sdk): patch flatted vulnerability in nodejs client lockfile (#33996) 2026-03-24 11:24:31 +08:00
Stephen Zhou 27c4faad4f
ci: update actions version, fix cache (#33950) 2026-03-24 10:52:27 +08:00
wangxiaolei fbd558762d
fix: fix chunk not display in indexed document (#33942) 2026-03-24 10:36:48 +08:00
yyh 075b8bf1ae
fix(web): update account settings header (#33965) 2026-03-24 10:04:08 +08:00
Desel72 49a1fae555
test: migrate password reset tests to testcontainers (#33974)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 06:04:34 +09:00
tmimmanuel cc17c8e883
refactor: use EnumText for TidbAuthBinding.status and MessageFile.type (#33975) 2026-03-24 05:38:29 +09:00
tmimmanuel 5d2cb3cd80
refactor: use EnumText for DocumentSegment.type (#33979) 2026-03-24 05:37:51 +09:00
Desel72 f2c71f3668
test: migrate oauth server service tests to testcontainers (#33958) 2026-03-24 03:15:22 +09:00
Desel72 0492ed7034
test: migrate api tools manage service tests to testcontainers (#33956)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 02:54:33 +09:00
Renzo dd4f504b39
refactor: select in remaining console app controllers (#33969)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 02:53:05 +09:00
tmimmanuel 75c3ef82d9
refactor: use EnumText for TenantCreditPool.pool_type (#33959) 2026-03-24 02:51:10 +09:00
Desel72 8ca1ebb96d
test: migrate workflow tools manage service tests to testcontainers (#33955)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 02:50:10 +09:00
Desel72 3f086b97b6
test: remove mock tests superseded by testcontainers (#33957) 2026-03-24 02:46:54 +09:00
tmimmanuel 4a2e9633db
refactor: use EnumText for ApiToken.type (#33961) 2026-03-24 02:46:06 +09:00
tmimmanuel 20fc69ae7f
refactor: use EnumText for WorkflowAppLog.created_from and WorkflowArchiveLog columns (#33954) 2026-03-24 02:44:46 +09:00
107 changed files with 1027 additions and 2546 deletions

View File

@ -4,10 +4,9 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
with:
node-version-file: web/.nvmrc
working-directory: web
node-version-file: .nvmrc
cache: true
cache-dependency-path: web/pnpm-lock.yaml
run-install: |
cwd: ./web
run-install: true

View File

@ -84,20 +84,20 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
vp run lint:ci
# pnpm run lint:report
# continue-on-error: true
# - name: Annotate Code
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
# with:
# eslint-report: web/eslint_report.json
# github-token: ${{ secrets.GITHUB_TOKEN }}
run: vp run lint:ci
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
@ -114,6 +114,13 @@ jobs:
working-directory: ./web
run: vp run knip
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -10,6 +10,7 @@ from configs import dify_config
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
@ -269,7 +270,7 @@ def migrate_knowledge_vector_database():
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == "hierarchical_model":
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []

View File

@ -9,6 +9,7 @@ from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
resource_type: ApiTokenType | None = None
resource_model: type | None = None
resource_id_field: str | None = None
token_prefix: str | None = None
@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource):
)
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
assert self.resource_type is not None, "resource_type must be set"
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_tenant_id
@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource):
class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
resource_type: ApiTokenType | None = None
resource_model: type | None = None
resource_id_field: str | None = None
@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app"""
return super().post(resource_id)
resource_type = "app"
resource_type = ApiTokenType.APP
resource_model = App
resource_id_field = "app_id"
token_prefix = "app-"
@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
resource_type = "app"
resource_type = ApiTokenType.APP
resource_model = App
resource_id_field = "app_id"
@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset"""
return super().post(resource_id)
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
resource_model = Dataset
resource_id_field = "dataset_id"
token_prefix = "ds-"
@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
resource_model = Dataset
resource_id_field = "dataset_id"

View File

@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
subquery = (
db.session.query(
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
)
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery()
)
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
)
if not conversation:

View File

@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
app = db.session.get(App, args.flow_id)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)

View File

@ -2,6 +2,7 @@ import json
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
return server
@console_ns.doc("create_app_mcp_server")
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
server = db.session.get(AppMCPServer, payload.id)
if not server:
raise NotFound()
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
server = db.session.scalar(
select(AppMCPServer)
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
.limit(1)
)
if not server:
raise NotFound()

View File

@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
)
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
if original_app_model_config is None:
raise ValueError("Original app model config not found")
agent_mode = original_app_model_config.agent_mode_dict

View File

@ -2,6 +2,7 @@ from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
@ -75,7 +76,7 @@ class AppSite(Resource):
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound

View File

@ -2,6 +2,8 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from sqlalchemy import select
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
return app_model

View File

@ -54,7 +54,7 @@ from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import SegmentStatus
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -777,7 +777,7 @@ class DatasetIndexingStatusApi(Resource):
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = "dataset-"
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@ -826,7 +826,7 @@ class DatasetApiKeyApi(Resource):
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
class DatasetApiDeleteApi(Resource):
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
@console_ns.doc("delete_dataset_api_key")
@console_ns.doc(description="Delete dataset API key")

View File

@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
app_id=self._application_generate_entity.app_config.app_id,
workflow_id=self._workflow.id,
workflow_run_id=workflow_run_id,
created_from=created_from.value,
created_from=created_from,
created_by_role=self._created_by_role,
created_by=self._user_id,
)

View File

@ -918,11 +918,11 @@ class ProviderManager:
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
pool_type=ProviderQuotaType.TRIAL,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
pool_type=ProviderQuotaType.PAID,
)
else:
trail_pool = None

View File

@ -33,6 +33,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, TidbAuthBinding
from models.enums import TidbAuthBindingStatus
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
password=new_cluster["password"],
tenant_id=dataset.tenant_id,
active=True,
status="ACTIVE",
status=TidbAuthBindingStatus.ACTIVE,
)
db.session.add(new_tidb_auth_binding)
db.session.commit()

View File

@ -9,6 +9,7 @@ from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
class TidbService:
@ -170,7 +171,7 @@ class TidbService:
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = "ACTIVE"
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()

View File

@ -43,7 +43,9 @@ from .enums import (
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SegmentType,
SummaryStatus,
TidbAuthBindingStatus,
)
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
@ -494,7 +496,9 @@ class Document(Base):
)
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_form: Mapped[IndexStructureType] = mapped_column(
EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'")
)
doc_language = mapped_column(String(255), nullable=True)
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@ -998,7 +1002,9 @@ class ChildChunk(Base):
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase):
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
status: Mapped[TidbAuthBindingStatus] = mapped_column(
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
)
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(

View File

@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum):
TIME = "time"
class SegmentType(StrEnum):
"""Document segment type"""
AUTOMATIC = "automatic"
CUSTOMIZED = "customized"
class SegmentStatus(StrEnum):
"""Document segment status"""
@ -323,3 +330,10 @@ class ProviderQuotaType(StrEnum):
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ApiTokenType(StrEnum):
"""API Token type"""
APP = "app"
DATASET = "dataset"

View File

@ -21,7 +21,7 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.tools.signature import sign_tool_file
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from dify_graph.file import helpers as file_helpers
from extensions.storage.storage_type import StorageType
from libs.helper import generate_string # type: ignore[import-not-found]
@ -31,6 +31,7 @@ from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db
from .enums import (
ApiTokenType,
AppMCPServerStatus,
AppStatus,
BannerStatus,
@ -43,6 +44,7 @@ from .enums import (
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
ProviderQuotaType,
TagType,
)
from .provider_ids import GenericProviderID
@ -1783,7 +1785,7 @@ class MessageFile(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(
EnumText(FileTransferMethod, length=255), nullable=False
)
@ -2095,7 +2097,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=True)
tenant_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(String(16), nullable=False)
type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False)
token: Mapped[str] = mapped_column(String(255), nullable=False)
last_used_at = mapped_column(sa.DateTime, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -2490,7 +2492,9 @@ class TenantCreditPool(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
pool_type: Mapped[ProviderQuotaType] = mapped_column(
EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial"
)
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(

View File

@ -145,7 +145,9 @@ class ApiToolProvider(TypeBase):
icon: Mapped[str] = mapped_column(String(255), nullable=False)
# original schema
schema: Mapped[str] = mapped_column(LongText, nullable=False)
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column(
EnumText(ApiProviderSchemaType, length=40), nullable=False
)
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id

View File

@ -1221,7 +1221,9 @@ class WorkflowAppLog(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column(
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False
)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
@ -1301,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase):
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column(
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True
)
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
run_status: Mapped[WorkflowExecutionStatus] = mapped_column(
EnumText(WorkflowExecutionStatus, length=255), nullable=False
)
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
)

View File

@ -8,6 +8,7 @@ from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
@app.celery.task(queue="dataset")
@ -57,7 +58,7 @@ def create_clusters(batch_size):
account=new_cluster["account"],
password=new_cluster["password"],
active=False,
status="CREATING",
status=TidbAuthBindingStatus.CREATING,
)
db.session.add(tidb_auth_binding)
db.session.commit()

View File

@ -9,6 +9,7 @@ from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
@app.celery.task(queue="dataset")
@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
try:
# check the number of idle tidb serverless
tidb_serverless_list = db.session.scalars(
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
select(TidbAuthBinding).where(
TidbAuthBinding.active == False,
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
)
).all()
if len(tidb_serverless_list) == 0:
return

View File

@ -1,8 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import TypedDict
class AuthCredentials(TypedDict):
auth_type: str
config: dict[str, Any]
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
self.credentials = credentials
@abstractmethod

View File

@ -1,9 +1,9 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
def __init__(self, provider: str, credentials: AuthCredentials):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -3,11 +3,11 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":

View File

@ -7,6 +7,7 @@ from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
from models.enums import ProviderQuotaType
logger = logging.getLogger(__name__)
@ -16,7 +17,10 @@ class CreditPoolService:
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
credit_pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
tenant_id=tenant_id,
quota_limit=dify_config.HOSTED_POOL_CREDITS,
quota_used=0,
pool_type=ProviderQuotaType.TRIAL,
)
db.session.add(credit_pool)
db.session.commit()

View File

@ -58,6 +58,7 @@ from models.enums import (
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SegmentType,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
@ -1439,7 +1440,7 @@ class DocumentService:
.filter(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != "qa_model", # Skip qa_model documents
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
)
@ -2039,7 +2040,7 @@ class DocumentService:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_form = IndexStructureType(knowledge_config.doc_form)
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
@ -2639,7 +2640,7 @@ class DocumentService:
document.splitting_completed_at = None
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = document_data.doc_form
document.doc_form = IndexStructureType(document_data.doc_form)
db.session.add(document)
db.session.commit()
# update document segment
@ -3100,7 +3101,7 @@ class DocumentService:
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
if "answer" not in args or not args["answer"]:
raise ValueError("Answer is required")
if not args["answer"].strip():
@ -3157,7 +3158,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
@ -3231,7 +3232,7 @@ class SegmentService:
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
@ -3254,7 +3255,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
@ -3321,7 +3322,7 @@ class SegmentService:
content = args.content or segment.content
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3418,7 +3419,7 @@ class SegmentService:
)
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
else:
@ -3435,7 +3436,7 @@ class SegmentService:
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3786,7 +3787,7 @@ class SegmentService:
child_chunk.word_count = len(child_chunk.content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
child_chunk.type = SegmentType.CUSTOMIZED
update_child_chunks.append(child_chunk)
else:
new_child_chunks_args.append(child_chunk_update_args)
@ -3845,7 +3846,7 @@ class SegmentService:
child_chunk.word_count = len(content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
child_chunk.type = SegmentType.CUSTOMIZED
db.session.add(child_chunk)
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
db.session.commit()

View File

@ -9,6 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
@ -79,9 +80,9 @@ class RagPipelineTransformService:
pipeline = self._create_pipeline(pipeline_yaml)
# save chunk structure to dataset
if doc_form == "hierarchical_model":
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset.chunk_structure = "hierarchical_model"
elif doc_form == "text_model":
elif doc_form == IndexStructureType.PARAGRAPH_INDEX:
dataset.chunk_structure = "text_model"
else:
raise ValueError("Unsupported doc form")
@ -101,7 +102,7 @@ class RagPipelineTransformService:
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
pipeline_yaml = {}
if doc_form == "text_model":
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
@ -132,7 +133,7 @@ class RagPipelineTransformService:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
elif doc_form == IndexStructureType.PARENT_CHILD_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
# get graph from transform.file-parentchild.yml

View File

@ -11,6 +11,7 @@ from sqlalchemy import func
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexStructureType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -109,7 +110,7 @@ def batch_create_segment_to_index_task(
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
@ -159,7 +160,7 @@ def batch_create_segment_to_index_task(
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count

View File

@ -10,6 +10,7 @@ from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
)
if (
document.indexing_status == IndexingStatus.COMPLETED
and document.doc_form != "qa_model"
and document.doc_form != IndexStructureType.QA_INDEX
and document.need_summary is True
):
try:

View File

@ -9,6 +9,7 @@ from celery import shared_task
from sqlalchemy import or_, select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
@ -106,7 +107,7 @@ def regenerate_summary_index_task(
),
DatasetDocument.enabled == True, # Document must be enabled
DatasetDocument.archived == False, # Document must not be archived
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
.all()
@ -209,7 +210,7 @@ def regenerate_summary_index_task(
for dataset_document in dataset_documents:
# Skip qa_model documents
if dataset_document.doc_form == "qa_model":
if dataset_document.doc_form == IndexStructureType.QA_INDEX:
continue
try:

View File

@ -179,7 +179,7 @@ def _record_trigger_failure_log(
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=created_by_role,
created_by=created_by,
)

View File

@ -13,6 +13,7 @@ from unittest.mock import patch
import pytest
from extensions.ext_redis import redis_client
from models.enums import ApiTokenType
from models.model import ApiToken
from services.api_token_service import ApiTokenCache, CachedApiToken
@ -279,7 +280,7 @@ class TestEndToEndCacheFlow:
test_token = ApiToken()
test_token.id = "test-e2e-id"
test_token.token = test_token_value
test_token.type = test_scope
test_token.type = ApiTokenType.APP
test_token.app_id = "test-app"
test_token.tenant_id = "test-tenant"
test_token.last_used_at = None

View File

@ -1,17 +1,10 @@
"""
Test suite for password reset authentication flows.
"""Testcontainers integration tests for password reset authentication flows."""
This module tests the password reset mechanism including:
- Password reset email sending
- Verification code validation
- Password reset with token
- Rate limiting and security checks
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.error import (
EmailCodeError,
@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import (
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
@pytest.fixture(autouse=True)
def _mock_forgot_password_session():
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.__exit__.return_value = None
yield mock_session
@pytest.fixture(autouse=True)
def _mock_forgot_password_db():
with patch("controllers.console.auth.forgot_password.db") as mock_db:
mock_db.engine = MagicMock()
yield mock_db
class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_account(self):
@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi:
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
mock_wraps_db,
app,
mock_account,
):
"""
Test successful password reset email sending.
Verifies that:
- Email is sent to valid account
- Reset token is generated and returned
- IP rate limiting is checked
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_account.return_value = mock_account
mock_send_email.return_value = "reset_token_123"
@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi:
assert response["data"] == "reset_token_123"
mock_send_email.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
"""
Test password reset email blocked by IP rate limit.
@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi:
- No email is sent when rate limited
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi:
(None, "en-US"), # Defaults to en-US when not provided
],
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
mock_wraps_db,
app,
mock_account,
language_input,
@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi:
- Unsupported languages default to en-US
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token"
@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi:
"""Test cases for verifying password reset codes."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
"""
@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi:
- Rate limit is reset on success
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_generate_token.return_value = (None, "new_token")
@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi:
)
mock_reset_rate_limit.assert_called_once_with("test@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
mock_generate_token.return_value = (None, "fresh-token")
@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi:
mock_revoke_token.assert_called_once_with("upper_token")
mock_reset_rate_limit.assert_called_once_with("user@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
"""
Test code verification blocked by rate limit.
@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi:
- Prevents brute force attacks on verification codes
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
# Act & Assert
@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi:
with pytest.raises(EmailPasswordResetLimitError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with invalid token.
@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = None
@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with mismatched email.
@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi:
- Prevents token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi:
with pytest.raises(InvalidEmailError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with incorrect code.
@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi:
- Rate limit counter is incremented
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
@ -380,11 +321,8 @@ class TestForgotPasswordResetApi:
"""Test cases for resetting password with verified token."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_account(self):
@ -394,7 +332,6 @@ class TestForgotPasswordResetApi:
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@ -405,7 +342,6 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_revoke_token,
mock_get_data,
mock_wraps_db,
app,
mock_account,
):
@ -418,7 +354,6 @@ class TestForgotPasswordResetApi:
- Success response is returned
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_get_account.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@ -436,9 +371,8 @@ class TestForgotPasswordResetApi:
assert response["result"] == "success"
mock_revoke_token.assert_called_once_with("valid_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
def test_reset_password_mismatch(self, mock_get_data, app):
"""
Test password reset with mismatched passwords.
@ -447,7 +381,6 @@ class TestForgotPasswordResetApi:
- No password update occurs
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
# Act & Assert
@ -460,9 +393,8 @@ class TestForgotPasswordResetApi:
with pytest.raises(PasswordMismatchError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
def test_reset_password_invalid_token(self, mock_get_data, app):
"""
Test password reset with invalid token.
@ -470,7 +402,6 @@ class TestForgotPasswordResetApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@ -483,9 +414,8 @@ class TestForgotPasswordResetApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
def test_reset_password_wrong_phase(self, mock_get_data, app):
"""
Test password reset with token not in reset phase.
@ -494,7 +424,6 @@ class TestForgotPasswordResetApi:
- Prevents use of verification-phase tokens for reset
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
# Act & Assert
@ -507,13 +436,10 @@ class TestForgotPasswordResetApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_reset_password_account_not_found(
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
):
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
"""
Test password reset for non-existent account.
@ -521,7 +447,6 @@ class TestForgotPasswordResetApi:
- AccountNotFound is raised when account doesn't exist
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
mock_get_account.return_value = None

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration:
name=f"Document {i}",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Archived
@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Disabled
archived=False,
@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=status, # Not completed
enabled=True,
archived=False,
@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()

View File

@ -27,7 +27,7 @@ from models.human_input import (
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
@ -218,7 +218,7 @@ class TestDeleteRunsWithRelated:
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)
@ -278,7 +278,7 @@ class TestCountRunsWithRelated:
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory:
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = document_id
document.indexing_status = indexing_status

View File

@ -6,6 +6,7 @@ import pytest
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from models.enums import ProviderQuotaType
from services.credit_pool_service import CreditPoolService
@ -20,7 +21,7 @@ class TestCreditPoolService:
assert isinstance(pool, TenantCreditPool)
assert pool.tenant_id == tenant_id
assert pool.pool_type == "trial"
assert pool.pool_type == ProviderQuotaType.TRIAL
assert pool.quota_used == 0
assert pool.quota_limit > 0
@ -28,14 +29,14 @@ class TestCreditPoolService:
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial")
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
assert result is not None
assert result.tenant_id == tenant_id
assert result.pool_type == "trial"
assert result.pool_type == ProviderQuotaType.TRIAL
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial")
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
assert result is None

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory:
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by or str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = document_id or str(uuid4())
document.enabled = enabled

View File

@ -3,6 +3,7 @@
from unittest.mock import patch
from uuid import uuid4
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom
@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory:
tenant_id: str,
dataset_id: str,
created_by: str,
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
) -> Document:
"""Persist a document so dataset.doc_form resolves through the real document path."""
document = Document(
@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset:
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=owner.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
# Act

View File

@ -3,6 +3,7 @@ from uuid import uuid4
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from services.dataset_service import DocumentService
@ -42,7 +43,7 @@ def _create_document(
name=f"doc-{uuid4()}",
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = str(uuid4())
document.indexing_status = indexing_status
@ -142,3 +143,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c
rows = db_session_with_containers.scalars(filtered).all()
assert {row.id for row in rows} == {doc1.id, doc2.id}
def test_normalize_display_status_alias_mapping():
"""Test that normalize_display_status maps aliases correctly."""
assert DocumentService.normalize_display_status("ACTIVE") == "available"
assert DocumentService.normalize_display_status("enabled") == "available"
assert DocumentService.normalize_display_status("archived") == "archived"
assert DocumentService.normalize_display_status("unknown") is None

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
@ -69,7 +70,7 @@ def make_document(
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=str(uuid4()),
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
doc.id = document_id
doc.indexing_status = "completed"

View File

@ -8,6 +8,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.file.enums import FileType
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
# MessageFile
file = MessageFile(
message_id=message.id,
type="image",
type=FileType.IMAGE,
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to=MessageFileBelongsTo.USER,

View File

@ -5,6 +5,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
@ -139,7 +140,7 @@ class TestMetadataService:
name=fake.file_name(),
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
)

View File

@ -0,0 +1,174 @@
"""Testcontainers integration tests for OAuthServerService."""
from __future__ import annotations
import uuid
from typing import cast
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import BadRequest
from models.model import OAuthProviderApp
from services.oauth_server import (
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
OAUTH_ACCESS_TOKEN_REDIS_KEY,
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
OAUTH_REFRESH_TOKEN_REDIS_KEY,
OAuthGrantType,
OAuthServerService,
)
class TestOAuthServerServiceGetProviderApp:
"""DB-backed tests for get_oauth_provider_app."""
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
app = OAuthProviderApp(
app_icon="icon.png",
client_id=client_id,
client_secret=str(uuid4()),
app_label={"en-US": "Test OAuth App"},
redirect_uris=["https://example.com/callback"],
scope="read",
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
client_id = f"client-{uuid4()}"
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
result = OAuthServerService.get_oauth_provider_app(client_id)
assert result is not None
assert result.client_id == client_id
assert result.id == created.id
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
assert result is None
class TestOAuthServerServiceTokenOperations:
"""Redis-backed tests for token sign/validate operations."""
@pytest.fixture
def mock_redis(self):
with patch("services.oauth_server.redis_client") as mock:
yield mock
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
assert code == str(deterministic_uuid)
mock_redis.set.assert_called_once_with(
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
"user-1",
ex=600,
)
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
mock_redis.get.return_value = None
with pytest.raises(BadRequest, match="invalid code"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="bad-code",
client_id="client-1",
)
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
token_uuids = [
uuid.UUID("00000000-0000-0000-0000-000000000201"),
uuid.UUID("00000000-0000-0000-0000-000000000202"),
]
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
mock_redis.get.return_value = b"user-1"
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="code-1",
client_id="client-1",
)
assert access_token == str(token_uuids[0])
assert refresh_token == str(token_uuids[1])
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
mock_redis.delete.assert_called_once_with(code_key)
mock_redis.set.assert_any_call(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
mock_redis.set.assert_any_call(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
b"user-1",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
mock_redis.get.return_value = None
with pytest.raises(BadRequest, match="invalid refresh token"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="stale-token",
client_id="client-1",
)
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
mock_redis.get.return_value = b"user-1"
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="refresh-1",
client_id="client-1",
)
assert access_token == str(deterministic_uuid)
assert returned_refresh == "refresh-1"
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
grant_type = cast(OAuthGrantType, "invalid-grant-type")
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
assert result is None
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
assert refresh_token == str(deterministic_uuid)
mock_redis.set.assert_called_once_with(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
"user-2",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
mock_redis.get.return_value = None
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
assert result is None
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
mock_redis.get.return_value = b"user-88"
expected_user = MagicMock()
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
assert result is expected_user
mock_load.assert_called_once_with("user-88")

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from models.workflow import WorkflowAppLogCreatedFrom
from services.account_service import AccountService, TenantService
# Delay import of AppService to avoid circular dependency
@ -221,7 +222,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -357,7 +358,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_1.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -399,7 +400,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_2.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -441,7 +442,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_4.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -521,7 +522,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -627,7 +628,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -732,7 +733,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -860,7 +861,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -902,7 +903,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="web-app",
created_from=WorkflowAppLogCreatedFrom.WEB_APP,
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
@ -1037,7 +1038,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1125,7 +1126,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1279,7 +1280,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1379,7 +1380,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -1481,7 +1482,7 @@ class TestWorkflowAppService:
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)

View File

@ -536,3 +536,151 @@ class TestApiToolManageService:
# Verify mock interactions
mock_external_service_dependencies["encrypter"].assert_called_once()
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
def test_delete_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test successful deletion of an API tool provider."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
schema = self._create_test_openapi_schema()
provider_name = fake.unique.word()
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy="",
custom_disclaimer="",
labels=[],
)
provider = (
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert provider is not None
result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name)
assert result == {"result": "success"}
deleted = (
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert deleted is None
def test_delete_api_tool_provider_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test deletion raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
def test_update_api_tool_provider_not_found(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test update raises ValueError when original provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="does not exists"):
ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name="new-name",
original_provider="nonexistent",
icon={},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
privacy_policy=None,
custom_disclaimer="",
labels=[],
)
def test_update_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test update raises ValueError when auth_type is missing from credentials."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
schema = self._create_test_openapi_schema()
provider_name = fake.unique.word()
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy="",
custom_disclaimer="",
labels=[],
)
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
original_provider=provider_name,
icon={},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy=None,
custom_disclaimer="",
labels=[],
)
def test_list_api_tool_provider_tools_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test listing tools raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
def test_test_api_tool_preview_invalid_schema_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test preview raises ValueError for invalid schema type."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="invalid schema type"):
ApiToolManageService.test_api_tool_preview(
tenant_id=tenant.id,
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type="bad-schema-type",
schema="schema",
)

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
@ -52,7 +52,7 @@ class TestToolTransformService:
user_id="test_user_id",
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
elif provider_type == "builtin":
@ -659,7 +659,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
@ -695,7 +695,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)
@ -731,7 +731,7 @@ class TestToolTransformService:
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str="[]",
)

View File

@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService:
# After the fix, this should always be 0
# For now, we document that the record may exist, demonstrating the bug
# assert tool_count == 0 # Expected after fix
def test_delete_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test successful deletion of a workflow tool."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
tool_name = fake.unique.word()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
)
tool = (
db_session_with_containers.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name)
.first()
)
assert tool is not None
result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id)
assert result == {"result": "success"}
deleted = (
db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first()
)
assert deleted is None
def test_list_tenant_workflow_tools_empty(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test listing workflow tools when none exist returns empty list."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id)
assert result == []
def test_get_workflow_tool_by_tool_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that get_workflow_tool_by_tool_id raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4())
def test_get_workflow_tool_by_app_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that get_workflow_tool_by_app_id raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4())
def test_list_single_workflow_tools_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that list_single_workflow_tools raises ValueError when tool not found."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4())
def test_create_workflow_tool_with_labels(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that labels are forwarded to ToolLabelManager when provided."""
fake = Faker()
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.unique.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=self._create_test_workflow_tool_parameters(),
labels=["label-1", "label-2"],
)
assert result == {"result": "success"}
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()

View File

@ -13,6 +13,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask:
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Execute the task with non-existent dataset
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
batch_clean_document_task(
document_ids=[document_id],
dataset_id=dataset_id,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
file_ids=[],
)
# Verify that no index processing occurred
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask:
account = self._create_test_account(db_session_with_containers)
# Test different doc_form types
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
doc_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in doc_forms:
dataset = self._create_test_dataset(db_session_with_containers, account)

View File

@ -19,6 +19,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -179,7 +180,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
)
@ -221,17 +222,17 @@ class TestBatchCreateSegmentToIndexTask:
return upload_file
def _create_test_csv_content(self, content_type="text_model"):
def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX):
"""
Helper method to create test CSV content.
Args:
content_type: Type of content to create ("text_model" or "qa_model")
content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX)
Returns:
str: CSV content as string
"""
if content_type == "qa_model":
if content_type == IndexStructureType.QA_INDEX:
csv_content = "content,answer\n"
csv_content += "This is the first segment content,This is the first answer\n"
csv_content += "This is the second segment content,This is the second answer\n"
@ -264,7 +265,7 @@ class TestBatchCreateSegmentToIndexTask:
upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant)
# Create CSV content
csv_content = self._create_test_csv_content("text_model")
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
# Mock storage to return our CSV content
mock_storage = mock_external_service_dependencies["storage"]
@ -451,7 +452,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Document is disabled
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
# Archived document
@ -467,7 +468,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Document is archived
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
# Document with incomplete indexing
@ -483,7 +484,7 @@ class TestBatchCreateSegmentToIndexTask:
indexing_status=IndexingStatus.INDEXING, # Not completed
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=0,
),
]
@ -655,7 +656,7 @@ class TestBatchCreateSegmentToIndexTask:
db_session_with_containers.commit()
# Create CSV content
csv_content = self._create_test_csv_content("text_model")
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
# Mock storage to return our CSV content
mock_storage = mock_external_service_dependencies["storage"]

View File

@ -18,6 +18,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
@ -192,7 +193,7 @@ class TestCleanDatasetTask:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
word_count=100,
created_at=datetime.now(),
updated_at=datetime.now(),

View File

@ -12,6 +12,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from services.account_service import AccountService, TenantService
@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask:
name=f"Notion Page {i}",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model", # Set doc_form to ensure dataset.doc_form works
doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
)
@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask:
# Test different index types
# Note: Only testing text_model to avoid dependency on external services
index_types = ["text_model"]
index_types = [IndexStructureType.PARAGRAPH_INDEX]
for index_type in index_types:
# Create dataset (doc_form will be set via document creation)

View File

@ -12,6 +12,7 @@ from uuid import uuid4
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask:
enabled=True,
archived=False,
indexing_status=IndexingStatus.COMPLETED,
doc_form="qa_model",
doc_form=IndexStructureType.QA_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask:
enabled=True,
archived=False,
indexing_status=IndexingStatus.COMPLETED,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask:
- Processing completes successfully for different forms
"""
# Arrange: Test different doc_forms
doc_forms = ["qa_model", "text_model", "web_model"]
doc_forms = [
IndexStructureType.QA_INDEX,
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.PARAGRAPH_INDEX,
]
for doc_form in doc_forms:
# Create fresh test data for each form

View File

@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from services.account_service import AccountService, TenantService
@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="parent_child_index",
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="parent_child_index",
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="qa_index",
doc_form=IndexStructureType.QA_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask:
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor was initialized with custom index type
mock_index_processor_factory.assert_called_once_with("qa_index")
mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX)
mock_factory = mock_index_processor_factory.return_value
mock_processor = mock_factory.init_index_processor.return_value
mock_processor.load.assert_called_once()
@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask:
assert updated_document.indexing_status == IndexingStatus.COMPLETED
# Verify index processor was initialized with the document's index type
mock_index_processor_factory.assert_called_once_with("text_model")
mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX)
mock_factory = mock_index_processor_factory.return_value
mock_processor = mock_factory.init_index_processor.return_value
mock_processor.load.assert_called_once()
@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask:
name=f"Test Document {i}",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask:
name="Test Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask:
name="Enabled Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask:
name="Disabled Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # This document should be skipped
@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask:
name="Active Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask:
name="Archived Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask:
name="Document for doc_form",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask:
name="Completed Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask:
name="Incomplete Document",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.INDEXING, # This document should be skipped
enabled=True,

View File

@ -15,6 +15,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask:
dataset: Dataset,
tenant: Tenant,
account: Account,
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
) -> Document:
"""
Helper method to create a test document.
@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask:
- Index processor clean method is called correctly
"""
# Test different document forms
doc_forms = ["text_model", "qa_model", "table_model"]
doc_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in doc_forms:
# Arrange: Create test data for each form

View File

@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Dataset, DocumentSegment
from models import Document as DatasetDocument
from models.dataset import DatasetProcessRule
@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask:
document.indexing_status = "completed"
document.enabled = True
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing
document.doc_language = "en"
db_session_with_containers.add(document)
db_session_with_containers.commit()
@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask:
segment_ids = [segment.id for segment in segments]
# Test different document forms
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
doc_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
]
for doc_form in doc_forms:
# Update document form

View File

@ -14,6 +14,7 @@ from uuid import uuid4
import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
created_by=created_by,
indexing_status=indexing_status,
enabled=True,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
)
db_session_with_containers.add(document)

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
@ -80,7 +81,7 @@ class TestDocumentIndexingUpdateTask:
created_by=account.id,
indexing_status=IndexingStatus.WAITING,
enabled=True,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()

View File

@ -4,6 +4,7 @@ import pytest
from faker import Faker
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexStructureType
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -130,7 +131,7 @@ class TestDuplicateDocumentIndexingTasks:
created_by=account.id,
indexing_status=IndexingStatus.WAITING,
enabled=True,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
documents.append(document)
@ -265,7 +266,7 @@ class TestDuplicateDocumentIndexingTasks:
created_by=account.id,
indexing_status=IndexingStatus.WAITING,
enabled=True,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
documents.append(document)
@ -524,7 +525,7 @@ class TestDuplicateDocumentIndexingTasks:
created_by=dataset.created_by,
indexing_status=IndexingStatus.WAITING,
enabled=True,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
extra_documents.append(document)

View File

@ -281,12 +281,10 @@ class TestSiteEndpoints:
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
)
monkeypatch.setattr(
site_module,
@ -305,12 +303,10 @@ class TestSiteEndpoints:
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
)
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
monkeypatch.setattr(

View File

@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
conversation = SimpleNamespace(id="c1", app_id="app-1")
query = MagicMock()
query.where.return_value = query
query.first.return_value = conversation
session = MagicMock()
session.query.return_value = query
session.scalar.return_value = conversation
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
session = MagicMock()
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)

View File

@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged():
),
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
):
mock_session.query.return_value.where.return_value.first.return_value = conversation
mock_session.scalar.return_value = conversation
_get_conversation(app_model, "conversation-id")

View File

@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
with app.test_request_context(
"/console/api/instruction-generate",
@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
_install_workflow_service(monkeypatch, workflow=None)
with app.test_request_context(
@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
workflow = SimpleNamespace(graph_dict={"nodes": []})
_install_workflow_service(monkeypatch, workflow=workflow)
@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
workflow = SimpleNamespace(
graph_dict={

View File

@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
)
session = MagicMock()
query = MagicMock()
query.where.return_value = query
query.first.return_value = original_config
session.query.return_value = query
session.get.return_value = original_config
monkeypatch.setattr(model_config_module.db, "session", session)
monkeypatch.setattr(

View File

@ -11,10 +11,8 @@ from models.model import AppMode
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
@wraps_module.get_app_model
def handler(app_model):
@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
def handler(app_model):

View File

@ -11,6 +11,7 @@ from controllers.console.datasets.data_source import (
DataSourceNotionDocumentSyncApi,
DataSourceNotionListApi,
)
from core.rag.index_processor.constant.index_type import IndexStructureType
def unwrap(func):
@ -343,7 +344,7 @@ class TestDataSourceNotionApi:
}
],
"process_rule": {"rules": {}},
"doc_form": "text_model",
"doc_form": IndexStructureType.PARAGRAPH_INDEX,
"doc_language": "English",
}

View File

@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import (
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.provider_manager import ProviderManager
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import ApiToken, UploadFile
@ -1146,7 +1147,7 @@ class TestDatasetIndexingEstimateApi:
},
"process_rule": {"chunk_size": 100},
"indexing_technique": "high_quality",
"doc_form": "text_model",
"doc_form": IndexStructureType.PARAGRAPH_INDEX,
"doc_language": "English",
"dataset_id": None,
}

View File

@ -30,6 +30,7 @@ from controllers.console.datasets.error import (
InvalidActionError,
InvalidMetadataError,
)
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.enums import DataSourceType, IndexingStatus
@ -66,7 +67,7 @@ def document():
indexing_status=IndexingStatus.INDEXING,
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info_dict={"upload_file_id": "file-1"},
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
archived=False,
is_paused=False,
dataset_process_rule=None,
@ -765,8 +766,8 @@ class TestDocumentGenerateSummaryApi:
summary_index_setting={"enable": True},
)
doc1 = MagicMock(id="doc-1", doc_form="qa_model")
doc2 = MagicMock(id="doc-2", doc_form="text")
doc1 = MagicMock(id="doc-1", doc_form=IndexStructureType.QA_INDEX)
doc2 = MagicMock(id="doc-2", doc_form=IndexStructureType.PARAGRAPH_INDEX)
payload = {"document_list": ["doc-1", "doc-2"]}
@ -822,7 +823,7 @@ class TestDocumentIndexingEstimateApi:
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info_dict={"upload_file_id": "file-1"},
tenant_id="tenant-1",
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
dataset_process_rule=None,
)
@ -849,7 +850,7 @@ class TestDocumentIndexingEstimateApi:
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info_dict={"upload_file_id": "file-1"},
tenant_id="tenant-1",
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
dataset_process_rule=None,
)
@ -973,7 +974,7 @@ class TestDocumentBatchIndexingEstimateApi:
"mode": "single",
"only_main_content": True,
},
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
with (
@ -1001,7 +1002,7 @@ class TestDocumentBatchIndexingEstimateApi:
"notion_page_id": "p1",
"type": "page",
},
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
with (
@ -1024,7 +1025,7 @@ class TestDocumentBatchIndexingEstimateApi:
indexing_status=IndexingStatus.INDEXING,
data_source_type="unknown",
data_source_info_dict={},
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]):
@ -1353,7 +1354,7 @@ class TestDocumentIndexingEdgeCases:
data_source_type=DataSourceType.UPLOAD_FILE,
data_source_info_dict={"upload_file_id": "file-1"},
tenant_id="tenant-1",
doc_form="text",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
dataset_process_rule=None,
)

View File

@ -24,6 +24,7 @@ from controllers.console.datasets.error import (
InvalidActionError,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
@ -366,7 +367,7 @@ class TestDatasetDocumentSegmentAddApi:
dataset.indexing_technique = "economy"
document = MagicMock()
document.doc_form = "text"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
segment = MagicMock()
segment.id = "seg-1"
@ -505,7 +506,7 @@ class TestDatasetDocumentSegmentUpdateApi:
dataset.indexing_technique = "economy"
document = MagicMock()
document.doc_form = "text"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
segment = MagicMock()

View File

@ -8,6 +8,7 @@ from controllers.console.apikey import (
BaseApiKeyResource,
_get_resource,
)
from models.enums import ApiTokenType
@pytest.fixture
@ -45,14 +46,14 @@ def bypass_permissions():
class DummyApiKeyListResource(BaseApiKeyListResource):
resource_type = "app"
resource_type = ApiTokenType.APP
resource_model = MagicMock()
resource_id_field = "app_id"
token_prefix = "app-"
class DummyApiKeyResource(BaseApiKeyResource):
resource_type = "app"
resource_type = ApiTokenType.APP
resource_model = MagicMock()
resource_id_field = "app_id"

View File

@ -12,6 +12,7 @@ from unittest.mock import Mock
import pytest
from flask import Flask
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import TenantStatus
from models.model import App, AppMode, EndUser
from tests.unit_tests.conftest import setup_mock_tenant_account_query
@ -175,7 +176,7 @@ def mock_document():
document.name = "test_document.txt"
document.indexing_status = "completed"
document.enabled = True
document.doc_form = "text_model"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
return document

View File

@ -31,6 +31,7 @@ from controllers.service_api.dataset.segment import (
SegmentCreatePayload,
SegmentListQuery,
)
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
from models.enums import IndexingStatus
from services.dataset_service import DocumentService, SegmentService
@ -788,7 +789,7 @@ class TestSegmentApiGet:
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock(doc_form="text_model")
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
mock_marshal.return_value = [{"id": mock_segment.id}]
@ -903,7 +904,7 @@ class TestSegmentApiPost:
mock_doc = Mock()
mock_doc.indexing_status = "completed"
mock_doc.enabled = True
mock_doc.doc_form = "text_model"
mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.segment_create_args_validate.return_value = None
@ -1091,7 +1092,7 @@ class TestDatasetSegmentApiDelete:
mock_doc = Mock()
mock_doc.indexing_status = "completed"
mock_doc.enabled = True
mock_doc.doc_form = "text_model"
mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = None # Segment not found
@ -1371,7 +1372,7 @@ class TestDatasetSegmentApiGetSingle:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock(doc_form="text_model")
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_marshal.return_value = {"id": mock_segment.id}
@ -1390,7 +1391,7 @@ class TestDatasetSegmentApiGetSingle:
assert status == 200
assert "data" in response
assert response["doc_form"] == "text_model"
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
@patch("controllers.service_api.dataset.segment.db")

View File

@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import (
InvalidMetadataError,
)
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.enums import IndexingStatus
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel
@ -52,7 +53,7 @@ class TestDocumentTextCreatePayload:
def test_payload_with_defaults(self):
"""Test payload default values."""
payload = DocumentTextCreatePayload(name="Doc", text="Content")
assert payload.doc_form == "text_model"
assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX
assert payload.doc_language == "English"
assert payload.process_rule is None
assert payload.indexing_technique is None
@ -62,14 +63,14 @@ class TestDocumentTextCreatePayload:
payload = DocumentTextCreatePayload(
name="Full Document",
text="Complete document content here",
doc_form="qa_model",
doc_form=IndexStructureType.QA_INDEX,
doc_language="Chinese",
indexing_technique="high_quality",
embedding_model="text-embedding-ada-002",
embedding_model_provider="openai",
)
assert payload.name == "Full Document"
assert payload.doc_form == "qa_model"
assert payload.doc_form == IndexStructureType.QA_INDEX
assert payload.doc_language == "Chinese"
assert payload.indexing_technique == "high_quality"
assert payload.embedding_model == "text-embedding-ada-002"
@ -147,8 +148,8 @@ class TestDocumentTextUpdate:
def test_payload_with_doc_form_update(self):
"""Test payload with doc_form update."""
payload = DocumentTextUpdate(doc_form="qa_model")
assert payload.doc_form == "qa_model"
payload = DocumentTextUpdate(doc_form=IndexStructureType.QA_INDEX)
assert payload.doc_form == IndexStructureType.QA_INDEX
def test_payload_with_language_update(self):
"""Test payload with doc_language update."""
@ -158,7 +159,7 @@ class TestDocumentTextUpdate:
def test_payload_default_values(self):
"""Test payload default values."""
payload = DocumentTextUpdate()
assert payload.doc_form == "text_model"
assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX
assert payload.doc_language == "English"
@ -272,14 +273,24 @@ class TestDocumentDocForm:
def test_text_model_form(self):
"""Test text_model form."""
doc_form = "text_model"
valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"]
doc_form = IndexStructureType.PARAGRAPH_INDEX
valid_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
"parent_child_model",
]
assert doc_form in valid_forms
def test_qa_model_form(self):
"""Test qa_model form."""
doc_form = "qa_model"
valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"]
doc_form = IndexStructureType.QA_INDEX
valid_forms = [
IndexStructureType.PARAGRAPH_INDEX,
IndexStructureType.QA_INDEX,
IndexStructureType.PARENT_CHILD_INDEX,
"parent_child_model",
]
assert doc_form in valid_forms
@ -504,7 +515,7 @@ class TestDocumentApiGet:
doc.name = "test_document.txt"
doc.indexing_status = "completed"
doc.enabled = True
doc.doc_form = "text_model"
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
doc.doc_language = "English"
doc.doc_type = "book"
doc.doc_metadata_details = {"source": "upload"}

View File

@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
from core.app.entities.task_entities import MessageEndStreamResponse
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from dify_graph.file.enums import FileTransferMethod
from dify_graph.file.enums import FileTransferMethod, FileType
from models.model import MessageFile, UploadFile
@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles:
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
message_file.upload_file_id = str(uuid.uuid4())
message_file.url = None
message_file.type = "image"
message_file.type = FileType.IMAGE
return message_file
@pytest.fixture
@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles:
message_file.transfer_method = FileTransferMethod.REMOTE_URL
message_file.upload_file_id = None
message_file.url = "https://example.com/image.jpg"
message_file.type = "image"
message_file.type = FileType.IMAGE
return message_file
@pytest.fixture
@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles:
message_file.transfer_method = FileTransferMethod.TOOL_FILE
message_file.upload_file_id = None
message_file.url = "tool_file_123.png"
message_file.type = "image"
message_file.type = FileType.IMAGE
return message_file
@pytest.fixture

View File

@ -4800,8 +4800,8 @@ class TestInternalHooksCoverage:
dataset_docs = [
SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX),
SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX),
SimpleNamespace(id="doc-c", doc_form="qa_model"),
SimpleNamespace(id="doc-d", doc_form="qa_model"),
SimpleNamespace(id="doc-c", doc_form=IndexStructureType.QA_INDEX),
SimpleNamespace(id="doc-d", doc_form=IndexStructureType.QA_INDEX),
]
child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")]
segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")]

View File

@ -238,7 +238,7 @@ class TestApiToolProviderValidation:
name=provider_name,
icon='{"type": "emoji", "value": "🔧"}',
schema=schema,
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Custom API for testing",
tools_str=json.dumps(tools),
credentials_str=json.dumps(credentials),
@ -249,7 +249,7 @@ class TestApiToolProviderValidation:
assert api_provider.user_id == user_id
assert api_provider.name == provider_name
assert api_provider.schema == schema
assert api_provider.schema_type_str == "openapi"
assert api_provider.schema_type_str == ApiProviderSchemaType.OPENAPI
assert api_provider.description == "Custom API for testing"
def test_api_tool_provider_schema_type_property(self):
@ -261,7 +261,7 @@ class TestApiToolProviderValidation:
name="Test API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Test",
tools_str="[]",
credentials_str="{}",
@ -314,7 +314,7 @@ class TestApiToolProviderValidation:
name="Weather API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Weather API",
tools_str=json.dumps(tools_data),
credentials_str="{}",
@ -343,7 +343,7 @@ class TestApiToolProviderValidation:
name="Secure API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Secure API",
tools_str="[]",
credentials_str=json.dumps(credentials_data),
@ -369,7 +369,7 @@ class TestApiToolProviderValidation:
name="Privacy API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="API with privacy policy",
tools_str="[]",
credentials_str="{}",
@ -391,7 +391,7 @@ class TestApiToolProviderValidation:
name="Disclaimer API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="API with disclaimer",
tools_str="[]",
credentials_str="{}",
@ -410,7 +410,7 @@ class TestApiToolProviderValidation:
name="Default API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="API",
tools_str="[]",
credentials_str="{}",
@ -432,7 +432,7 @@ class TestApiToolProviderValidation:
name=provider_name,
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Unique API",
tools_str="[]",
credentials_str="{}",
@ -454,7 +454,7 @@ class TestApiToolProviderValidation:
name="Public API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Public API with no auth",
tools_str="[]",
credentials_str=json.dumps(credentials),
@ -479,7 +479,7 @@ class TestApiToolProviderValidation:
name="Query Auth API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="API with query auth",
tools_str="[]",
credentials_str=json.dumps(credentials),
@ -741,7 +741,7 @@ class TestCredentialStorage:
name="Test API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Test",
tools_str="[]",
credentials_str=json.dumps(credentials),
@ -788,7 +788,7 @@ class TestCredentialStorage:
name="Update Test",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Test",
tools_str="[]",
credentials_str=json.dumps(original_credentials),
@ -897,7 +897,7 @@ class TestToolProviderRelationships:
name="User API",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Test",
tools_str="[]",
credentials_str="{}",
@ -931,7 +931,7 @@ class TestToolProviderRelationships:
name="Custom API 1",
icon="{}",
schema="{}",
schema_type_str="openapi",
schema_type_str=ApiProviderSchemaType.OPENAPI,
description="Test",
tools_str="[]",
credentials_str="{}",

View File

@ -13,13 +13,13 @@ class ConcreteApiKeyAuth(ApiKeyAuthBase):
class TestApiKeyAuthBase:
def test_should_store_credentials_on_init(self):
"""Test that credentials are properly stored during initialization"""
credentials = {"api_key": "test_key", "auth_type": "bearer"}
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
auth = ConcreteApiKeyAuth(credentials)
assert auth.credentials == credentials
def test_should_not_instantiate_abstract_class(self):
"""Test that ApiKeyAuthBase cannot be instantiated directly"""
credentials = {"api_key": "test_key"}
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
with pytest.raises(TypeError) as exc_info:
ApiKeyAuthBase(credentials)
@ -29,7 +29,7 @@ class TestApiKeyAuthBase:
def test_should_allow_subclass_implementation(self):
"""Test that subclasses can properly implement the abstract method"""
credentials = {"api_key": "test_key", "auth_type": "bearer"}
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
auth = ConcreteApiKeyAuth(credentials)
# Should not raise any exception

View File

@ -58,7 +58,7 @@ class TestApiKeyAuthFactory:
mock_get_factory.return_value = mock_auth_class
# Act
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}})
result = factory.validate_credentials()
# Assert
@ -75,7 +75,7 @@ class TestApiKeyAuthFactory:
mock_get_factory.return_value = mock_auth_class
# Act & Assert
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}})
with pytest.raises(Exception) as exc_info:
factory.validate_credentials()
assert str(exc_info.value) == "Authentication error"

View File

@ -111,6 +111,7 @@ from unittest.mock import Mock, patch
import pytest
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.rag.index_processor.constant.index_type import IndexStructureType
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset, DatasetProcessRule, Document
from services.dataset_service import DatasetService, DocumentService
@ -188,7 +189,7 @@ class DocumentValidationTestDataFactory:
def create_knowledge_config_mock(
data_source: DataSource | None = None,
process_rule: ProcessRule | None = None,
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
indexing_technique: str = "high_quality",
**kwargs,
) -> Mock:
@ -326,8 +327,8 @@ class TestDatasetServiceCheckDocForm:
- Validation logic works correctly
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model")
doc_form = "text_model"
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
doc_form = IndexStructureType.PARAGRAPH_INDEX
# Act (should not raise)
DatasetService.check_doc_form(dataset, doc_form)
@ -349,7 +350,7 @@ class TestDatasetServiceCheckDocForm:
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=None)
doc_form = "text_model"
doc_form = IndexStructureType.PARAGRAPH_INDEX
# Act (should not raise)
DatasetService.check_doc_form(dataset, doc_form)
@ -370,8 +371,8 @@ class TestDatasetServiceCheckDocForm:
- Error type is correct
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model")
doc_form = "table_model" # Different form
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
doc_form = IndexStructureType.PARENT_CHILD_INDEX # Different form
# Act & Assert
with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"):
@ -390,7 +391,7 @@ class TestDatasetServiceCheckDocForm:
"""
# Arrange
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="knowledge_card")
doc_form = "text_model" # Different form
doc_form = IndexStructureType.PARAGRAPH_INDEX # Different form
# Act & Assert
with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"):

View File

@ -2,8 +2,10 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import Account
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
from models.enums import SegmentType
from services.dataset_service import SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
@ -77,7 +79,7 @@ class SegmentTestDataFactory:
chunk.word_count = word_count
chunk.index_node_id = f"node-{chunk_id}"
chunk.index_node_hash = "hash-123"
chunk.type = "automatic"
chunk.type = SegmentType.AUTOMATIC
chunk.created_by = "user-123"
chunk.updated_by = None
chunk.updated_at = None
@ -90,7 +92,7 @@ class SegmentTestDataFactory:
document_id: str = "doc-123",
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
word_count: int = 100,
**kwargs,
) -> Mock:
@ -209,7 +211,7 @@ class TestSegmentServiceCreateSegment:
def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user):
"""Test creation of segment with QA model (requires answer)."""
# Arrange
document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100)
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]}
@ -428,7 +430,7 @@ class TestSegmentServiceUpdateSegment:
"""Test update segment with QA model (includes answer)."""
# Arrange
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100)
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"])

View File

@ -4,6 +4,7 @@ from unittest.mock import Mock, create_autospec
import pytest
from redis.exceptions import LockNotOwnedError
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.account import Account
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService, SegmentService
@ -76,7 +77,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned(
info_list = types.SimpleNamespace(data_source_type="upload_file")
data_source = types.SimpleNamespace(info_list=info_list)
knowledge_config = types.SimpleNamespace(
doc_form="qa_model",
doc_form=IndexStructureType.QA_INDEX,
original_document_id=None, # go into "new document" branch
data_source=data_source,
indexing_technique="high_quality",
@ -131,7 +132,7 @@ def test_add_segment_ignores_lock_not_owned(
document.id = "doc-1"
document.dataset_id = dataset.id
document.word_count = 0
document.doc_form = "qa_model"
document.doc_form = IndexStructureType.QA_INDEX
# Minimal args required by add_segment
args = {
@ -174,4 +175,4 @@ def test_multi_create_segment_ignores_lock_not_owned(
document.id = "doc-1"
document.dataset_id = dataset.id
document.word_count = 0
document.doc_form = "qa_model"
document.doc_form = IndexStructureType.QA_INDEX

View File

@ -1,8 +0,0 @@
from services.dataset_service import DocumentService
def test_normalize_display_status_alias_mapping():
assert DocumentService.normalize_display_status("ACTIVE") == "available"
assert DocumentService.normalize_display_status("enabled") == "available"
assert DocumentService.normalize_display_status("archived") == "archived"
assert DocumentService.normalize_display_status("unknown") is None

View File

@ -1,224 +0,0 @@
from __future__ import annotations
import uuid
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import BadRequest
from services.oauth_server import (
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
OAUTH_ACCESS_TOKEN_REDIS_KEY,
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
OAUTH_REFRESH_TOKEN_REDIS_KEY,
OAuthGrantType,
OAuthServerService,
)
@pytest.fixture
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
return mocker.patch("services.oauth_server.redis_client")
@pytest.fixture
def mock_session(mocker: MockerFixture) -> MagicMock:
"""Mock the OAuth server Session context manager."""
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
mocker.patch("services.oauth_server.Session", return_value=session_cm)
return session
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
# Arrange
mock_execute_result = MagicMock()
expected_app = MagicMock()
mock_execute_result.scalar_one_or_none.return_value = expected_app
mock_session.execute.return_value = mock_execute_result
# Act
result = OAuthServerService.get_oauth_provider_app("client-1")
# Assert
assert result is expected_app
mock_session.execute.assert_called_once()
mock_execute_result.scalar_one_or_none.assert_called_once()
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
# Act
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
# Assert
expected_code = str(deterministic_uuid)
assert code == expected_code
mock_redis_client.set.assert_called_once_with(
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
"user-1",
ex=600,
)
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act + Assert
with pytest.raises(BadRequest, match="invalid code"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="bad-code",
client_id="client-1",
)
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
token_uuids = [
uuid.UUID("00000000-0000-0000-0000-000000000201"),
uuid.UUID("00000000-0000-0000-0000-000000000202"),
]
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
mock_redis_client.get.return_value = b"user-1"
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
# Act
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="code-1",
client_id="client-1",
)
# Assert
assert access_token == str(token_uuids[0])
assert refresh_token == str(token_uuids[1])
mock_redis_client.delete.assert_called_once_with(code_key)
mock_redis_client.set.assert_any_call(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
mock_redis_client.set.assert_any_call(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
b"user-1",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act + Assert
with pytest.raises(BadRequest, match="invalid refresh token"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="stale-token",
client_id="client-1",
)
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
mock_redis_client.get.return_value = b"user-1"
# Act
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="refresh-1",
client_id="client-1",
)
# Assert
assert access_token == str(deterministic_uuid)
assert returned_refresh_token == "refresh-1"
mock_redis_client.set.assert_called_once_with(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
# Arrange
grant_type = cast(OAuthGrantType, "invalid-grant-type")
# Act
result = OAuthServerService.sign_oauth_access_token(
grant_type=grant_type,
client_id="client-1",
)
# Assert
assert result is None
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
# Act
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
# Assert
assert refresh_token == str(deterministic_uuid)
mock_redis_client.set.assert_called_once_with(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
"user-2",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
# Assert
assert result is None
def test_validate_oauth_access_token_should_load_user_when_token_exists(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
mock_redis_client.get.return_value = b"user-88"
expected_user = MagicMock()
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
# Act
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
# Assert
assert result is expected_user
mock_load_user.assert_called_once_with("user-88")

View File

@ -11,6 +11,7 @@ from unittest.mock import MagicMock
import pytest
import services.summary_index_service as summary_module
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.enums import SegmentStatus, SummaryStatus
from services.summary_index_service import SummaryIndexService
@ -48,7 +49,7 @@ def _segment(*, has_document: bool = True) -> MagicMock:
if has_document:
doc = MagicMock(name="document")
doc.doc_language = "en"
doc.doc_form = "text_model"
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
segment.document = doc
else:
segment.document = None
@ -623,13 +624,13 @@ def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.Mon
dataset = _dataset(indexing_technique="economy")
document = MagicMock(spec=summary_module.DatasetDocument)
document.id = "doc-1"
document.doc_form = "text_model"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == []
dataset = _dataset()
assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == []
document.doc_form = "qa_model"
document.doc_form = IndexStructureType.QA_INDEX
assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == []
@ -637,7 +638,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py
dataset = _dataset()
document = MagicMock(spec=summary_module.DatasetDocument)
document.id = "doc-1"
document.doc_form = "text_model"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
seg1 = _segment()
seg2 = _segment()
@ -673,7 +674,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch:
dataset = _dataset()
document = MagicMock(spec=summary_module.DatasetDocument)
document.id = "doc-1"
document.doc_form = "text_model"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
session = MagicMock()
query = MagicMock()
@ -696,7 +697,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
dataset = _dataset()
document = MagicMock(spec=summary_module.DatasetDocument)
document.id = "doc-1"
document.doc_form = "text_model"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
seg = _segment()
session = MagicMock()
@ -935,7 +936,7 @@ def test_update_summary_for_segment_skip_conditions() -> None:
SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None
)
seg = _segment(has_document=True)
seg.document.doc_form = "qa_model"
seg.document.doc_form = IndexStructureType.QA_INDEX
assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None

View File

@ -9,6 +9,7 @@ from unittest.mock import MagicMock
import pytest
import services.vector_service as vector_service_module
from core.rag.index_processor.constant.index_type import IndexStructureType
from services.vector_service import VectorService
@ -32,7 +33,7 @@ class _ParentDocStub:
def _make_dataset(
*,
indexing_technique: str = "high_quality",
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
tenant_id: str = "tenant-1",
dataset_id: str = "dataset-1",
is_multimodal: bool = False,
@ -106,7 +107,7 @@ def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(mo
factory_instance.init_index_processor.return_value = index_processor
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model")
VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX)
index_processor.load.assert_called_once()
args, kwargs = index_processor.load.call_args
@ -131,7 +132,7 @@ def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monk
factory_instance.init_index_processor.return_value = index_processor
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model")
VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX)
assert index_processor.load.call_count == 2
first_args, first_kwargs = index_processor.load.call_args_list[0]
@ -153,7 +154,7 @@ def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pyte
factory_instance.init_index_processor.return_value = index_processor
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
VectorService.create_segments_vector(None, [], dataset, "text_model")
VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX)
index_processor.load.assert_not_called()
@ -392,7 +393,7 @@ def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkey
def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1")
dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX, tenant_id="tenant-1", dataset_id="dataset-1")
segment = _make_segment(segment_id="seg-1")
dataset_document = MagicMock()
@ -439,7 +440,7 @@ def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch
def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _make_dataset(doc_form="text_model")
dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX)
segment = _make_segment()
dataset_document = MagicMock()
dataset_document.doc_language = "en"

View File

@ -1,259 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models.model import App, EndUser
from services.web_conversation_service import WebConversationService
@pytest.fixture
def app_model() -> App:
return cast(App, SimpleNamespace(id="app-1"))
def _account(**kwargs: Any) -> Account:
return cast(Account, SimpleNamespace(**kwargs))
def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))
def test_pagination_by_last_id_should_raise_error_when_user_is_none(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
# Act + Assert
with pytest.raises(ValueError, match="User is required"):
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=None,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
)
def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_user = _account(id="user-1")
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id="conv-9",
limit=10,
invoke_from=InvokeFrom.WEB_APP,
pinned=None,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] is None
assert call_kwargs["exclude_ids"] is None
assert call_kwargs["last_id"] == "conv-9"
assert call_kwargs["sort_by"] == "-updated_at"
def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-1"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {}))
session.scalars.return_value.all.return_value = ["conv-1", "conv-2"]
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
pinned=True,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] == ["conv-1", "conv-2"]
assert call_kwargs["exclude_ids"] is None
def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
session = MagicMock()
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-1"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
session.scalars.return_value.all.return_value = ["conv-3"]
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
mock_pagination.return_value = MagicMock()
# Act
WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=fake_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
pinned=False,
)
# Assert
call_kwargs = mock_pagination.call_args.kwargs
assert call_kwargs["include_ids"] is None
assert call_kwargs["exclude_ids"] == ["conv-3"]
def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
# Arrange
mock_db = mocker.patch("services.web_conversation_service.db")
mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
# Act
WebConversationService.pin(app_model, "conv-1", None)
# Assert
mock_db.session.add.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_pin_should_return_early_when_conversation_is_already_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-1"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = object()
mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
# Act
WebConversationService.pin(app_model, "conv-1", fake_user)
# Assert
mock_get_conversation.assert_not_called()
mock_db.session.add.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_pin_should_create_pinned_conversation_when_not_already_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_account_cls = type("FakeAccount", (), {})
fake_user = cast(Account, fake_account_cls())
fake_user.id = "account-2"
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_conversation = SimpleNamespace(id="conv-2")
mock_get_conversation = mocker.patch(
"services.web_conversation_service.ConversationService.get_conversation",
return_value=mock_conversation,
)
# Act
WebConversationService.pin(app_model, "conv-2", fake_user)
# Assert
mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user)
added_obj = mock_db.session.add.call_args.args[0]
assert added_obj.app_id == "app-1"
assert added_obj.conversation_id == "conv-2"
assert added_obj.created_by_role == "account"
assert added_obj.created_by == "account-2"
mock_db.session.commit.assert_called_once()
def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
# Arrange
mock_db = mocker.patch("services.web_conversation_service.db")
# Act
WebConversationService.unpin(app_model, "conv-1", None)
# Assert
mock_db.session.delete.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_unpin_should_return_early_when_conversation_is_not_pinned(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-3"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act
WebConversationService.unpin(app_model, "conv-7", fake_user)
# Assert
mock_db.session.delete.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_unpin_should_delete_pinned_conversation_when_exists(
app_model: App,
mocker: MockerFixture,
) -> None:
# Arrange
fake_end_user_cls = type("FakeEndUser", (), {})
fake_user = cast(EndUser, fake_end_user_cls())
fake_user.id = "end-user-4"
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
mock_db = mocker.patch("services.web_conversation_service.db")
pinned_obj = SimpleNamespace(id="pin-1")
mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj
# Act
WebConversationService.unpin(app_model, "conv-8", fake_user)
# Assert
mock_db.session.delete.assert_called_once_with(pinned_obj)
mock_db.session.commit.assert_called_once()

View File

@ -1,643 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.tools.entities.tool_entities import ApiProviderSchemaType
from services.tools.api_tools_manage_service import ApiToolManageService
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.tools.api_tools_manage_service.db")
mocked_db.session = MagicMock()
return mocked_db
def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace:
return SimpleNamespace(operation_id=operation_id)
def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value),
)
# Act
result = ApiToolManageService.parser_api_schema("valid-schema")
# Assert
assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value
assert len(result["credentials_schema"]) == 3
assert "warning" in result
def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=RuntimeError("bad schema"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"):
ApiToolManageService.parser_api_schema("invalid")
def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None:
# Arrange
expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER)
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=expected,
)
extra_info: dict[str, str] = {}
# Act
result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info)
# Assert
assert result == expected
def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=ValueError("parse failed"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema: parse failed"):
ApiToolManageService.convert_schema_to_tool_bundles("schema")
def test_create_api_tool_provider_should_raise_error_when_provider_already_exists(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="provider provider-a already exists"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name=" provider-a ",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
many_tools = [_tool_bundle(str(i)) for i in range(101)]
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=(many_tools, ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="the number of apis should be less than 100"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_create_provider_when_input_is_valid(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
mock_controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=mock_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.encrypt.return_value = {"auth_type": "none"}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
# Act
result = ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=["news"],
)
# Assert
assert result == {"result": "success"}
mock_controller.load_bundled_tools.assert_called_once()
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.get",
return_value=SimpleNamespace(status_code=200, text="schema-content"),
)
mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True})
# Act
result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
# Assert
assert result == {"schema": "schema-content"}
@pytest.mark.parametrize("status_code", [400, 404, 500])
def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid(
status_code: int,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.get",
return_value=SimpleNamespace(status_code=status_code, text="schema-content"),
)
mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger")
# Act + Assert
with pytest.raises(ValueError, match="invalid schema, please check the url you provided"):
ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
mock_logger.exception.assert_called_once()
def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found(
mock_db: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="you have not added provider provider-a"):
ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")])
mock_db.session.query.return_value.where.return_value.first.return_value = provider
controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
return_value=controller,
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"])
mock_convert = mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}],
)
# Act
result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
# Assert
assert result == [{"name": "tool-a"}, {"name": "tool-b"}]
assert mock_convert.call_count == 2
def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found(
mock_db: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="api provider provider-a does not exists"):
ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
original_provider="provider-a",
icon={},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy=None,
custom_disclaimer="custom",
labels=[],
)
def test_update_api_tool_provider_should_raise_error_when_auth_type_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(credentials={}, name="old")
mock_db.session.query.return_value.where.return_value.first.return_value = provider
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
original_provider="provider-a",
icon={},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy=None,
custom_disclaimer="custom",
labels=[],
)
def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(
credentials={"auth_type": "none", "api_key_value": "encrypted-old"},
name="old",
icon="",
schema="",
description="",
schema_type_str="",
tools_str="",
privacy_policy="",
custom_disclaimer="",
credentials_str="",
)
mock_db.session.query.return_value.where.return_value.first.return_value = provider
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=controller,
)
cache = MagicMock()
encrypter = MagicMock()
encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"}
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(encrypter, cache),
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
# Act
result = ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-new",
original_provider="provider-old",
icon={"emoji": "E"},
credentials={"auth_type": "none", "api_key_value": "***"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=["news"],
)
# Assert
assert result == {"result": "success"}
assert provider.name == "provider-new"
assert provider.privacy_policy == "privacy"
assert provider.credentials_str != ""
cache.delete.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="you have not added provider provider-a"):
ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None:
# Arrange
provider = object()
mock_db.session.query.return_value.where.return_value.first.return_value = provider
# Act
result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
# Assert
assert result == {"result": "success"}
mock_db.session.delete.assert_called_once_with(provider)
mock_db.session.commit.assert_called_once()
def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None:
# Arrange
expected = {"provider": "value"}
mock_get = mocker.patch(
"services.tools.api_tools_manage_service.ToolManager.user_get_api_provider",
return_value=expected,
)
# Act
result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a")
# Assert
assert result == expected
mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1")
def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None:
# Arrange
schema_type = "bad-schema-type"
# Act + Assert
with pytest.raises(ValueError, match="invalid schema type"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=schema_type, # type: ignore[arg-type]
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=RuntimeError("invalid"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
# Act + Assert
with pytest.raises(ValueError, match="invalid tool name tool-b"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-b",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_auth_type_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
provider_controller = MagicMock()
tool_obj = MagicMock()
tool_obj.fork_tool_runtime.return_value = tool_obj
tool_obj.validate_credentials.side_effect = ValueError("validation failed")
provider_controller.get_tool.return_value = tool_obj
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=provider_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
mock_encrypter.mask_plugin_credentials.return_value = {}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
# Act
result = ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
# Assert
assert result == {"error": "validation failed"}
def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
provider_controller = MagicMock()
tool_obj = MagicMock()
tool_obj.fork_tool_runtime.return_value = tool_obj
tool_obj.validate_credentials.return_value = {"ok": True}
provider_controller.get_tool.return_value = tool_obj
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=provider_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
mock_encrypter.mask_plugin_credentials.return_value = {}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
# Act
result = ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={"x": "1"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
# Assert
assert result == {"result": {"ok": True}}
def test_list_api_tools_should_return_all_user_providers_with_converted_tools(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_one = SimpleNamespace(name="p1")
provider_two = SimpleNamespace(name="p2")
mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two]
controller_one = MagicMock()
controller_one.get_tools.return_value = ["tool-a"]
controller_two = MagicMock()
controller_two.get_tools.return_value = ["tool-b", "tool-c"]
user_provider_one = SimpleNamespace(labels=[], tools=[])
user_provider_two = SimpleNamespace(labels=[], tools=[])
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
side_effect=[controller_one, controller_two],
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"])
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider",
side_effect=[user_provider_one, user_provider_two],
)
mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider")
mock_convert = mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}],
)
# Act
result = ApiToolManageService.list_api_tools("tenant-1")
# Assert
assert len(result) == 2
assert user_provider_one.tools == [{"name": "tool-a"}]
assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}]
assert mock_convert.call_count == 3

View File

@ -1,955 +0,0 @@
"""
Unit tests for services.tools.workflow_tools_manage_service
Covers WorkflowToolManageService: create, update, list, delete, get, list_single.
"""
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from models.model import App
from models.tools import WorkflowToolProvider
from services.tools import workflow_tools_manage_service
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
# ---------------------------------------------------------------------------
# Shared helpers / fake infrastructure
# ---------------------------------------------------------------------------
class DummyWorkflow:
"""Minimal in-memory Workflow substitute."""
def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None:
self._graph_dict = graph_dict
self.version = version
@property
def graph_dict(self) -> dict:
return self._graph_dict
class FakeQuery:
"""Chainable query object that always returns a fixed result."""
def __init__(self, result: object) -> None:
self._result = result
def where(self, *args: object, **kwargs: object) -> "FakeQuery":
return self
def first(self) -> object:
return self._result
def delete(self) -> int:
return 1
class DummySession:
"""Minimal SQLAlchemy session substitute."""
def __init__(self) -> None:
self.added: list[WorkflowToolProvider] = []
self.committed: bool = False
def __enter__(self) -> "DummySession":
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> bool:
return False
def add(self, obj: WorkflowToolProvider) -> None:
self.added.append(obj)
def begin(self) -> "DummySession":
return self
def commit(self) -> None:
self.committed = True
def _build_parameters() -> list[WorkflowToolParameterConfiguration]:
return [
WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM),
]
def _build_fake_db(
*,
existing_tool: WorkflowToolProvider | None = None,
app: object | None = None,
tool_by_id: WorkflowToolProvider | None = None,
) -> tuple[MagicMock, DummySession]:
"""
Build a fake db object plus a DummySession for Session context-manager.
query(WorkflowToolProvider) returns existing_tool on first call,
then tool_by_id on subsequent calls (or None if not provided).
query(App) returns app.
"""
call_counts: dict[str, int] = {"wftp": 0}
def query(model: type) -> FakeQuery:
if model is WorkflowToolProvider:
call_counts["wftp"] += 1
if call_counts["wftp"] == 1:
return FakeQuery(existing_tool)
return FakeQuery(tool_by_id)
if model is App:
return FakeQuery(app)
return FakeQuery(None)
fake_db = MagicMock()
fake_db.session = SimpleNamespace(query=query, commit=MagicMock())
dummy_session = DummySession()
return fake_db, dummy_session
# ---------------------------------------------------------------------------
# TestCreateWorkflowTool
# ---------------------------------------------------------------------------
class TestCreateWorkflowTool:
"""Tests for WorkflowToolManageService.create_workflow_tool."""
def test_should_raise_when_human_input_nodes_present(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Human-input nodes must be rejected before any provider is created."""
# Arrange
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "n1", "data": {"type": "human-input"}}]})
app = SimpleNamespace(workflow=workflow)
fake_session = SimpleNamespace(query=lambda m: FakeQuery(None) if m is WorkflowToolProvider else FakeQuery(app))
monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session)
mock_from_db = MagicMock()
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
# Act + Assert
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id="user-id",
tenant_id="tenant-id",
workflow_app_id="app-id",
name="tool_name",
label="Tool",
icon={"type": "emoji", "emoji": "🔧"},
description="desc",
parameters=_build_parameters(),
)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
mock_from_db.assert_not_called()
def test_should_raise_when_duplicate_name_or_app_id(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Existing provider with same name or app_id raises ValueError."""
# Arrange
existing = MagicMock(spec=WorkflowToolProvider)
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(existing)),
)
# Act + Assert
with pytest.raises(ValueError, match="already exists"):
WorkflowToolManageService.create_workflow_tool(
user_id="u",
tenant_id="t",
workflow_app_id="app-1",
name="dup",
label="Dup",
icon={},
description="",
parameters=[],
)
def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the referenced App does not exist."""
# Arrange
call_count = {"n": 0}
def query(m: type) -> FakeQuery:
call_count["n"] += 1
if m is WorkflowToolProvider:
return FakeQuery(None)
return FakeQuery(None) # App returns None
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
# Act + Assert
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService.create_workflow_tool(
user_id="u",
tenant_id="t",
workflow_app_id="missing-app",
name="n",
label="L",
icon={},
description="",
parameters=[],
)
def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the App has no attached Workflow."""
# Arrange
app_no_workflow = SimpleNamespace(workflow=None)
def query(m: type) -> FakeQuery:
if m is WorkflowToolProvider:
return FakeQuery(None)
return FakeQuery(app_no_workflow)
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
# Act + Assert
with pytest.raises(ValueError, match="Workflow not found"):
WorkflowToolManageService.create_workflow_tool(
user_id="u",
tenant_id="t",
workflow_app_id="app-id",
name="n",
label="L",
icon={},
description="",
parameters=[],
)
def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Exceptions from WorkflowToolProviderController.from_db are wrapped as ValueError."""
# Arrange
workflow = DummyWorkflow(graph_dict={"nodes": []})
app = SimpleNamespace(workflow=workflow)
def query(m: type) -> FakeQuery:
if m is WorkflowToolProvider:
return FakeQuery(None)
return FakeQuery(app)
fake_db = MagicMock()
fake_db.session = SimpleNamespace(query=query)
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
dummy_session = DummySession()
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
monkeypatch.setattr(
workflow_tools_manage_service.WorkflowToolProviderController,
"from_db",
MagicMock(side_effect=RuntimeError("bad config")),
)
# Act + Assert
with pytest.raises(ValueError, match="bad config"):
WorkflowToolManageService.create_workflow_tool(
user_id="u",
tenant_id="t",
workflow_app_id="app-id",
name="n",
label="L",
icon={},
description="",
parameters=[],
)
def test_should_succeed_and_persist_provider(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Happy path: provider is added to session and success dict is returned."""
# Arrange
workflow = DummyWorkflow(graph_dict={"nodes": []}, version="2.0.0")
app = SimpleNamespace(workflow=workflow)
def query(m: type) -> FakeQuery:
if m is WorkflowToolProvider:
return FakeQuery(None)
return FakeQuery(app)
fake_db = MagicMock()
fake_db.session = SimpleNamespace(query=query)
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
dummy_session = DummySession()
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock())
icon = {"type": "emoji", "emoji": "🔧"}
# Act
result = WorkflowToolManageService.create_workflow_tool(
user_id="user-id",
tenant_id="tenant-id",
workflow_app_id="app-id",
name="tool_name",
label="Tool",
icon=icon,
description="desc",
parameters=_build_parameters(),
)
# Assert
assert result == {"result": "success"}
assert len(dummy_session.added) == 1
created: WorkflowToolProvider = dummy_session.added[0]
assert created.name == "tool_name"
assert created.label == "Tool"
assert created.icon == json.dumps(icon)
assert created.version == "2.0.0"
def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Labels are forwarded to ToolLabelManager when provided."""
# Arrange
workflow = DummyWorkflow(graph_dict={"nodes": []})
app = SimpleNamespace(workflow=workflow)
def query(m: type) -> FakeQuery:
if m is WorkflowToolProvider:
return FakeQuery(None)
return FakeQuery(app)
fake_db = MagicMock()
fake_db.session = SimpleNamespace(query=query)
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
dummy_session = DummySession()
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock())
mock_label_mgr = MagicMock()
monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr)
mock_to_ctrl = MagicMock()
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", mock_to_ctrl
)
# Act
WorkflowToolManageService.create_workflow_tool(
user_id="u",
tenant_id="t",
workflow_app_id="app-id",
name="n",
label="L",
icon={},
description="",
parameters=[],
labels=["tag1", "tag2"],
)
# Assert
mock_label_mgr.assert_called_once()
# ---------------------------------------------------------------------------
# TestUpdateWorkflowTool
# ---------------------------------------------------------------------------
class TestUpdateWorkflowTool:
"""Tests for WorkflowToolManageService.update_workflow_tool."""
def _make_provider(self) -> WorkflowToolProvider:
p = MagicMock(spec=WorkflowToolProvider)
p.app_id = "app-id"
p.tenant_id = "tenant-id"
return p
def test_should_raise_when_name_duplicated(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""If another tool with the given name already exists, raise ValueError."""
# Arrange
existing = MagicMock(spec=WorkflowToolProvider)
def query(m: type) -> FakeQuery:
return FakeQuery(existing)
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
# Act + Assert
with pytest.raises(ValueError, match="already exists"):
WorkflowToolManageService.update_workflow_tool(
user_id="u",
tenant_id="t",
workflow_tool_id="tool-1",
name="dup",
label="L",
icon={},
description="",
parameters=[],
)
def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the workflow tool to update does not exist."""
# Arrange
call_count = {"n": 0}
def query(m: type) -> FakeQuery:
call_count["n"] += 1
# 1st call: name uniqueness check → None (no duplicate)
# 2nd call: fetch tool by id → None (not found)
return FakeQuery(None)
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
# Act + Assert
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService.update_workflow_tool(
user_id="u",
tenant_id="t",
workflow_tool_id="missing",
name="n",
label="L",
icon={},
description="",
parameters=[],
)
def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the tool's referenced App has been removed."""
# Arrange
provider = self._make_provider()
call_count = {"n": 0}
def query(m: type) -> FakeQuery:
call_count["n"] += 1
if m is WorkflowToolProvider:
# 1st: duplicate name check (None), 2nd: fetch provider
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
return FakeQuery(None) # App not found
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
# Act + Assert
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService.update_workflow_tool(
user_id="u",
tenant_id="t",
workflow_tool_id="tool-1",
name="n",
label="L",
icon={},
description="",
parameters=[],
)
def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the App exists but has no Workflow."""
# Arrange
provider = self._make_provider()
app_no_wf = SimpleNamespace(workflow=None)
call_count = {"n": 0}
def query(m: type) -> FakeQuery:
call_count["n"] += 1
if m is WorkflowToolProvider:
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
return FakeQuery(app_no_wf)
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
# Act + Assert
with pytest.raises(ValueError, match="Workflow not found"):
WorkflowToolManageService.update_workflow_tool(
user_id="u",
tenant_id="t",
workflow_tool_id="tool-1",
name="n",
label="L",
icon={},
description="",
parameters=[],
)
def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Exceptions from from_db are re-raised as ValueError."""
# Arrange
provider = self._make_provider()
workflow = DummyWorkflow(graph_dict={"nodes": []})
app = SimpleNamespace(workflow=workflow)
call_count = {"n": 0}
def query(m: type) -> FakeQuery:
call_count["n"] += 1
if m is WorkflowToolProvider:
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
return FakeQuery(app)
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=query, commit=MagicMock()),
)
monkeypatch.setattr(
workflow_tools_manage_service.WorkflowToolProviderController,
"from_db",
MagicMock(side_effect=RuntimeError("from_db error")),
)
# Act + Assert
with pytest.raises(ValueError, match="from_db error"):
WorkflowToolManageService.update_workflow_tool(
user_id="u",
tenant_id="t",
workflow_tool_id="tool-1",
name="n",
label="L",
icon={},
description="",
parameters=[],
)
def test_should_succeed_and_call_commit(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Happy path: provider fields are updated and session committed."""
# Arrange
provider = self._make_provider()
workflow = DummyWorkflow(graph_dict={"nodes": []}, version="3.0.0")
app = SimpleNamespace(workflow=workflow)
call_count = {"n": 0}
def query(m: type) -> FakeQuery:
call_count["n"] += 1
if m is WorkflowToolProvider:
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
return FakeQuery(app)
mock_commit = MagicMock()
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=query, commit=mock_commit),
)
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock())
icon = {"type": "emoji", "emoji": "🛠"}
# Act
result = WorkflowToolManageService.update_workflow_tool(
user_id="u",
tenant_id="t",
workflow_tool_id="tool-1",
name="new_name",
label="New Label",
icon=icon,
description="new desc",
parameters=_build_parameters(),
)
# Assert
assert result == {"result": "success"}
mock_commit.assert_called_once()
assert provider.name == "new_name"
assert provider.label == "New Label"
assert provider.icon == json.dumps(icon)
assert provider.version == "3.0.0"
def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Labels are forwarded to ToolLabelManager during update."""
# Arrange
provider = self._make_provider()
workflow = DummyWorkflow(graph_dict={"nodes": []})
app = SimpleNamespace(workflow=workflow)
call_count = {"n": 0}
def query(m: type) -> FakeQuery:
call_count["n"] += 1
if m is WorkflowToolProvider:
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
return FakeQuery(app)
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=query, commit=MagicMock()),
)
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock())
mock_label_mgr = MagicMock()
monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr)
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", MagicMock()
)
# Act
WorkflowToolManageService.update_workflow_tool(
user_id="u",
tenant_id="t",
workflow_tool_id="tool-1",
name="n",
label="L",
icon={},
description="",
parameters=[],
labels=["a"],
)
# Assert
mock_label_mgr.assert_called_once()
# ---------------------------------------------------------------------------
# TestListTenantWorkflowTools
# ---------------------------------------------------------------------------
class TestListTenantWorkflowTools:
"""Tests for WorkflowToolManageService.list_tenant_workflow_tools."""
def test_should_return_empty_list_when_no_tools(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""An empty database yields an empty result list."""
# Arrange
fake_scalars = MagicMock()
fake_scalars.all.return_value = []
fake_db = MagicMock()
fake_db.session.scalars.return_value = fake_scalars
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
# Act
result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t")
# Assert
assert result == []
def test_should_skip_broken_providers_and_log(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Providers that fail to load are logged and skipped."""
# Arrange
good_provider = MagicMock(spec=WorkflowToolProvider)
good_provider.id = "good-id"
good_provider.app_id = "app-good"
bad_provider = MagicMock(spec=WorkflowToolProvider)
bad_provider.id = "bad-id"
bad_provider.app_id = "app-bad"
fake_scalars = MagicMock()
fake_scalars.all.return_value = [good_provider, bad_provider]
fake_db = MagicMock()
fake_db.session.scalars.return_value = fake_scalars
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
good_ctrl = MagicMock()
good_ctrl.provider_id = "good-id"
def to_controller(provider: WorkflowToolProvider) -> MagicMock:
if provider is bad_provider:
raise RuntimeError("broken provider")
return good_ctrl
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", to_controller
)
mock_get_labels = MagicMock(return_value={})
monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", mock_get_labels)
mock_to_user = MagicMock()
mock_to_user.return_value.tools = []
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_user_provider", mock_to_user
)
monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock())
mock_get_tools = MagicMock(return_value=[MagicMock()])
good_ctrl.get_tools = mock_get_tools
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock()
)
# Act
result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t")
# Assert - only good provider contributed
assert len(result) == 1
def test_should_return_tools_for_all_providers(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""All successfully loaded providers appear in the result."""
# Arrange
provider = MagicMock(spec=WorkflowToolProvider)
provider.id = "p-1"
provider.app_id = "app-1"
fake_scalars = MagicMock()
fake_scalars.all.return_value = [provider]
fake_db = MagicMock()
fake_db.session.scalars.return_value = fake_scalars
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
ctrl = MagicMock()
ctrl.provider_id = "p-1"
ctrl.get_tools.return_value = [MagicMock()]
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService,
"workflow_provider_to_controller",
MagicMock(return_value=ctrl),
)
monkeypatch.setattr(
workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", MagicMock(return_value={"p-1": []})
)
user_provider = MagicMock()
user_provider.tools = []
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService,
"workflow_provider_to_user_provider",
MagicMock(return_value=user_provider),
)
monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock())
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock()
)
# Act
result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t")
# Assert
assert len(result) == 1
assert result[0] is user_provider
# ---------------------------------------------------------------------------
# TestDeleteWorkflowTool
# ---------------------------------------------------------------------------
class TestDeleteWorkflowTool:
"""Tests for WorkflowToolManageService.delete_workflow_tool."""
def test_should_delete_and_commit(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""delete_workflow_tool queries, deletes, commits, and returns success."""
# Arrange
mock_query = MagicMock()
mock_query.where.return_value.delete.return_value = 1
mock_commit = MagicMock()
fake_session = SimpleNamespace(query=lambda m: mock_query, commit=mock_commit)
monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session)
# Act
result = WorkflowToolManageService.delete_workflow_tool("u", "t", "tool-1")
# Assert
assert result == {"result": "success"}
mock_commit.assert_called_once()
# ---------------------------------------------------------------------------
# TestGetWorkflowToolByToolId / ByAppId
# ---------------------------------------------------------------------------
class TestGetWorkflowToolByToolIdAndAppId:
"""Tests for get_workflow_tool_by_tool_id and get_workflow_tool_by_app_id."""
def test_get_by_tool_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Raises ValueError when no WorkflowToolProvider found by tool id."""
# Arrange
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(None)),
)
# Act + Assert
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService.get_workflow_tool_by_tool_id("u", "t", "missing")
def test_get_by_app_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Raises ValueError when no WorkflowToolProvider found by app id."""
# Arrange
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(None)),
)
# Act + Assert
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService.get_workflow_tool_by_app_id("u", "t", "missing-app")
# ---------------------------------------------------------------------------
# TestGetWorkflowTool (private _get_workflow_tool)
# ---------------------------------------------------------------------------
class TestGetWorkflowTool:
"""Tests for the internal _get_workflow_tool helper."""
def test_should_raise_when_db_tool_none(self) -> None:
"""_get_workflow_tool raises ValueError when db_tool is None."""
with pytest.raises(ValueError, match="Tool not found"):
WorkflowToolManageService._get_workflow_tool("t", None)
def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the corresponding App row is missing."""
# Arrange
db_tool = MagicMock(spec=WorkflowToolProvider)
db_tool.app_id = "app-1"
db_tool.tenant_id = "t"
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(None)),
)
# Act + Assert
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService._get_workflow_tool("t", db_tool)
def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when App has no attached Workflow."""
# Arrange
db_tool = MagicMock(spec=WorkflowToolProvider)
db_tool.app_id = "app-1"
db_tool.tenant_id = "t"
app = SimpleNamespace(workflow=None)
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(app)),
)
# Act + Assert
with pytest.raises(ValueError, match="Workflow not found"):
WorkflowToolManageService._get_workflow_tool("t", db_tool)
def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the controller returns no WorkflowTool instances."""
# Arrange
db_tool = MagicMock(spec=WorkflowToolProvider)
db_tool.app_id = "app-1"
db_tool.tenant_id = "t"
db_tool.id = "tool-1"
workflow = DummyWorkflow(graph_dict={"nodes": []})
app = SimpleNamespace(workflow=workflow)
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(app)),
)
ctrl = MagicMock()
ctrl.get_tools.return_value = []
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService,
"workflow_provider_to_controller",
MagicMock(return_value=ctrl),
)
# Act + Assert
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService._get_workflow_tool("t", db_tool)
def test_should_return_dict_on_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Happy path: returns a dict with name, label, icon, synced, etc."""
# Arrange
db_tool = MagicMock(spec=WorkflowToolProvider)
db_tool.app_id = "app-1"
db_tool.tenant_id = "t"
db_tool.id = "tool-1"
db_tool.name = "my_tool"
db_tool.label = "My Tool"
db_tool.icon = json.dumps({"emoji": "🔧"})
db_tool.description = "some desc"
db_tool.privacy_policy = ""
db_tool.version = "1.0"
db_tool.parameter_configurations = []
workflow = DummyWorkflow(graph_dict={"nodes": []}, version="1.0")
app = SimpleNamespace(workflow=workflow)
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(app)),
)
workflow_tool = MagicMock()
workflow_tool.entity.output_schema = {"type": "object"}
ctrl = MagicMock()
ctrl.get_tools.return_value = [workflow_tool]
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService,
"workflow_provider_to_controller",
MagicMock(return_value=ctrl),
)
mock_convert = MagicMock(return_value={"tool": "api_entity"})
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", mock_convert
)
monkeypatch.setattr(
workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[])
)
# Act
result = WorkflowToolManageService._get_workflow_tool("t", db_tool)
# Assert
assert result["name"] == "my_tool"
assert result["label"] == "My Tool"
assert result["synced"] is True
assert "icon" in result
assert "output_schema" in result
# ---------------------------------------------------------------------------
# TestListSingleWorkflowTools
# ---------------------------------------------------------------------------
class TestListSingleWorkflowTools:
"""Tests for WorkflowToolManageService.list_single_workflow_tools."""
def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the specified tool does not exist in DB."""
# Arrange
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(None)),
)
# Act + Assert
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1")
def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""ValueError when the controller yields no tools for the provider."""
# Arrange
db_tool = MagicMock(spec=WorkflowToolProvider)
db_tool.id = "tool-1"
db_tool.tenant_id = "t"
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(db_tool)),
)
ctrl = MagicMock()
ctrl.get_tools.return_value = []
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService,
"workflow_provider_to_controller",
MagicMock(return_value=ctrl),
)
# Act + Assert
with pytest.raises(ValueError, match="not found"):
WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1")
def test_should_return_api_entity_list(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Happy path: returns list with one ToolApiEntity."""
# Arrange
db_tool = MagicMock(spec=WorkflowToolProvider)
db_tool.id = "tool-1"
db_tool.tenant_id = "t"
monkeypatch.setattr(
workflow_tools_manage_service.db,
"session",
SimpleNamespace(query=lambda m: FakeQuery(db_tool)),
)
workflow_tool = MagicMock()
ctrl = MagicMock()
ctrl.get_tools.return_value = [workflow_tool]
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService,
"workflow_provider_to_controller",
MagicMock(return_value=ctrl),
)
api_entity = MagicMock()
monkeypatch.setattr(
workflow_tools_manage_service.ToolTransformService,
"convert_tool_entity_to_api_entity",
MagicMock(return_value=api_entity),
)
monkeypatch.setattr(
workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[])
)
# Act
result = WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1")
# Assert
assert result == [api_entity]

View File

@ -121,6 +121,7 @@ import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment
from services.vector_service import VectorService
@ -151,7 +152,7 @@ class VectorServiceTestDataFactory:
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "tenant-123",
doc_form: str = "text_model",
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
indexing_technique: str = "high_quality",
embedding_model_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
@ -493,7 +494,7 @@ class TestVectorService:
"""
# Arrange
dataset = VectorServiceTestDataFactory.create_dataset_mock(
doc_form="text_model", indexing_technique="high_quality"
doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality"
)
segment = VectorServiceTestDataFactory.create_document_segment_mock()
@ -505,7 +506,7 @@ class TestVectorService:
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor
# Act
VectorService.create_segments_vector(keywords_list, [segment], dataset, "text_model")
VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX)
# Assert
mock_index_processor.load.assert_called_once()
@ -649,7 +650,7 @@ class TestVectorService:
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor
# Act
VectorService.create_segments_vector(None, [], dataset, "text_model")
VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX)
# Assert
mock_index_processor.load.assert_not_called()

View File

@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.enums import DataSourceType
from tasks.clean_dataset_task import clean_dataset_task
@ -186,7 +187,7 @@ class TestErrorHandling:
indexing_technique="high_quality",
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
# Assert
@ -231,7 +232,7 @@ class TestPipelineAndWorkflowDeletion:
indexing_technique="high_quality",
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
pipeline_id=pipeline_id,
)
@ -267,7 +268,7 @@ class TestPipelineAndWorkflowDeletion:
indexing_technique="high_quality",
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
pipeline_id=None,
)
@ -323,7 +324,7 @@ class TestSegmentAttachmentCleanup:
indexing_technique="high_quality",
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
# Assert
@ -368,7 +369,7 @@ class TestSegmentAttachmentCleanup:
indexing_technique="high_quality",
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
# Assert - storage delete was attempted
@ -410,7 +411,7 @@ class TestEdgeCases:
indexing_technique="high_quality",
index_struct='{"type": "paragraph"}',
collection_binding_id=collection_binding_id,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
# Assert
@ -454,7 +455,7 @@ class TestIndexProcessorParameters:
indexing_technique=indexing_technique,
index_struct=index_struct,
collection_binding_id=collection_binding_id,
doc_form="paragraph_index",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
# Assert

View File

@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
@ -222,7 +223,7 @@ def mock_documents(document_ids, dataset_id):
doc.stopped_at = None
doc.processing_started_at = None
# optional attribute used in some code paths
doc.doc_form = "text_model"
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
documents.append(doc)
return documents

View File

@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document
from tasks.document_indexing_sync_task import document_indexing_sync_task
@ -62,7 +63,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id,
document.tenant_id = str(uuid.uuid4())
document.data_source_type = "notion_import"
document.indexing_status = "completed"
document.doc_form = "text_model"
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
document.data_source_info_dict = {
"notion_workspace_id": notion_workspace_id,
"notion_page_id": notion_page_id,

View File

@ -69,6 +69,7 @@
},
"pnpm": {
"overrides": {
"flatted@<=3.4.1": "3.4.2",
"rollup@>=4.0.0,<4.59.0": "4.59.0"
}
}

View File

@ -5,6 +5,7 @@ settings:
excludeLinksFromLockfile: false
overrides:
flatted@<=3.4.1: 3.4.2
rollup@>=4.0.0,<4.59.0: 4.59.0
importers:
@ -324,66 +325,79 @@ packages:
resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==}
cpu: [arm]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-arm-musleabihf@4.59.0':
resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==}
cpu: [arm]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-arm64-gnu@4.59.0':
resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==}
cpu: [arm64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-arm64-musl@4.59.0':
resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==}
cpu: [arm64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-loong64-gnu@4.59.0':
resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==}
cpu: [loong64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-loong64-musl@4.59.0':
resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==}
cpu: [loong64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-ppc64-gnu@4.59.0':
resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==}
cpu: [ppc64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-ppc64-musl@4.59.0':
resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==}
cpu: [ppc64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-riscv64-gnu@4.59.0':
resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==}
cpu: [riscv64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-riscv64-musl@4.59.0':
resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==}
cpu: [riscv64]
os: [linux]
libc: [musl]
'@rollup/rollup-linux-s390x-gnu@4.59.0':
resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==}
cpu: [s390x]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-x64-gnu@4.59.0':
resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==}
cpu: [x64]
os: [linux]
libc: [glibc]
'@rollup/rollup-linux-x64-musl@4.59.0':
resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==}
cpu: [x64]
os: [linux]
libc: [musl]
'@rollup/rollup-openbsd-x64@4.59.0':
resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==}
@ -741,8 +755,8 @@ packages:
resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==}
engines: {node: '>=16'}
flatted@3.4.1:
resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==}
flatted@3.4.2:
resolution: {integrity: sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==}
follow-redirects@1.15.11:
resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==}
@ -1836,10 +1850,10 @@ snapshots:
flat-cache@4.0.1:
dependencies:
flatted: 3.4.1
flatted: 3.4.2
keyv: 4.5.4
flatted@3.4.1: {}
flatted@3.4.2: {}
follow-redirects@1.15.11: {}

Some files were not shown because too many files have changed in this diff Show More