mirror of https://github.com/langgenius/dify.git
Compare commits
18 Commits
9554980423
...
47a53d0bea
| Author | SHA1 | Date |
|---|---|---|
|
|
47a53d0bea | |
|
|
8b634a9bee | |
|
|
ecd3a964c1 | |
|
|
0589fa423b | |
|
|
27c4faad4f | |
|
|
fbd558762d | |
|
|
075b8bf1ae | |
|
|
49a1fae555 | |
|
|
cc17c8e883 | |
|
|
5d2cb3cd80 | |
|
|
f2c71f3668 | |
|
|
0492ed7034 | |
|
|
dd4f504b39 | |
|
|
75c3ef82d9 | |
|
|
8ca1ebb96d | |
|
|
3f086b97b6 | |
|
|
4a2e9633db | |
|
|
20fc69ae7f |
|
|
@ -4,10 +4,9 @@ runs:
|
|||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
node-version-file: web/.nvmrc
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
cache-dependency-path: web/pnpm-lock.yaml
|
||||
run-install: |
|
||||
cwd: ./web
|
||||
run-install: true
|
||||
|
|
|
|||
|
|
@ -84,20 +84,20 @@ jobs:
|
|||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
vp run lint:ci
|
||||
# pnpm run lint:report
|
||||
# continue-on-error: true
|
||||
|
||||
# - name: Annotate Code
|
||||
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
|
||||
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
|
||||
# with:
|
||||
# eslint-report: web/eslint_report.json
|
||||
# github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
|
@ -114,6 +114,13 @@ jobs:
|
|||
working-directory: ./web
|
||||
run: vp run knip
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ jobs:
|
|||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76
|
||||
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from configs import dify_config
|
|||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
|
|
@ -269,7 +270,7 @@ def migrate_knowledge_vector_database():
|
|||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == "hierarchical_model":
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from extensions.ext_database import db
|
|||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
|
|
@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
|||
class BaseApiKeyListResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
token_prefix: str | None = None
|
||||
|
|
@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource):
|
|||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
assert self.resource_type is not None, "resource_type must be set"
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_tenant_id
|
||||
|
|
@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource):
|
|||
class BaseApiKeyResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
|
|
@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
|
@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource):
|
|||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
|
||||
|
|
@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
token_prefix = "ds-"
|
||||
|
|
@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
|
|
|||
|
|
@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
|
|||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||
)
|
||||
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||
.subquery()
|
||||
)
|
||||
|
|
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
|
|||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
|
|||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
app = db.session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
return server
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
|
|
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
|
|||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
|
|
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
|
|
|||
|
|
@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
|
|||
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
|
||||
if original_app_model_config is None:
|
||||
raise ValueError("Original app model config not found")
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Literal
|
|||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
|
|
@ -75,7 +76,7 @@ class AppSite(Resource):
|
|||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
|
|
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
|||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
|
|
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
|
|||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
|
||||
return app_model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ from fields.document_fields import document_status_fields
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import SegmentStatus
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
|
@ -777,7 +777,7 @@ class DatasetIndexingStatusApi(Resource):
|
|||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = "dataset-"
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
|
|
@ -826,7 +826,7 @@ class DatasetApiKeyApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete dataset API key")
|
||||
|
|
|
|||
|
|
@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
created_from=created_from.value,
|
||||
created_from=created_from,
|
||||
created_by_role=self._created_by_role,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -918,11 +918,11 @@ class ProviderManager:
|
|||
|
||||
trail_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL.value,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
paid_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.PAID.value,
|
||||
pool_type=ProviderQuotaType.PAID,
|
||||
)
|
||||
else:
|
||||
trail_pool = None
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
|
|
@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||
password=new_cluster["password"],
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status="ACTIVE",
|
||||
status=TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from configs import dify_config
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
class TidbService:
|
||||
|
|
@ -170,7 +171,7 @@ class TidbService:
|
|||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = "ACTIVE"
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -43,7 +43,9 @@ from .enums import (
|
|||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
SummaryStatus,
|
||||
TidbAuthBindingStatus,
|
||||
)
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
|
|
@ -494,7 +496,9 @@ class Document(Base):
|
|||
)
|
||||
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_form: Mapped[IndexStructureType] = mapped_column(
|
||||
EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'")
|
||||
)
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
|
|
@ -998,7 +1002,9 @@ class ChildChunk(Base):
|
|||
# indexing fields
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase):
|
|||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
|
||||
status: Mapped[TidbAuthBindingStatus] = mapped_column(
|
||||
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
|
||||
)
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
|
|||
|
|
@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum):
|
|||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentType(StrEnum):
|
||||
"""Document segment type"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOMIZED = "customized"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
|
|
@ -323,3 +330,10 @@ class ProviderQuotaType(StrEnum):
|
|||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ApiTokenType(StrEnum):
|
||||
"""API Token type"""
|
||||
|
||||
APP = "app"
|
||||
DATASET = "dataset"
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from configs import dify_config
|
|||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
|
|
@ -31,6 +31,7 @@ from .account import Account, Tenant
|
|||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .engine import db
|
||||
from .enums import (
|
||||
ApiTokenType,
|
||||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
|
|
@ -43,6 +44,7 @@ from .enums import (
|
|||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
ProviderQuotaType,
|
||||
TagType,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
|
|
@ -1783,7 +1785,7 @@ class MessageFile(TypeBase):
|
|||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(
|
||||
EnumText(FileTransferMethod, length=255), nullable=False
|
||||
)
|
||||
|
|
@ -2095,7 +2097,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
|
|||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=True)
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_used_at = mapped_column(sa.DateTime, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -2490,7 +2492,9 @@ class TenantCreditPool(TypeBase):
|
|||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
||||
pool_type: Mapped[ProviderQuotaType] = mapped_column(
|
||||
EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial"
|
||||
)
|
||||
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
|
|||
|
|
@ -145,7 +145,9 @@ class ApiToolProvider(TypeBase):
|
|||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original schema
|
||||
schema: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column(
|
||||
EnumText(ApiProviderSchemaType, length=40), nullable=False
|
||||
)
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
|
|
|
|||
|
|
@ -1221,7 +1221,9 @@ class WorkflowAppLog(TypeBase):
|
|||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column(
|
||||
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False
|
||||
)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
@ -1301,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase):
|
|||
|
||||
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column(
|
||||
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True
|
||||
)
|
||||
|
||||
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[WorkflowExecutionStatus] = mapped_column(
|
||||
EnumText(WorkflowExecutionStatus, length=255), nullable=False
|
||||
)
|
||||
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
|
||||
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from configs import dify_config
|
|||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
|
|
@ -57,7 +58,7 @@ def create_clusters(batch_size):
|
|||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
active=False,
|
||||
status="CREATING",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
db.session.add(tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from configs import dify_config
|
|||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
|
|
@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
|
|||
try:
|
||||
# check the number of idle tidb serverless
|
||||
tidb_serverless_list = db.session.scalars(
|
||||
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
||||
select(TidbAuthBinding).where(
|
||||
TidbAuthBinding.active == False,
|
||||
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
).all()
|
||||
if len(tidb_serverless_list) == 0:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,8 +1,16 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class AuthCredentials(TypedDict):
|
||||
auth_type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
def __init__(self, provider: str, credentials: AuthCredentials):
|
||||
auth_factory = self.get_apikey_auth_factory(provider)
|
||||
self.auth = auth_factory(credentials)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import json
|
|||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import json
|
|||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import json
|
|||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ from urllib.parse import urljoin
|
|||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from configs import dify_config
|
|||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -16,7 +17,10 @@ class CreditPoolService:
|
|||
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""create default credit pool for new tenant"""
|
||||
credit_pool = TenantCreditPool(
|
||||
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
||||
tenant_id=tenant_id,
|
||||
quota_limit=dify_config.HOSTED_POOL_CREDITS,
|
||||
quota_used=0,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
db.session.add(credit_pool)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ from models.enums import (
|
|||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
|
@ -1439,7 +1440,7 @@ class DocumentService:
|
|||
.filter(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != "qa_model", # Skip qa_model documents
|
||||
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
)
|
||||
|
|
@ -2039,7 +2040,7 @@ class DocumentService:
|
|||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_form = IndexStructureType(knowledge_config.doc_form)
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
|
|
@ -2639,7 +2640,7 @@ class DocumentService:
|
|||
document.splitting_completed_at = None
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = document_data.doc_form
|
||||
document.doc_form = IndexStructureType(document_data.doc_form)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# update document segment
|
||||
|
|
@ -3100,7 +3101,7 @@ class DocumentService:
|
|||
class SegmentService:
|
||||
@classmethod
|
||||
def segment_create_args_validate(cls, args: dict, document: Document):
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
if "answer" not in args or not args["answer"]:
|
||||
raise ValueError("Answer is required")
|
||||
if not args["answer"].strip():
|
||||
|
|
@ -3157,7 +3158,7 @@ class SegmentService:
|
|||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment_document.word_count += len(args["answer"])
|
||||
segment_document.answer = args["answer"]
|
||||
|
||||
|
|
@ -3231,7 +3232,7 @@ class SegmentService:
|
|||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[content + segment_item["answer"]]
|
||||
)[0]
|
||||
|
|
@ -3254,7 +3255,7 @@ class SegmentService:
|
|||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment_document.answer = segment_item["answer"]
|
||||
segment_document.word_count += len(segment_item["answer"])
|
||||
increment_word_count += segment_document.word_count
|
||||
|
|
@ -3321,7 +3322,7 @@ class SegmentService:
|
|||
content = args.content or segment.content
|
||||
if segment.content == content:
|
||||
segment.word_count = len(content)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
|
|
@ -3418,7 +3419,7 @@ class SegmentService:
|
|||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
|
||||
else:
|
||||
|
|
@ -3435,7 +3436,7 @@ class SegmentService:
|
|||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
|
|
@ -3786,7 +3787,7 @@ class SegmentService:
|
|||
child_chunk.word_count = len(child_chunk.content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = naive_utc_now()
|
||||
child_chunk.type = "customized"
|
||||
child_chunk.type = SegmentType.CUSTOMIZED
|
||||
update_child_chunks.append(child_chunk)
|
||||
else:
|
||||
new_child_chunks_args.append(child_chunk_update_args)
|
||||
|
|
@ -3845,7 +3846,7 @@ class SegmentService:
|
|||
child_chunk.word_count = len(content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = naive_utc_now()
|
||||
child_chunk.type = "customized"
|
||||
child_chunk.type = SegmentType.CUSTOMIZED
|
||||
db.session.add(child_chunk)
|
||||
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from flask_login import current_user
|
|||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
|
|
@ -79,9 +80,9 @@ class RagPipelineTransformService:
|
|||
pipeline = self._create_pipeline(pipeline_yaml)
|
||||
|
||||
# save chunk structure to dataset
|
||||
if doc_form == "hierarchical_model":
|
||||
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
dataset.chunk_structure = "hierarchical_model"
|
||||
elif doc_form == "text_model":
|
||||
elif doc_form == IndexStructureType.PARAGRAPH_INDEX:
|
||||
dataset.chunk_structure = "text_model"
|
||||
else:
|
||||
raise ValueError("Unsupported doc form")
|
||||
|
|
@ -101,7 +102,7 @@ class RagPipelineTransformService:
|
|||
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
|
||||
pipeline_yaml = {}
|
||||
if doc_form == "text_model":
|
||||
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
|
||||
match datasource_type:
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if indexing_technique == "high_quality":
|
||||
|
|
@ -132,7 +133,7 @@ class RagPipelineTransformService:
|
|||
pipeline_yaml = yaml.safe_load(f)
|
||||
case _:
|
||||
raise ValueError("Unsupported datasource type")
|
||||
elif doc_form == "hierarchical_model":
|
||||
elif doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
match datasource_type:
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
# get graph from transform.file-parentchild.yml
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from sqlalchemy import func
|
|||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
|
@ -109,7 +110,7 @@ def batch_create_segment_to_index_task(
|
|||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
|
|
@ -159,7 +160,7 @@ def batch_create_segment_to_index_task(
|
|||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from configs import dify_config
|
|||
from core.db.session_factory import session_factory
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
)
|
||||
if (
|
||||
document.indexing_status == IndexingStatus.COMPLETED
|
||||
and document.doc_form != "qa_model"
|
||||
and document.doc_form != IndexStructureType.QA_INDEX
|
||||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from celery import shared_task
|
|||
from sqlalchemy import or_, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
|
@ -106,7 +107,7 @@ def regenerate_summary_index_task(
|
|||
),
|
||||
DatasetDocument.enabled == True, # Document must be enabled
|
||||
DatasetDocument.archived == False, # Document must not be archived
|
||||
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
|
||||
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
|
||||
.all()
|
||||
|
|
@ -209,7 +210,7 @@ def regenerate_summary_index_task(
|
|||
|
||||
for dataset_document in dataset_documents:
|
||||
# Skip qa_model documents
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
if dataset_document.doc_form == IndexStructureType.QA_INDEX:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ def _record_trigger_failure_log(
|
|||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken
|
||||
from services.api_token_service import ApiTokenCache, CachedApiToken
|
||||
|
||||
|
|
@ -279,7 +280,7 @@ class TestEndToEndCacheFlow:
|
|||
test_token = ApiToken()
|
||||
test_token.id = "test-e2e-id"
|
||||
test_token.token = test_token_value
|
||||
test_token.type = test_scope
|
||||
test_token.type = ApiTokenType.APP
|
||||
test_token.app_id = "test-app"
|
||||
test_token.tenant_id = "test-tenant"
|
||||
test_token.last_used_at = None
|
||||
|
|
|
|||
|
|
@ -1,17 +1,10 @@
|
|||
"""
|
||||
Test suite for password reset authentication flows.
|
||||
"""Testcontainers integration tests for password reset authentication flows."""
|
||||
|
||||
This module tests the password reset mechanism including:
|
||||
- Password reset email sending
|
||||
- Verification code validation
|
||||
- Password reset with token
|
||||
- Rate limiting and security checks
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
|
|
@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import (
|
|||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_session():
|
||||
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_session_cls.return_value.__exit__.return_value = None
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_db():
|
||||
with patch("controllers.console.auth.forgot_password.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
|
|
@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
|
|
@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi:
|
|||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test successful password reset email sending.
|
||||
|
||||
Verifies that:
|
||||
- Email is sent to valid account
|
||||
- Reset token is generated and returned
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "reset_token_123"
|
||||
|
|
@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi:
|
|||
assert response["data"] == "reset_token_123"
|
||||
mock_send_email.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
|
||||
"""
|
||||
Test password reset email blocked by IP rate limit.
|
||||
|
||||
|
|
@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
- No email is sent when rate limited
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
(None, "en-US"), # Defaults to en-US when not provided
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
|
|
@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
language_input,
|
||||
|
|
@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
- Unsupported languages default to en-US
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
|
|
@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi:
|
|||
"""Test cases for verifying password reset codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
|
|
@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""
|
||||
|
|
@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Rate limit is reset on success
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_generate_token.return_value = (None, "new_token")
|
||||
|
|
@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi:
|
|||
)
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
|
|
@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
|
@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token.assert_called_once_with("upper_token")
|
||||
mock_reset_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification blocked by rate limit.
|
||||
|
||||
|
|
@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Prevents brute force attacks on verification codes
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(EmailPasswordResetLimitError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with invalid token.
|
||||
|
||||
|
|
@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi:
|
|||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = None
|
||||
|
||||
|
|
@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with mismatched email.
|
||||
|
||||
|
|
@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Prevents token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
|
|
@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(InvalidEmailError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with incorrect code.
|
||||
|
||||
|
|
@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Rate limit counter is incremented
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
|
|
@ -380,11 +321,8 @@ class TestForgotPasswordResetApi:
|
|||
"""Test cases for resetting password with verified token."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
|
|
@ -394,7 +332,6 @@ class TestForgotPasswordResetApi:
|
|||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
|
|
@ -405,7 +342,6 @@ class TestForgotPasswordResetApi:
|
|||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
|
|
@ -418,7 +354,6 @@ class TestForgotPasswordResetApi:
|
|||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
|
@ -436,9 +371,8 @@ class TestForgotPasswordResetApi:
|
|||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with("valid_token")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_mismatch(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with mismatched passwords.
|
||||
|
||||
|
|
@ -447,7 +381,6 @@ class TestForgotPasswordResetApi:
|
|||
- No password update occurs
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -460,9 +393,8 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(PasswordMismatchError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_invalid_token(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with invalid token.
|
||||
|
||||
|
|
@ -470,7 +402,6 @@ class TestForgotPasswordResetApi:
|
|||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -483,9 +414,8 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with token not in reset phase.
|
||||
|
||||
|
|
@ -494,7 +424,6 @@ class TestForgotPasswordResetApi:
|
|||
- Prevents use of verification-phase tokens for reset
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -507,13 +436,10 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_reset_password_account_not_found(
|
||||
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
|
||||
):
|
||||
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
|
||||
|
|
@ -521,7 +447,6 @@ class TestForgotPasswordResetApi:
|
|||
- AccountNotFound is raised when account doesn't exist
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = None
|
||||
|
||||
|
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
|
|
@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
name=f"Document {i}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Archived Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # Archived
|
||||
|
|
@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Disabled Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # Disabled
|
||||
archived=False,
|
||||
|
|
@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {status}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=status, # Not completed
|
||||
enabled=True,
|
||||
archived=False,
|
||||
|
|
@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document for {dataset.name}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
|
|
@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
|
|
@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration:
|
|||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from models.human_input import (
|
|||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
|
|
@ -218,7 +218,7 @@ class TestDeleteRunsWithRelated:
|
|||
app_id=test_scope.app_id,
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
|
|
@ -278,7 +278,7 @@ class TestCountRunsWithRelated:
|
|||
app_id=test_scope.app_id,
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
|
|
@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory:
|
|||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
document.id = document_id
|
||||
document.indexing_status = indexing_status
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ class TestCreditPoolService:
|
|||
|
||||
assert isinstance(pool, TenantCreditPool)
|
||||
assert pool.tenant_id == tenant_id
|
||||
assert pool.pool_type == "trial"
|
||||
assert pool.pool_type == ProviderQuotaType.TRIAL
|
||||
assert pool.quota_used == 0
|
||||
assert pool.quota_limit > 0
|
||||
|
||||
|
|
@ -28,14 +29,14 @@ class TestCreditPoolService:
|
|||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial")
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert result is not None
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.pool_type == "trial"
|
||||
assert result.pool_type == ProviderQuotaType.TRIAL
|
||||
|
||||
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial")
|
||||
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from uuid import uuid4
|
|||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
|
@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from uuid import uuid4
|
|||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DocumentService
|
||||
|
|
@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
|||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by or str(uuid4()),
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
document.id = document_id or str(uuid4())
|
||||
document.enabled = enabled
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
|
|
@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory:
|
|||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
created_by: str,
|
||||
doc_form: str = "text_model",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
) -> Document:
|
||||
"""Persist a document so dataset.doc_form resolves through the real document path."""
|
||||
document = Document(
|
||||
|
|
@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset:
|
|||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=owner.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
# Act
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from uuid import uuid4
|
|||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from services.dataset_service import DocumentService
|
||||
|
|
@ -42,7 +43,7 @@ def _create_document(
|
|||
name=f"doc-{uuid4()}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=str(uuid4()),
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
document.id = str(uuid4())
|
||||
document.indexing_status = indexing_status
|
||||
|
|
@ -142,3 +143,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c
|
|||
|
||||
rows = db_session_with_containers.scalars(filtered).all()
|
||||
assert {row.id for row in rows} == {doc1.id, doc2.id}
|
||||
|
||||
|
||||
def test_normalize_display_status_alias_mapping():
|
||||
"""Test that normalize_display_status maps aliases correctly."""
|
||||
assert DocumentService.normalize_display_status("ACTIVE") == "available"
|
||||
assert DocumentService.normalize_display_status("enabled") == "available"
|
||||
assert DocumentService.normalize_display_status("archived") == "archived"
|
||||
assert DocumentService.normalize_display_status("unknown") is None
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
|
|
@ -69,7 +70,7 @@ def make_document(
|
|||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=str(uuid4()),
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
doc.id = document_id
|
||||
doc.indexing_status = "completed"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.file.enums import FileType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
|
@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
|
|||
# MessageFile
|
||||
file = MessageFile(
|
||||
message_id=message.id,
|
||||
type="image",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method="local_file",
|
||||
url="http://example.com/test.jpg",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from faker import Faker
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
|
||||
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
|
||||
|
|
@ -139,7 +140,7 @@ class TestMetadataService:
|
|||
name=fake.file_name(),
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,174 @@
|
|||
"""Testcontainers integration tests for OAuthServerService."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerServiceGetProviderApp:
|
||||
"""DB-backed tests for get_oauth_provider_app."""
|
||||
|
||||
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
|
||||
app = OAuthProviderApp(
|
||||
app_icon="icon.png",
|
||||
client_id=client_id,
|
||||
client_secret=str(uuid4()),
|
||||
app_label={"en-US": "Test OAuth App"},
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
scope="read",
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
|
||||
client_id = f"client-{uuid4()}"
|
||||
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
|
||||
|
||||
result = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.client_id == client_id
|
||||
assert result.id == created.id
|
||||
|
||||
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestOAuthServerServiceTokenOperations:
|
||||
"""Redis-backed tests for token sign/validate operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
with patch("services.oauth_server.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
assert code == str(deterministic_uuid)
|
||||
mock_redis.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
|
||||
mock_redis.get.return_value = b"user-1"
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
mock_redis.delete.assert_called_once_with(code_key)
|
||||
mock_redis.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
mock_redis.get.return_value = b"user-1"
|
||||
|
||||
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh == "refresh-1"
|
||||
|
||||
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
|
||||
mock_redis.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
|
||||
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
assert result is expected_user
|
||||
mock_load.assert_called_once_with("user-88")
|
||||
|
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
|||
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowAppLogCreatedFrom
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
|
|
@ -221,7 +222,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -357,7 +358,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_1.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -399,7 +400,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_2.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -441,7 +442,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_4.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -521,7 +522,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -627,7 +628,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -732,7 +733,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -860,7 +861,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -902,7 +903,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="web-app",
|
||||
created_from=WorkflowAppLogCreatedFrom.WEB_APP,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=end_user.id,
|
||||
)
|
||||
|
|
@ -1037,7 +1038,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -1125,7 +1126,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -1279,7 +1280,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -1379,7 +1380,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
@ -1481,7 +1482,7 @@ class TestWorkflowAppService:
|
|||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -536,3 +536,151 @@ class TestApiToolManageService:
|
|||
# Verify mock interactions
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
|
||||
def test_delete_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test successful deletion of an API tool provider."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
schema = self._create_test_openapi_schema()
|
||||
provider_name = fake.unique.word()
|
||||
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
provider = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
deleted = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
assert deleted is None
|
||||
|
||||
def test_delete_api_tool_provider_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test deletion raises ValueError when provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_update_api_tool_provider_not_found(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test update raises ValueError when original provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name="new-name",
|
||||
original_provider="nonexistent",
|
||||
icon={},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
def test_update_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test update raises ValueError when auth_type is missing from credentials."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
schema = self._create_test_openapi_schema()
|
||||
provider_name = fake.unique.word()
|
||||
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
original_provider=provider_name,
|
||||
icon={},
|
||||
credentials={},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
def test_list_api_tool_provider_tools_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test listing tools raises ValueError when provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_test_api_tool_preview_invalid_schema_type(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test preview raises ValueError for invalid schema type."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="invalid schema type"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id=tenant.id,
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type="bad-schema-type",
|
||||
schema="schema",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
|
@ -52,7 +52,7 @@ class TestToolTransformService:
|
|||
user_id="test_user_id",
|
||||
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str="[]",
|
||||
)
|
||||
elif provider_type == "builtin":
|
||||
|
|
@ -659,7 +659,7 @@ class TestToolTransformService:
|
|||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
|
|
@ -695,7 +695,7 @@ class TestToolTransformService:
|
|||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
|
|
@ -731,7 +731,7 @@ class TestToolTransformService:
|
|||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService:
|
|||
# After the fix, this should always be 0
|
||||
# For now, we document that the record may exist, demonstrating the bug
|
||||
# assert tool_count == 0 # Expected after fix
|
||||
|
||||
def test_delete_workflow_tool_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test successful deletion of a workflow tool."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
tool_name = fake.unique.word()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=self._create_test_workflow_tool_parameters(),
|
||||
)
|
||||
|
||||
tool = (
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name)
|
||||
.first()
|
||||
)
|
||||
assert tool is not None
|
||||
|
||||
result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
deleted = (
|
||||
db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first()
|
||||
)
|
||||
assert deleted is None
|
||||
|
||||
def test_list_tenant_workflow_tools_empty(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test listing workflow tools when none exist returns empty list."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_get_workflow_tool_by_tool_id_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that get_workflow_tool_by_tool_id raises ValueError when tool not found."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Tool not found"):
|
||||
WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4())
|
||||
|
||||
def test_get_workflow_tool_by_app_id_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that get_workflow_tool_by_app_id raises ValueError when tool not found."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Tool not found"):
|
||||
WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4())
|
||||
|
||||
def test_list_single_workflow_tools_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that list_single_workflow_tools raises ValueError when tool not found."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4())
|
||||
|
||||
def test_create_workflow_tool_with_labels(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that labels are forwarded to ToolLabelManager when provided."""
|
||||
fake = Faker()
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
result = WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.unique.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=self._create_test_workflow_tool_parameters(),
|
||||
labels=["label-1", "label-2"],
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
|
@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(document)
|
||||
|
|
@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask:
|
|||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the task with non-existent dataset
|
||||
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
|
||||
batch_clean_document_task(
|
||||
document_ids=[document_id],
|
||||
dataset_id=dataset_id,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
file_ids=[],
|
||||
)
|
||||
|
||||
# Verify that no index processing occurred
|
||||
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
|
||||
|
|
@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask:
|
|||
account = self._create_test_account(db_session_with_containers)
|
||||
|
||||
# Test different doc_form types
|
||||
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
|
||||
doc_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
|
@ -179,7 +180,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=0,
|
||||
)
|
||||
|
||||
|
|
@ -221,17 +222,17 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
|
||||
return upload_file
|
||||
|
||||
def _create_test_csv_content(self, content_type="text_model"):
|
||||
def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX):
|
||||
"""
|
||||
Helper method to create test CSV content.
|
||||
|
||||
Args:
|
||||
content_type: Type of content to create ("text_model" or "qa_model")
|
||||
content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX)
|
||||
|
||||
Returns:
|
||||
str: CSV content as string
|
||||
"""
|
||||
if content_type == "qa_model":
|
||||
if content_type == IndexStructureType.QA_INDEX:
|
||||
csv_content = "content,answer\n"
|
||||
csv_content += "This is the first segment content,This is the first answer\n"
|
||||
csv_content += "This is the second segment content,This is the second answer\n"
|
||||
|
|
@ -264,7 +265,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant)
|
||||
|
||||
# Create CSV content
|
||||
csv_content = self._create_test_csv_content("text_model")
|
||||
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
# Mock storage to return our CSV content
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
|
@ -451,7 +452,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # Document is disabled
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=0,
|
||||
),
|
||||
# Archived document
|
||||
|
|
@ -467,7 +468,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # Document is archived
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=0,
|
||||
),
|
||||
# Document with incomplete indexing
|
||||
|
|
@ -483,7 +484,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
indexing_status=IndexingStatus.INDEXING, # Not completed
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=0,
|
||||
),
|
||||
]
|
||||
|
|
@ -655,7 +656,7 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
db_session_with_containers.commit()
|
||||
|
||||
# Create CSV content
|
||||
csv_content = self._create_test_csv_content("text_model")
|
||||
csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
# Mock storage to return our CSV content
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
|
|
@ -192,7 +193,7 @@ class TestCleanDatasetTask:
|
|||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count=100,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
|
@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask:
|
|||
name=f"Notion Page {i}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model", # Set doc_form to ensure dataset.doc_form works
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
|
@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask:
|
|||
|
||||
# Test different index types
|
||||
# Note: Only testing text_model to avoid dependency on external services
|
||||
index_types = ["text_model"]
|
||||
index_types = [IndexStructureType.PARAGRAPH_INDEX]
|
||||
|
||||
for index_type in index_types:
|
||||
# Create dataset (doc_form will be set via document creation)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from uuid import uuid4
|
|||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
|
@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask:
|
|||
enabled=True,
|
||||
archived=False,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="qa_model",
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask:
|
|||
enabled=True,
|
||||
archived=False,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask:
|
|||
- Processing completes successfully for different forms
|
||||
"""
|
||||
# Arrange: Test different doc_forms
|
||||
doc_forms = ["qa_model", "text_model", "web_model"]
|
||||
doc_forms = [
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
# Create fresh test data for each form
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch
|
|||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
|
@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="parent_child_index",
|
||||
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="parent_child_index",
|
||||
doc_form=IndexStructureType.PARENT_CHILD_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="qa_index",
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor was initialized with custom index type
|
||||
mock_index_processor_factory.assert_called_once_with("qa_index")
|
||||
mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX)
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
mock_processor = mock_factory.init_index_processor.return_value
|
||||
mock_processor.load.assert_called_once()
|
||||
|
|
@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
assert updated_document.indexing_status == IndexingStatus.COMPLETED
|
||||
|
||||
# Verify index processor was initialized with the document's index type
|
||||
mock_index_processor_factory.assert_called_once_with("text_model")
|
||||
mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_factory = mock_index_processor_factory.return_value
|
||||
mock_processor = mock_factory.init_index_processor.return_value
|
||||
mock_processor.load.assert_called_once()
|
||||
|
|
@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name=f"Test Document {i}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Test Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Enabled Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Disabled Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # This document should be skipped
|
||||
|
|
@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Active Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Archived Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Document for doc_form",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Completed Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask:
|
|||
name="Incomplete Document",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.INDEXING, # This document should be skipped
|
||||
enabled=True,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
|
@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask:
|
|||
dataset: Dataset,
|
||||
tenant: Tenant,
|
||||
account: Account,
|
||||
doc_form: str = "text_model",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
) -> Document:
|
||||
"""
|
||||
Helper method to create a test document.
|
||||
|
|
@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask:
|
|||
- Index processor clean method is called correctly
|
||||
"""
|
||||
# Test different document forms
|
||||
doc_forms = ["text_model", "qa_model", "table_model"]
|
||||
doc_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
# Arrange: Create test data for each form
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models import Account, Dataset, DocumentSegment
|
||||
from models import Document as DatasetDocument
|
||||
from models.dataset import DatasetProcessRule
|
||||
|
|
@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask:
|
|||
document.indexing_status = "completed"
|
||||
document.enabled = True
|
||||
document.archived = False
|
||||
document.doc_form = "text_model" # Use text_model form for testing
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing
|
||||
document.doc_language = "en"
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
|
@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask:
|
|||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Test different document forms
|
||||
doc_forms = ["text_model", "qa_model", "hierarchical_model"]
|
||||
doc_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
]
|
||||
|
||||
for doc_form in doc_forms:
|
||||
# Update document form
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from uuid import uuid4
|
|||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
|
|
@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory:
|
|||
created_by=created_by,
|
||||
indexing_status=indexing_status,
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
|
|
@ -80,7 +81,7 @@ class TestDocumentIndexingUpdateTask:
|
|||
created_by=account.id,
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
from faker import Faker
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
|
@ -130,7 +131,7 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
created_by=account.id,
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
|
@ -265,7 +266,7 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
created_by=account.id,
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
|
@ -524,7 +525,7 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
created_by=dataset.created_by,
|
||||
indexing_status=IndexingStatus.WAITING,
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
extra_documents.append(document)
|
||||
|
|
|
|||
|
|
@ -281,12 +281,10 @@ class TestSiteEndpoints:
|
|||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
|
|
@ -305,12 +303,10 @@ class TestSiteEndpoints:
|
|||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
|
|
@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
|
|||
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = conversation
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = conversation
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
|
@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No
|
|||
|
||||
|
||||
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged():
|
|||
),
|
||||
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
|
||||
):
|
||||
mock_session.query.return_value.where.return_value.first.return_value = conversation
|
||||
mock_session.scalar.return_value = conversation
|
||||
|
||||
_get_conversation(app_model, "conversation-id")
|
||||
|
||||
|
|
|
|||
|
|
@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
|
|||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
|
|
@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
|||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
|
|
@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
|||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
|
|
@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
|||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
|
|
|
|||
|
|
@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
|
|||
)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = original_config
|
||||
session.query.return_value = query
|
||||
session.get.return_value = original_config
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
|
|
@ -11,10 +11,8 @@ from models.model import AppMode
|
|||
|
||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
|
|
@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
|
||||
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
||||
def handler(app_model):
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from controllers.console.datasets.data_source import (
|
|||
DataSourceNotionDocumentSyncApi,
|
||||
DataSourceNotionListApi,
|
||||
)
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
|
|
@ -343,7 +344,7 @@ class TestDataSourceNotionApi:
|
|||
}
|
||||
],
|
||||
"process_rule": {"rules": {}},
|
||||
"doc_form": "text_model",
|
||||
"doc_form": IndexStructureType.PARAGRAPH_INDEX,
|
||||
"doc_language": "English",
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import (
|
|||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import ApiToken, UploadFile
|
||||
|
|
@ -1146,7 +1147,7 @@ class TestDatasetIndexingEstimateApi:
|
|||
},
|
||||
"process_rule": {"chunk_size": 100},
|
||||
"indexing_technique": "high_quality",
|
||||
"doc_form": "text_model",
|
||||
"doc_form": IndexStructureType.PARAGRAPH_INDEX,
|
||||
"doc_language": "English",
|
||||
"dataset_id": None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from controllers.console.datasets.error import (
|
|||
InvalidActionError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.enums import DataSourceType, IndexingStatus
|
||||
|
||||
|
||||
|
|
@ -66,7 +67,7 @@ def document():
|
|||
indexing_status=IndexingStatus.INDEXING,
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info_dict={"upload_file_id": "file-1"},
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
archived=False,
|
||||
is_paused=False,
|
||||
dataset_process_rule=None,
|
||||
|
|
@ -765,8 +766,8 @@ class TestDocumentGenerateSummaryApi:
|
|||
summary_index_setting={"enable": True},
|
||||
)
|
||||
|
||||
doc1 = MagicMock(id="doc-1", doc_form="qa_model")
|
||||
doc2 = MagicMock(id="doc-2", doc_form="text")
|
||||
doc1 = MagicMock(id="doc-1", doc_form=IndexStructureType.QA_INDEX)
|
||||
doc2 = MagicMock(id="doc-2", doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
payload = {"document_list": ["doc-1", "doc-2"]}
|
||||
|
||||
|
|
@ -822,7 +823,7 @@ class TestDocumentIndexingEstimateApi:
|
|||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info_dict={"upload_file_id": "file-1"},
|
||||
tenant_id="tenant-1",
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
dataset_process_rule=None,
|
||||
)
|
||||
|
||||
|
|
@ -849,7 +850,7 @@ class TestDocumentIndexingEstimateApi:
|
|||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info_dict={"upload_file_id": "file-1"},
|
||||
tenant_id="tenant-1",
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
dataset_process_rule=None,
|
||||
)
|
||||
|
||||
|
|
@ -973,7 +974,7 @@ class TestDocumentBatchIndexingEstimateApi:
|
|||
"mode": "single",
|
||||
"only_main_content": True,
|
||||
},
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
with (
|
||||
|
|
@ -1001,7 +1002,7 @@ class TestDocumentBatchIndexingEstimateApi:
|
|||
"notion_page_id": "p1",
|
||||
"type": "page",
|
||||
},
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
with (
|
||||
|
|
@ -1024,7 +1025,7 @@ class TestDocumentBatchIndexingEstimateApi:
|
|||
indexing_status=IndexingStatus.INDEXING,
|
||||
data_source_type="unknown",
|
||||
data_source_info_dict={},
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]):
|
||||
|
|
@ -1353,7 +1354,7 @@ class TestDocumentIndexingEdgeCases:
|
|||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
data_source_info_dict={"upload_file_id": "file-1"},
|
||||
tenant_id="tenant-1",
|
||||
doc_form="text",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
dataset_process_rule=None,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from controllers.console.datasets.error import (
|
|||
InvalidActionError,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
|
||||
|
|
@ -366,7 +367,7 @@ class TestDatasetDocumentSegmentAddApi:
|
|||
dataset.indexing_technique = "economy"
|
||||
|
||||
document = MagicMock()
|
||||
document.doc_form = "text"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
segment = MagicMock()
|
||||
segment.id = "seg-1"
|
||||
|
|
@ -505,7 +506,7 @@ class TestDatasetDocumentSegmentUpdateApi:
|
|||
dataset.indexing_technique = "economy"
|
||||
|
||||
document = MagicMock()
|
||||
document.doc_form = "text"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
segment = MagicMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from controllers.console.apikey import (
|
|||
BaseApiKeyResource,
|
||||
_get_resource,
|
||||
)
|
||||
from models.enums import ApiTokenType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -45,14 +46,14 @@ def bypass_permissions():
|
|||
|
||||
|
||||
class DummyApiKeyListResource(BaseApiKeyListResource):
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
||||
|
||||
class DummyApiKeyResource(BaseApiKeyResource):
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, AppMode, EndUser
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
|
|
@ -175,7 +176,7 @@ def mock_document():
|
|||
document.name = "test_document.txt"
|
||||
document.indexing_status = "completed"
|
||||
document.enabled = True
|
||||
document.doc_form = "text_model"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
return document
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from controllers.service_api.dataset.segment import (
|
|||
SegmentCreatePayload,
|
||||
SegmentListQuery,
|
||||
)
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
|
||||
from models.enums import IndexingStatus
|
||||
from services.dataset_service import DocumentService, SegmentService
|
||||
|
|
@ -788,7 +789,7 @@ class TestSegmentApiGet:
|
|||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock(doc_form="text_model")
|
||||
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
|
||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
||||
|
||||
|
|
@ -903,7 +904,7 @@ class TestSegmentApiPost:
|
|||
mock_doc = Mock()
|
||||
mock_doc.indexing_status = "completed"
|
||||
mock_doc.enabled = True
|
||||
mock_doc.doc_form = "text_model"
|
||||
mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
|
||||
mock_seg_svc.segment_create_args_validate.return_value = None
|
||||
|
|
@ -1091,7 +1092,7 @@ class TestDatasetSegmentApiDelete:
|
|||
mock_doc = Mock()
|
||||
mock_doc.indexing_status = "completed"
|
||||
mock_doc.enabled = True
|
||||
mock_doc.doc_form = "text_model"
|
||||
mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
|
||||
mock_seg_svc.get_segment_by_id.return_value = None # Segment not found
|
||||
|
|
@ -1371,7 +1372,7 @@ class TestDatasetSegmentApiGetSingle:
|
|||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc = Mock(doc_form="text_model")
|
||||
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
|
|
@ -1390,7 +1391,7 @@ class TestDatasetSegmentApiGetSingle:
|
|||
|
||||
assert status == 200
|
||||
assert "data" in response
|
||||
assert response["doc_form"] == "text_model"
|
||||
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||
@patch("controllers.service_api.dataset.segment.db")
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import (
|
|||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.enums import IndexingStatus
|
||||
from services.dataset_service import DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel
|
||||
|
|
@ -52,7 +53,7 @@ class TestDocumentTextCreatePayload:
|
|||
def test_payload_with_defaults(self):
|
||||
"""Test payload default values."""
|
||||
payload = DocumentTextCreatePayload(name="Doc", text="Content")
|
||||
assert payload.doc_form == "text_model"
|
||||
assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX
|
||||
assert payload.doc_language == "English"
|
||||
assert payload.process_rule is None
|
||||
assert payload.indexing_technique is None
|
||||
|
|
@ -62,14 +63,14 @@ class TestDocumentTextCreatePayload:
|
|||
payload = DocumentTextCreatePayload(
|
||||
name="Full Document",
|
||||
text="Complete document content here",
|
||||
doc_form="qa_model",
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
doc_language="Chinese",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_model_provider="openai",
|
||||
)
|
||||
assert payload.name == "Full Document"
|
||||
assert payload.doc_form == "qa_model"
|
||||
assert payload.doc_form == IndexStructureType.QA_INDEX
|
||||
assert payload.doc_language == "Chinese"
|
||||
assert payload.indexing_technique == "high_quality"
|
||||
assert payload.embedding_model == "text-embedding-ada-002"
|
||||
|
|
@ -147,8 +148,8 @@ class TestDocumentTextUpdate:
|
|||
|
||||
def test_payload_with_doc_form_update(self):
|
||||
"""Test payload with doc_form update."""
|
||||
payload = DocumentTextUpdate(doc_form="qa_model")
|
||||
assert payload.doc_form == "qa_model"
|
||||
payload = DocumentTextUpdate(doc_form=IndexStructureType.QA_INDEX)
|
||||
assert payload.doc_form == IndexStructureType.QA_INDEX
|
||||
|
||||
def test_payload_with_language_update(self):
|
||||
"""Test payload with doc_language update."""
|
||||
|
|
@ -158,7 +159,7 @@ class TestDocumentTextUpdate:
|
|||
def test_payload_default_values(self):
|
||||
"""Test payload default values."""
|
||||
payload = DocumentTextUpdate()
|
||||
assert payload.doc_form == "text_model"
|
||||
assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX
|
||||
assert payload.doc_language == "English"
|
||||
|
||||
|
||||
|
|
@ -272,14 +273,24 @@ class TestDocumentDocForm:
|
|||
|
||||
def test_text_model_form(self):
|
||||
"""Test text_model form."""
|
||||
doc_form = "text_model"
|
||||
valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"]
|
||||
doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
valid_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
"parent_child_model",
|
||||
]
|
||||
assert doc_form in valid_forms
|
||||
|
||||
def test_qa_model_form(self):
|
||||
"""Test qa_model form."""
|
||||
doc_form = "qa_model"
|
||||
valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"]
|
||||
doc_form = IndexStructureType.QA_INDEX
|
||||
valid_forms = [
|
||||
IndexStructureType.PARAGRAPH_INDEX,
|
||||
IndexStructureType.QA_INDEX,
|
||||
IndexStructureType.PARENT_CHILD_INDEX,
|
||||
"parent_child_model",
|
||||
]
|
||||
assert doc_form in valid_forms
|
||||
|
||||
|
||||
|
|
@ -504,7 +515,7 @@ class TestDocumentApiGet:
|
|||
doc.name = "test_document.txt"
|
||||
doc.indexing_status = "completed"
|
||||
doc.enabled = True
|
||||
doc.doc_form = "text_model"
|
||||
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
doc.doc_language = "English"
|
||||
doc.doc_type = "book"
|
||||
doc.doc_metadata_details = {"source": "upload"}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.app.entities.task_entities import MessageEndStreamResponse
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles:
|
|||
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.url = None
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles:
|
|||
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "https://example.com/image.jpg"
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles:
|
|||
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "tool_file_123.png"
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -4800,8 +4800,8 @@ class TestInternalHooksCoverage:
|
|||
dataset_docs = [
|
||||
SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX),
|
||||
SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX),
|
||||
SimpleNamespace(id="doc-c", doc_form="qa_model"),
|
||||
SimpleNamespace(id="doc-d", doc_form="qa_model"),
|
||||
SimpleNamespace(id="doc-c", doc_form=IndexStructureType.QA_INDEX),
|
||||
SimpleNamespace(id="doc-d", doc_form=IndexStructureType.QA_INDEX),
|
||||
]
|
||||
child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")]
|
||||
segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")]
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ class TestApiToolProviderValidation:
|
|||
name=provider_name,
|
||||
icon='{"type": "emoji", "value": "🔧"}',
|
||||
schema=schema,
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Custom API for testing",
|
||||
tools_str=json.dumps(tools),
|
||||
credentials_str=json.dumps(credentials),
|
||||
|
|
@ -249,7 +249,7 @@ class TestApiToolProviderValidation:
|
|||
assert api_provider.user_id == user_id
|
||||
assert api_provider.name == provider_name
|
||||
assert api_provider.schema == schema
|
||||
assert api_provider.schema_type_str == "openapi"
|
||||
assert api_provider.schema_type_str == ApiProviderSchemaType.OPENAPI
|
||||
assert api_provider.description == "Custom API for testing"
|
||||
|
||||
def test_api_tool_provider_schema_type_property(self):
|
||||
|
|
@ -261,7 +261,7 @@ class TestApiToolProviderValidation:
|
|||
name="Test API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
|
|
@ -314,7 +314,7 @@ class TestApiToolProviderValidation:
|
|||
name="Weather API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Weather API",
|
||||
tools_str=json.dumps(tools_data),
|
||||
credentials_str="{}",
|
||||
|
|
@ -343,7 +343,7 @@ class TestApiToolProviderValidation:
|
|||
name="Secure API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Secure API",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials_data),
|
||||
|
|
@ -369,7 +369,7 @@ class TestApiToolProviderValidation:
|
|||
name="Privacy API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="API with privacy policy",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
|
|
@ -391,7 +391,7 @@ class TestApiToolProviderValidation:
|
|||
name="Disclaimer API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="API with disclaimer",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
|
|
@ -410,7 +410,7 @@ class TestApiToolProviderValidation:
|
|||
name="Default API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="API",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
|
|
@ -432,7 +432,7 @@ class TestApiToolProviderValidation:
|
|||
name=provider_name,
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Unique API",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
|
|
@ -454,7 +454,7 @@ class TestApiToolProviderValidation:
|
|||
name="Public API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Public API with no auth",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
|
|
@ -479,7 +479,7 @@ class TestApiToolProviderValidation:
|
|||
name="Query Auth API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="API with query auth",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
|
|
@ -741,7 +741,7 @@ class TestCredentialStorage:
|
|||
name="Test API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
|
|
@ -788,7 +788,7 @@ class TestCredentialStorage:
|
|||
name="Update Test",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(original_credentials),
|
||||
|
|
@ -897,7 +897,7 @@ class TestToolProviderRelationships:
|
|||
name="User API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
|
|
@ -931,7 +931,7 @@ class TestToolProviderRelationships:
|
|||
name="Custom API 1",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI,
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ class ConcreteApiKeyAuth(ApiKeyAuthBase):
|
|||
class TestApiKeyAuthBase:
|
||||
def test_should_store_credentials_on_init(self):
|
||||
"""Test that credentials are properly stored during initialization"""
|
||||
credentials = {"api_key": "test_key", "auth_type": "bearer"}
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
assert auth.credentials == credentials
|
||||
|
||||
def test_should_not_instantiate_abstract_class(self):
|
||||
"""Test that ApiKeyAuthBase cannot be instantiated directly"""
|
||||
credentials = {"api_key": "test_key"}
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
ApiKeyAuthBase(credentials)
|
||||
|
|
@ -29,7 +29,7 @@ class TestApiKeyAuthBase:
|
|||
|
||||
def test_should_allow_subclass_implementation(self):
|
||||
"""Test that subclasses can properly implement the abstract method"""
|
||||
credentials = {"api_key": "test_key", "auth_type": "bearer"}
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_key"}}
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
|
||||
# Should not raise any exception
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class TestApiKeyAuthFactory:
|
|||
mock_get_factory.return_value = mock_auth_class
|
||||
|
||||
# Act
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}})
|
||||
result = factory.validate_credentials()
|
||||
|
||||
# Assert
|
||||
|
|
@ -75,7 +75,7 @@ class TestApiKeyAuthFactory:
|
|||
mock_get_factory.return_value = mock_auth_class
|
||||
|
||||
# Act & Assert
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "test_key"}})
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
factory.validate_credentials()
|
||||
assert str(exc_info.value) == "Authentication error"
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import Dataset, DatasetProcessRule, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
|
@ -188,7 +189,7 @@ class DocumentValidationTestDataFactory:
|
|||
def create_knowledge_config_mock(
|
||||
data_source: DataSource | None = None,
|
||||
process_rule: ProcessRule | None = None,
|
||||
doc_form: str = "text_model",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_technique: str = "high_quality",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
|
|
@ -326,8 +327,8 @@ class TestDatasetServiceCheckDocForm:
|
|||
- Validation logic works correctly
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model")
|
||||
doc_form = "text_model"
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_doc_form(dataset, doc_form)
|
||||
|
|
@ -349,7 +350,7 @@ class TestDatasetServiceCheckDocForm:
|
|||
"""
|
||||
# Arrange
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=None)
|
||||
doc_form = "text_model"
|
||||
doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_doc_form(dataset, doc_form)
|
||||
|
|
@ -370,8 +371,8 @@ class TestDatasetServiceCheckDocForm:
|
|||
- Error type is correct
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="text_model")
|
||||
doc_form = "table_model" # Different form
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
doc_form = IndexStructureType.PARENT_CHILD_INDEX # Different form
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"):
|
||||
|
|
@ -390,7 +391,7 @@ class TestDatasetServiceCheckDocForm:
|
|||
"""
|
||||
# Arrange
|
||||
dataset = DocumentValidationTestDataFactory.create_dataset_mock(doc_form="knowledge_card")
|
||||
doc_form = "text_model" # Different form
|
||||
doc_form = IndexStructureType.PARAGRAPH_INDEX # Different form
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="doc_form is different from the dataset doc_form"):
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ from unittest.mock import MagicMock, Mock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.account import Account
|
||||
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentType
|
||||
from services.dataset_service import SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
|
|
@ -77,7 +79,7 @@ class SegmentTestDataFactory:
|
|||
chunk.word_count = word_count
|
||||
chunk.index_node_id = f"node-{chunk_id}"
|
||||
chunk.index_node_hash = "hash-123"
|
||||
chunk.type = "automatic"
|
||||
chunk.type = SegmentType.AUTOMATIC
|
||||
chunk.created_by = "user-123"
|
||||
chunk.updated_by = None
|
||||
chunk.updated_at = None
|
||||
|
|
@ -90,7 +92,7 @@ class SegmentTestDataFactory:
|
|||
document_id: str = "doc-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
doc_form: str = "text_model",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
word_count: int = 100,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
|
|
@ -209,7 +211,7 @@ class TestSegmentServiceCreateSegment:
|
|||
def test_create_segment_with_qa_model(self, mock_db_session, mock_current_user):
|
||||
"""Test creation of segment with QA model (requires answer)."""
|
||||
# Arrange
|
||||
document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100)
|
||||
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
args = {"content": "What is AI?", "answer": "AI is Artificial Intelligence", "keywords": ["ai"]}
|
||||
|
||||
|
|
@ -428,7 +430,7 @@ class TestSegmentServiceUpdateSegment:
|
|||
"""Test update segment with QA model (includes answer)."""
|
||||
# Arrange
|
||||
segment = SegmentTestDataFactory.create_segment_mock(enabled=True, word_count=10)
|
||||
document = SegmentTestDataFactory.create_document_mock(doc_form="qa_model", word_count=100)
|
||||
document = SegmentTestDataFactory.create_document_mock(doc_form=IndexStructureType.QA_INDEX, word_count=100)
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock(indexing_technique="economy")
|
||||
args = SegmentUpdateArgs(content="Updated question", answer="Updated answer", keywords=["qa"])
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from unittest.mock import Mock, create_autospec
|
|||
import pytest
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from services.dataset_service import DocumentService, SegmentService
|
||||
|
|
@ -76,7 +77,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned(
|
|||
info_list = types.SimpleNamespace(data_source_type="upload_file")
|
||||
data_source = types.SimpleNamespace(info_list=info_list)
|
||||
knowledge_config = types.SimpleNamespace(
|
||||
doc_form="qa_model",
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
original_document_id=None, # go into "new document" branch
|
||||
data_source=data_source,
|
||||
indexing_technique="high_quality",
|
||||
|
|
@ -131,7 +132,7 @@ def test_add_segment_ignores_lock_not_owned(
|
|||
document.id = "doc-1"
|
||||
document.dataset_id = dataset.id
|
||||
document.word_count = 0
|
||||
document.doc_form = "qa_model"
|
||||
document.doc_form = IndexStructureType.QA_INDEX
|
||||
|
||||
# Minimal args required by add_segment
|
||||
args = {
|
||||
|
|
@ -174,4 +175,4 @@ def test_multi_create_segment_ignores_lock_not_owned(
|
|||
document.id = "doc-1"
|
||||
document.dataset_id = dataset.id
|
||||
document.word_count = 0
|
||||
document.doc_form = "qa_model"
|
||||
document.doc_form = IndexStructureType.QA_INDEX
|
||||
|
|
|
|||
|
|
@ -1,8 +0,0 @@
|
|||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
def test_normalize_display_status_alias_mapping():
|
||||
assert DocumentService.normalize_display_status("ACTIVE") == "available"
|
||||
assert DocumentService.normalize_display_status("enabled") == "available"
|
||||
assert DocumentService.normalize_display_status("archived") == "archived"
|
||||
assert DocumentService.normalize_display_status("unknown") is None
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.oauth_server.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mocker: MockerFixture) -> MagicMock:
|
||||
"""Mock the OAuth server Session context manager."""
|
||||
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
mocker.patch("services.oauth_server.Session", return_value=session_cm)
|
||||
return session
|
||||
|
||||
|
||||
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_execute_result = MagicMock()
|
||||
expected_app = MagicMock()
|
||||
mock_execute_result.scalar_one_or_none.return_value = expected_app
|
||||
mock_session.execute.return_value = mock_execute_result
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.get_oauth_provider_app("client-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected_app
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_execute_result.scalar_one_or_none.assert_called_once()
|
||||
|
||||
|
||||
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
# Assert
|
||||
expected_code = str(deterministic_uuid)
|
||||
assert code == expected_code
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
|
||||
# Act
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
mock_redis_client.delete.assert_called_once_with(code_key)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
|
||||
# Act
|
||||
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh_token == "refresh-1"
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
|
||||
# Arrange
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=grant_type,
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
# Assert
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_load_user_when_token_exists(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
# Assert
|
||||
assert result is expected_user
|
||||
mock_load_user.assert_called_once_with("user-88")
|
||||
|
|
@ -11,6 +11,7 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
|
||||
import services.summary_index_service as summary_module
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.enums import SegmentStatus, SummaryStatus
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
|
@ -48,7 +49,7 @@ def _segment(*, has_document: bool = True) -> MagicMock:
|
|||
if has_document:
|
||||
doc = MagicMock(name="document")
|
||||
doc.doc_language = "en"
|
||||
doc.doc_form = "text_model"
|
||||
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
segment.document = doc
|
||||
else:
|
||||
segment.document = None
|
||||
|
|
@ -623,13 +624,13 @@ def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.Mon
|
|||
dataset = _dataset(indexing_technique="economy")
|
||||
document = MagicMock(spec=summary_module.DatasetDocument)
|
||||
document.id = "doc-1"
|
||||
document.doc_form = "text_model"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == []
|
||||
|
||||
dataset = _dataset()
|
||||
assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == []
|
||||
|
||||
document.doc_form = "qa_model"
|
||||
document.doc_form = IndexStructureType.QA_INDEX
|
||||
assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == []
|
||||
|
||||
|
||||
|
|
@ -637,7 +638,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py
|
|||
dataset = _dataset()
|
||||
document = MagicMock(spec=summary_module.DatasetDocument)
|
||||
document.id = "doc-1"
|
||||
document.doc_form = "text_model"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
seg1 = _segment()
|
||||
seg2 = _segment()
|
||||
|
|
@ -673,7 +674,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch:
|
|||
dataset = _dataset()
|
||||
document = MagicMock(spec=summary_module.DatasetDocument)
|
||||
document.id = "doc-1"
|
||||
document.doc_form = "text_model"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
|
|
@ -696,7 +697,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
|
|||
dataset = _dataset()
|
||||
document = MagicMock(spec=summary_module.DatasetDocument)
|
||||
document.id = "doc-1"
|
||||
document.doc_form = "text_model"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
seg = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
|
|
@ -935,7 +936,7 @@ def test_update_summary_for_segment_skip_conditions() -> None:
|
|||
SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None
|
||||
)
|
||||
seg = _segment(has_document=True)
|
||||
seg.document.doc_form = "qa_model"
|
||||
seg.document.doc_form = IndexStructureType.QA_INDEX
|
||||
assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
|
||||
import services.vector_service as vector_service_module
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from services.vector_service import VectorService
|
||||
|
||||
|
||||
|
|
@ -32,7 +33,7 @@ class _ParentDocStub:
|
|||
def _make_dataset(
|
||||
*,
|
||||
indexing_technique: str = "high_quality",
|
||||
doc_form: str = "text_model",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "dataset-1",
|
||||
is_multimodal: bool = False,
|
||||
|
|
@ -106,7 +107,7 @@ def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(mo
|
|||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model")
|
||||
VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
index_processor.load.assert_called_once()
|
||||
args, kwargs = index_processor.load.call_args
|
||||
|
|
@ -131,7 +132,7 @@ def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monk
|
|||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model")
|
||||
VectorService.create_segments_vector([["k1"]], [segment], dataset, IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
assert index_processor.load.call_count == 2
|
||||
first_args, first_kwargs = index_processor.load.call_args_list[0]
|
||||
|
|
@ -153,7 +154,7 @@ def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pyte
|
|||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(None, [], dataset, "text_model")
|
||||
VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX)
|
||||
index_processor.load.assert_not_called()
|
||||
|
||||
|
||||
|
|
@ -392,7 +393,7 @@ def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkey
|
|||
|
||||
|
||||
def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1")
|
||||
dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX, tenant_id="tenant-1", dataset_id="dataset-1")
|
||||
segment = _make_segment(segment_id="seg-1")
|
||||
|
||||
dataset_document = MagicMock()
|
||||
|
|
@ -439,7 +440,7 @@ def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch
|
|||
|
||||
|
||||
def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(doc_form="text_model")
|
||||
dataset = _make_dataset(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
segment = _make_segment()
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.doc_language = "en"
|
||||
|
|
|
|||
|
|
@ -1,259 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import Account
|
||||
from models.model import App, EndUser
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model() -> App:
|
||||
return cast(App, SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def _account(**kwargs: Any) -> Account:
|
||||
return cast(Account, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def _end_user(**kwargs: Any) -> EndUser:
|
||||
return cast(EndUser, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_raise_error_when_user_is_none(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=None,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_user = _account(id="user-1")
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id="conv-9",
|
||||
limit=10,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] is None
|
||||
assert call_kwargs["exclude_ids"] is None
|
||||
assert call_kwargs["last_id"] == "conv-9"
|
||||
assert call_kwargs["sort_by"] == "-updated_at"
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-1"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {}))
|
||||
session.scalars.return_value.all.return_value = ["conv-1", "conv-2"]
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] == ["conv-1", "conv-2"]
|
||||
assert call_kwargs["exclude_ids"] is None
|
||||
|
||||
|
||||
def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-1"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
session.scalars.return_value.all.return_value = ["conv-3"]
|
||||
mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id")
|
||||
mock_pagination.return_value = MagicMock()
|
||||
|
||||
# Act
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=fake_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_pagination.call_args.kwargs
|
||||
assert call_kwargs["include_ids"] is None
|
||||
assert call_kwargs["exclude_ids"] == ["conv-3"]
|
||||
|
||||
|
||||
def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-1", None)
|
||||
|
||||
# Assert
|
||||
mock_db.session.add.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_pin_should_return_early_when_conversation_is_already_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-1"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
||||
mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation")
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-1", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_get_conversation.assert_not_called()
|
||||
mock_db.session.add.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_pin_should_create_pinned_conversation_when_not_already_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_account_cls = type("FakeAccount", (), {})
|
||||
fake_user = cast(Account, fake_account_cls())
|
||||
fake_user.id = "account-2"
|
||||
mocker.patch("services.web_conversation_service.Account", fake_account_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_conversation = SimpleNamespace(id="conv-2")
|
||||
mock_get_conversation = mocker.patch(
|
||||
"services.web_conversation_service.ConversationService.get_conversation",
|
||||
return_value=mock_conversation,
|
||||
)
|
||||
|
||||
# Act
|
||||
WebConversationService.pin(app_model, "conv-2", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user)
|
||||
added_obj = mock_db.session.add.call_args.args[0]
|
||||
assert added_obj.app_id == "app-1"
|
||||
assert added_obj.conversation_id == "conv-2"
|
||||
assert added_obj.created_by_role == "account"
|
||||
assert added_obj.created_by == "account-2"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-1", None)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_unpin_should_return_early_when_conversation_is_not_pinned(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-3"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-7", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_unpin_should_delete_pinned_conversation_when_exists(
|
||||
app_model: App,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
fake_end_user_cls = type("FakeEndUser", (), {})
|
||||
fake_user = cast(EndUser, fake_end_user_cls())
|
||||
fake_user.id = "end-user-4"
|
||||
mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {}))
|
||||
mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls)
|
||||
mock_db = mocker.patch("services.web_conversation_service.db")
|
||||
pinned_obj = SimpleNamespace(id="pin-1")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj
|
||||
|
||||
# Act
|
||||
WebConversationService.unpin(app_model, "conv-8", fake_user)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_called_once_with(pinned_obj)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
|
@ -1,643 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.tools.api_tools_manage_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace:
|
||||
return SimpleNamespace(operation_id=operation_id)
|
||||
|
||||
|
||||
def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.parser_api_schema("valid-schema")
|
||||
|
||||
# Assert
|
||||
assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value
|
||||
assert len(result["credentials_schema"]) == 3
|
||||
assert "warning" in result
|
||||
|
||||
|
||||
def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=RuntimeError("bad schema"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"):
|
||||
ApiToolManageService.parser_api_schema("invalid")
|
||||
|
||||
|
||||
def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER)
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=expected,
|
||||
)
|
||||
extra_info: dict[str, str] = {}
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=ValueError("parse failed"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema: parse failed"):
|
||||
ApiToolManageService.convert_schema_to_tool_bundles("schema")
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_provider_already_exists(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="provider provider-a already exists"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name=" provider-a ",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
many_tools = [_tool_bundle(str(i)) for i in range(101)]
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=(many_tools, ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="the number of apis should be less than 100"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_create_provider_when_input_is_valid(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=mock_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.encrypt.return_value = {"auth_type": "none"}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=["news"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_controller.load_bundled_tools.assert_called_once()
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.get",
|
||||
return_value=SimpleNamespace(status_code=200, text="schema-content"),
|
||||
)
|
||||
mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True})
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
|
||||
|
||||
# Assert
|
||||
assert result == {"schema": "schema-content"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [400, 404, 500])
|
||||
def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid(
|
||||
status_code: int,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.get",
|
||||
return_value=SimpleNamespace(status_code=status_code, text="schema-content"),
|
||||
)
|
||||
mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema, please check the url you provided"):
|
||||
ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found(
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="you have not added provider provider-a"):
|
||||
ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
|
||||
|
||||
|
||||
def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")])
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
|
||||
return_value=controller,
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"])
|
||||
mock_convert = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
|
||||
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == [{"name": "tool-a"}, {"name": "tool-b"}]
|
||||
assert mock_convert.call_count == 2
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found(
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="api provider provider-a does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
original_provider="provider-a",
|
||||
icon={},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_raise_error_when_auth_type_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(credentials={}, name="old")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
original_provider="provider-a",
|
||||
icon={},
|
||||
credentials={},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(
|
||||
credentials={"auth_type": "none", "api_key_value": "encrypted-old"},
|
||||
name="old",
|
||||
icon="",
|
||||
schema="",
|
||||
description="",
|
||||
schema_type_str="",
|
||||
tools_str="",
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
credentials_str="",
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=controller,
|
||||
)
|
||||
cache = MagicMock()
|
||||
encrypter = MagicMock()
|
||||
encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"}
|
||||
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
|
||||
encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(encrypter, cache),
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-new",
|
||||
original_provider="provider-old",
|
||||
icon={"emoji": "E"},
|
||||
credentials={"auth_type": "none", "api_key_value": "***"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=["news"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
assert provider.name == "provider-new"
|
||||
assert provider.privacy_policy == "privacy"
|
||||
assert provider.credentials_str != ""
|
||||
cache.delete.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="you have not added provider provider-a"):
|
||||
ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
|
||||
def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
provider = object()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_db.session.delete.assert_called_once_with(provider)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
expected = {"provider": "value"}
|
||||
mock_get = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolManager.user_get_api_provider",
|
||||
return_value=expected,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None:
|
||||
# Arrange
|
||||
schema_type = "bad-schema-type"
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema type"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=schema_type, # type: ignore[arg-type]
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=RuntimeError("invalid"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid tool name tool-b"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-b",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_auth_type_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
provider_controller = MagicMock()
|
||||
tool_obj = MagicMock()
|
||||
tool_obj.fork_tool_runtime.return_value = tool_obj
|
||||
tool_obj.validate_credentials.side_effect = ValueError("validation failed")
|
||||
provider_controller.get_tool.return_value = tool_obj
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=provider_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "validation failed"}
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
provider_controller = MagicMock()
|
||||
tool_obj = MagicMock()
|
||||
tool_obj.fork_tool_runtime.return_value = tool_obj
|
||||
tool_obj.validate_credentials.return_value = {"ok": True}
|
||||
provider_controller.get_tool.return_value = tool_obj
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=provider_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={"x": "1"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": {"ok": True}}
|
||||
|
||||
|
||||
def test_list_api_tools_should_return_all_user_providers_with_converted_tools(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_one = SimpleNamespace(name="p1")
|
||||
provider_two = SimpleNamespace(name="p2")
|
||||
mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two]
|
||||
|
||||
controller_one = MagicMock()
|
||||
controller_one.get_tools.return_value = ["tool-a"]
|
||||
controller_two = MagicMock()
|
||||
controller_two.get_tools.return_value = ["tool-b", "tool-c"]
|
||||
|
||||
user_provider_one = SimpleNamespace(labels=[], tools=[])
|
||||
user_provider_two = SimpleNamespace(labels=[], tools=[])
|
||||
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
|
||||
side_effect=[controller_one, controller_two],
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"])
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider",
|
||||
side_effect=[user_provider_one, user_provider_two],
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider")
|
||||
mock_convert = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
|
||||
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.list_api_tools("tenant-1")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert user_provider_one.tools == [{"name": "tool-a"}]
|
||||
assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}]
|
||||
assert mock_convert.call_count == 3
|
||||
|
|
@ -1,955 +0,0 @@
|
|||
"""
|
||||
Unit tests for services.tools.workflow_tools_manage_service
|
||||
|
||||
Covers WorkflowToolManageService: create, update, list, delete, get, list_single.
|
||||
"""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from services.tools import workflow_tools_manage_service
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers / fake infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DummyWorkflow:
|
||||
"""Minimal in-memory Workflow substitute."""
|
||||
|
||||
def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None:
|
||||
self._graph_dict = graph_dict
|
||||
self.version = version
|
||||
|
||||
@property
|
||||
def graph_dict(self) -> dict:
|
||||
return self._graph_dict
|
||||
|
||||
|
||||
class FakeQuery:
|
||||
"""Chainable query object that always returns a fixed result."""
|
||||
|
||||
def __init__(self, result: object) -> None:
|
||||
self._result = result
|
||||
|
||||
def where(self, *args: object, **kwargs: object) -> "FakeQuery":
|
||||
return self
|
||||
|
||||
def first(self) -> object:
|
||||
return self._result
|
||||
|
||||
def delete(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
class DummySession:
|
||||
"""Minimal SQLAlchemy session substitute."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.added: list[WorkflowToolProvider] = []
|
||||
self.committed: bool = False
|
||||
|
||||
def __enter__(self) -> "DummySession":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: object, exc: object, tb: object) -> bool:
|
||||
return False
|
||||
|
||||
def add(self, obj: WorkflowToolProvider) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def begin(self) -> "DummySession":
|
||||
return self
|
||||
|
||||
def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
|
||||
def _build_parameters() -> list[WorkflowToolParameterConfiguration]:
|
||||
return [
|
||||
WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM),
|
||||
]
|
||||
|
||||
|
||||
def _build_fake_db(
|
||||
*,
|
||||
existing_tool: WorkflowToolProvider | None = None,
|
||||
app: object | None = None,
|
||||
tool_by_id: WorkflowToolProvider | None = None,
|
||||
) -> tuple[MagicMock, DummySession]:
|
||||
"""
|
||||
Build a fake db object plus a DummySession for Session context-manager.
|
||||
|
||||
query(WorkflowToolProvider) returns existing_tool on first call,
|
||||
then tool_by_id on subsequent calls (or None if not provided).
|
||||
query(App) returns app.
|
||||
"""
|
||||
call_counts: dict[str, int] = {"wftp": 0}
|
||||
|
||||
def query(model: type) -> FakeQuery:
|
||||
if model is WorkflowToolProvider:
|
||||
call_counts["wftp"] += 1
|
||||
if call_counts["wftp"] == 1:
|
||||
return FakeQuery(existing_tool)
|
||||
return FakeQuery(tool_by_id)
|
||||
if model is App:
|
||||
return FakeQuery(app)
|
||||
return FakeQuery(None)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.session = SimpleNamespace(query=query, commit=MagicMock())
|
||||
dummy_session = DummySession()
|
||||
return fake_db, dummy_session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCreateWorkflowTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateWorkflowTool:
|
||||
"""Tests for WorkflowToolManageService.create_workflow_tool."""
|
||||
|
||||
def test_should_raise_when_human_input_nodes_present(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Human-input nodes must be rejected before any provider is created."""
|
||||
# Arrange
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "n1", "data": {"type": "human-input"}}]})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
fake_session = SimpleNamespace(query=lambda m: FakeQuery(None) if m is WorkflowToolProvider else FakeQuery(app))
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session)
|
||||
mock_from_db = MagicMock()
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
workflow_app_id="app-id",
|
||||
name="tool_name",
|
||||
label="Tool",
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description="desc",
|
||||
parameters=_build_parameters(),
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
mock_from_db.assert_not_called()
|
||||
|
||||
def test_should_raise_when_duplicate_name_or_app_id(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Existing provider with same name or app_id raises ValueError."""
|
||||
# Arrange
|
||||
existing = MagicMock(spec=WorkflowToolProvider)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(existing)),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_app_id="app-1",
|
||||
name="dup",
|
||||
label="Dup",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the referenced App does not exist."""
|
||||
# Arrange
|
||||
call_count = {"n": 0}
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
call_count["n"] += 1
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None)
|
||||
return FakeQuery(None) # App returns None
|
||||
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_app_id="missing-app",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the App has no attached Workflow."""
|
||||
# Arrange
|
||||
app_no_workflow = SimpleNamespace(workflow=None)
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None)
|
||||
return FakeQuery(app_no_workflow)
|
||||
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_app_id="app-id",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Exceptions from WorkflowToolProviderController.from_db are wrapped as ValueError."""
|
||||
# Arrange
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": []})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None)
|
||||
return FakeQuery(app)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.session = SimpleNamespace(query=query)
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
|
||||
dummy_session = DummySession()
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.WorkflowToolProviderController,
|
||||
"from_db",
|
||||
MagicMock(side_effect=RuntimeError("bad config")),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="bad config"):
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_app_id="app-id",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_succeed_and_persist_provider(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Happy path: provider is added to session and success dict is returned."""
|
||||
# Arrange
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": []}, version="2.0.0")
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None)
|
||||
return FakeQuery(app)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.session = SimpleNamespace(query=query)
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
|
||||
dummy_session = DummySession()
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock())
|
||||
|
||||
icon = {"type": "emoji", "emoji": "🔧"}
|
||||
|
||||
# Act
|
||||
result = WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
workflow_app_id="app-id",
|
||||
name="tool_name",
|
||||
label="Tool",
|
||||
icon=icon,
|
||||
description="desc",
|
||||
parameters=_build_parameters(),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
assert len(dummy_session.added) == 1
|
||||
created: WorkflowToolProvider = dummy_session.added[0]
|
||||
assert created.name == "tool_name"
|
||||
assert created.label == "Tool"
|
||||
assert created.icon == json.dumps(icon)
|
||||
assert created.version == "2.0.0"
|
||||
|
||||
def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Labels are forwarded to ToolLabelManager when provided."""
|
||||
# Arrange
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": []})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None)
|
||||
return FakeQuery(app)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.session = SimpleNamespace(query=query)
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
|
||||
dummy_session = DummySession()
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock())
|
||||
mock_label_mgr = MagicMock()
|
||||
monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr)
|
||||
mock_to_ctrl = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", mock_to_ctrl
|
||||
)
|
||||
|
||||
# Act
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_app_id="app-id",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
labels=["tag1", "tag2"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_label_mgr.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUpdateWorkflowTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateWorkflowTool:
|
||||
"""Tests for WorkflowToolManageService.update_workflow_tool."""
|
||||
|
||||
def _make_provider(self) -> WorkflowToolProvider:
|
||||
p = MagicMock(spec=WorkflowToolProvider)
|
||||
p.app_id = "app-id"
|
||||
p.tenant_id = "tenant-id"
|
||||
return p
|
||||
|
||||
def test_should_raise_when_name_duplicated(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""If another tool with the given name already exists, raise ValueError."""
|
||||
# Arrange
|
||||
existing = MagicMock(spec=WorkflowToolProvider)
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
return FakeQuery(existing)
|
||||
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_tool_id="tool-1",
|
||||
name="dup",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the workflow tool to update does not exist."""
|
||||
# Arrange
|
||||
call_count = {"n": 0}
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
call_count["n"] += 1
|
||||
# 1st call: name uniqueness check → None (no duplicate)
|
||||
# 2nd call: fetch tool by id → None (not found)
|
||||
return FakeQuery(None)
|
||||
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_tool_id="missing",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the tool's referenced App has been removed."""
|
||||
# Arrange
|
||||
provider = self._make_provider()
|
||||
call_count = {"n": 0}
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
call_count["n"] += 1
|
||||
if m is WorkflowToolProvider:
|
||||
# 1st: duplicate name check (None), 2nd: fetch provider
|
||||
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
|
||||
return FakeQuery(None) # App not found
|
||||
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_tool_id="tool-1",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the App exists but has no Workflow."""
|
||||
# Arrange
|
||||
provider = self._make_provider()
|
||||
app_no_wf = SimpleNamespace(workflow=None)
|
||||
call_count = {"n": 0}
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
call_count["n"] += 1
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
|
||||
return FakeQuery(app_no_wf)
|
||||
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query))
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_tool_id="tool-1",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Exceptions from from_db are re-raised as ValueError."""
|
||||
# Arrange
|
||||
provider = self._make_provider()
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": []})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
call_count = {"n": 0}
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
call_count["n"] += 1
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
|
||||
return FakeQuery(app)
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=query, commit=MagicMock()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.WorkflowToolProviderController,
|
||||
"from_db",
|
||||
MagicMock(side_effect=RuntimeError("from_db error")),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="from_db error"):
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_tool_id="tool-1",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
def test_should_succeed_and_call_commit(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Happy path: provider fields are updated and session committed."""
|
||||
# Arrange
|
||||
provider = self._make_provider()
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": []}, version="3.0.0")
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
call_count = {"n": 0}
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
call_count["n"] += 1
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
|
||||
return FakeQuery(app)
|
||||
|
||||
mock_commit = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=query, commit=mock_commit),
|
||||
)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock())
|
||||
|
||||
icon = {"type": "emoji", "emoji": "🛠"}
|
||||
|
||||
# Act
|
||||
result = WorkflowToolManageService.update_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_tool_id="tool-1",
|
||||
name="new_name",
|
||||
label="New Label",
|
||||
icon=icon,
|
||||
description="new desc",
|
||||
parameters=_build_parameters(),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_commit.assert_called_once()
|
||||
assert provider.name == "new_name"
|
||||
assert provider.label == "New Label"
|
||||
assert provider.icon == json.dumps(icon)
|
||||
assert provider.version == "3.0.0"
|
||||
|
||||
def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Labels are forwarded to ToolLabelManager during update."""
|
||||
# Arrange
|
||||
provider = self._make_provider()
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": []})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
call_count = {"n": 0}
|
||||
|
||||
def query(m: type) -> FakeQuery:
|
||||
call_count["n"] += 1
|
||||
if m is WorkflowToolProvider:
|
||||
return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider)
|
||||
return FakeQuery(app)
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=query, commit=MagicMock()),
|
||||
)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock())
|
||||
mock_label_mgr = MagicMock()
|
||||
monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", MagicMock()
|
||||
)
|
||||
|
||||
# Act
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id="u",
|
||||
tenant_id="t",
|
||||
workflow_tool_id="tool-1",
|
||||
name="n",
|
||||
label="L",
|
||||
icon={},
|
||||
description="",
|
||||
parameters=[],
|
||||
labels=["a"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_label_mgr.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestListTenantWorkflowTools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListTenantWorkflowTools:
|
||||
"""Tests for WorkflowToolManageService.list_tenant_workflow_tools."""
|
||||
|
||||
def test_should_return_empty_list_when_no_tools(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""An empty database yields an empty result list."""
|
||||
# Arrange
|
||||
fake_scalars = MagicMock()
|
||||
fake_scalars.all.return_value = []
|
||||
fake_db = MagicMock()
|
||||
fake_db.session.scalars.return_value = fake_scalars
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
|
||||
|
||||
# Act
|
||||
result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t")
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_should_skip_broken_providers_and_log(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Providers that fail to load are logged and skipped."""
|
||||
# Arrange
|
||||
good_provider = MagicMock(spec=WorkflowToolProvider)
|
||||
good_provider.id = "good-id"
|
||||
good_provider.app_id = "app-good"
|
||||
bad_provider = MagicMock(spec=WorkflowToolProvider)
|
||||
bad_provider.id = "bad-id"
|
||||
bad_provider.app_id = "app-bad"
|
||||
|
||||
fake_scalars = MagicMock()
|
||||
fake_scalars.all.return_value = [good_provider, bad_provider]
|
||||
fake_db = MagicMock()
|
||||
fake_db.session.scalars.return_value = fake_scalars
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
|
||||
|
||||
good_ctrl = MagicMock()
|
||||
good_ctrl.provider_id = "good-id"
|
||||
|
||||
def to_controller(provider: WorkflowToolProvider) -> MagicMock:
|
||||
if provider is bad_provider:
|
||||
raise RuntimeError("broken provider")
|
||||
return good_ctrl
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", to_controller
|
||||
)
|
||||
mock_get_labels = MagicMock(return_value={})
|
||||
monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", mock_get_labels)
|
||||
mock_to_user = MagicMock()
|
||||
mock_to_user.return_value.tools = []
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_user_provider", mock_to_user
|
||||
)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock())
|
||||
mock_get_tools = MagicMock(return_value=[MagicMock()])
|
||||
good_ctrl.get_tools = mock_get_tools
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock()
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t")
|
||||
|
||||
# Assert - only good provider contributed
|
||||
assert len(result) == 1
|
||||
|
||||
def test_should_return_tools_for_all_providers(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""All successfully loaded providers appear in the result."""
|
||||
# Arrange
|
||||
provider = MagicMock(spec=WorkflowToolProvider)
|
||||
provider.id = "p-1"
|
||||
provider.app_id = "app-1"
|
||||
|
||||
fake_scalars = MagicMock()
|
||||
fake_scalars.all.return_value = [provider]
|
||||
fake_db = MagicMock()
|
||||
fake_db.session.scalars.return_value = fake_scalars
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
|
||||
|
||||
ctrl = MagicMock()
|
||||
ctrl.provider_id = "p-1"
|
||||
ctrl.get_tools.return_value = [MagicMock()]
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService,
|
||||
"workflow_provider_to_controller",
|
||||
MagicMock(return_value=ctrl),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", MagicMock(return_value={"p-1": []})
|
||||
)
|
||||
user_provider = MagicMock()
|
||||
user_provider.tools = []
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService,
|
||||
"workflow_provider_to_user_provider",
|
||||
MagicMock(return_value=user_provider),
|
||||
)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock())
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock()
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0] is user_provider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestDeleteWorkflowTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteWorkflowTool:
|
||||
"""Tests for WorkflowToolManageService.delete_workflow_tool."""
|
||||
|
||||
def test_should_delete_and_commit(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""delete_workflow_tool queries, deletes, commits, and returns success."""
|
||||
# Arrange
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.delete.return_value = 1
|
||||
mock_commit = MagicMock()
|
||||
fake_session = SimpleNamespace(query=lambda m: mock_query, commit=mock_commit)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session)
|
||||
|
||||
# Act
|
||||
result = WorkflowToolManageService.delete_workflow_tool("u", "t", "tool-1")
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestGetWorkflowToolByToolId / ByAppId
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetWorkflowToolByToolIdAndAppId:
|
||||
"""Tests for get_workflow_tool_by_tool_id and get_workflow_tool_by_app_id."""
|
||||
|
||||
def test_get_by_tool_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Raises ValueError when no WorkflowToolProvider found by tool id."""
|
||||
# Arrange
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(None)),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Tool not found"):
|
||||
WorkflowToolManageService.get_workflow_tool_by_tool_id("u", "t", "missing")
|
||||
|
||||
def test_get_by_app_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Raises ValueError when no WorkflowToolProvider found by app id."""
|
||||
# Arrange
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(None)),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Tool not found"):
|
||||
WorkflowToolManageService.get_workflow_tool_by_app_id("u", "t", "missing-app")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestGetWorkflowTool (private _get_workflow_tool)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetWorkflowTool:
|
||||
"""Tests for the internal _get_workflow_tool helper."""
|
||||
|
||||
def test_should_raise_when_db_tool_none(self) -> None:
|
||||
"""_get_workflow_tool raises ValueError when db_tool is None."""
|
||||
with pytest.raises(ValueError, match="Tool not found"):
|
||||
WorkflowToolManageService._get_workflow_tool("t", None)
|
||||
|
||||
def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the corresponding App row is missing."""
|
||||
# Arrange
|
||||
db_tool = MagicMock(spec=WorkflowToolProvider)
|
||||
db_tool.app_id = "app-1"
|
||||
db_tool.tenant_id = "t"
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(None)),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService._get_workflow_tool("t", db_tool)
|
||||
|
||||
def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when App has no attached Workflow."""
|
||||
# Arrange
|
||||
db_tool = MagicMock(spec=WorkflowToolProvider)
|
||||
db_tool.app_id = "app-1"
|
||||
db_tool.tenant_id = "t"
|
||||
app = SimpleNamespace(workflow=None)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(app)),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WorkflowToolManageService._get_workflow_tool("t", db_tool)
|
||||
|
||||
def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the controller returns no WorkflowTool instances."""
|
||||
# Arrange
|
||||
db_tool = MagicMock(spec=WorkflowToolProvider)
|
||||
db_tool.app_id = "app-1"
|
||||
db_tool.tenant_id = "t"
|
||||
db_tool.id = "tool-1"
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": []})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(app)),
|
||||
)
|
||||
ctrl = MagicMock()
|
||||
ctrl.get_tools.return_value = []
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService,
|
||||
"workflow_provider_to_controller",
|
||||
MagicMock(return_value=ctrl),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService._get_workflow_tool("t", db_tool)
|
||||
|
||||
def test_should_return_dict_on_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Happy path: returns a dict with name, label, icon, synced, etc."""
|
||||
# Arrange
|
||||
db_tool = MagicMock(spec=WorkflowToolProvider)
|
||||
db_tool.app_id = "app-1"
|
||||
db_tool.tenant_id = "t"
|
||||
db_tool.id = "tool-1"
|
||||
db_tool.name = "my_tool"
|
||||
db_tool.label = "My Tool"
|
||||
db_tool.icon = json.dumps({"emoji": "🔧"})
|
||||
db_tool.description = "some desc"
|
||||
db_tool.privacy_policy = ""
|
||||
db_tool.version = "1.0"
|
||||
db_tool.parameter_configurations = []
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": []}, version="1.0")
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(app)),
|
||||
)
|
||||
|
||||
workflow_tool = MagicMock()
|
||||
workflow_tool.entity.output_schema = {"type": "object"}
|
||||
ctrl = MagicMock()
|
||||
ctrl.get_tools.return_value = [workflow_tool]
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService,
|
||||
"workflow_provider_to_controller",
|
||||
MagicMock(return_value=ctrl),
|
||||
)
|
||||
mock_convert = MagicMock(return_value={"tool": "api_entity"})
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", mock_convert
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[])
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WorkflowToolManageService._get_workflow_tool("t", db_tool)
|
||||
|
||||
# Assert
|
||||
assert result["name"] == "my_tool"
|
||||
assert result["label"] == "My Tool"
|
||||
assert result["synced"] is True
|
||||
assert "icon" in result
|
||||
assert "output_schema" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestListSingleWorkflowTools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListSingleWorkflowTools:
|
||||
"""Tests for WorkflowToolManageService.list_single_workflow_tools."""
|
||||
|
||||
def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the specified tool does not exist in DB."""
|
||||
# Arrange
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(None)),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1")
|
||||
|
||||
def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ValueError when the controller yields no tools for the provider."""
|
||||
# Arrange
|
||||
db_tool = MagicMock(spec=WorkflowToolProvider)
|
||||
db_tool.id = "tool-1"
|
||||
db_tool.tenant_id = "t"
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(db_tool)),
|
||||
)
|
||||
ctrl = MagicMock()
|
||||
ctrl.get_tools.return_value = []
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService,
|
||||
"workflow_provider_to_controller",
|
||||
MagicMock(return_value=ctrl),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1")
|
||||
|
||||
def test_should_return_api_entity_list(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Happy path: returns list with one ToolApiEntity."""
|
||||
# Arrange
|
||||
db_tool = MagicMock(spec=WorkflowToolProvider)
|
||||
db_tool.id = "tool-1"
|
||||
db_tool.tenant_id = "t"
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.db,
|
||||
"session",
|
||||
SimpleNamespace(query=lambda m: FakeQuery(db_tool)),
|
||||
)
|
||||
workflow_tool = MagicMock()
|
||||
ctrl = MagicMock()
|
||||
ctrl.get_tools.return_value = [workflow_tool]
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService,
|
||||
"workflow_provider_to_controller",
|
||||
MagicMock(return_value=ctrl),
|
||||
)
|
||||
api_entity = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolTransformService,
|
||||
"convert_tool_entity_to_api_entity",
|
||||
MagicMock(return_value=api_entity),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[])
|
||||
)
|
||||
|
||||
# Act
|
||||
result = WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1")
|
||||
|
||||
# Assert
|
||||
assert result == [api_entity]
|
||||
|
|
@ -121,6 +121,7 @@ import pytest
|
|||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment
|
||||
from services.vector_service import VectorService
|
||||
|
|
@ -151,7 +152,7 @@ class VectorServiceTestDataFactory:
|
|||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
doc_form: str = "text_model",
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_technique: str = "high_quality",
|
||||
embedding_model_provider: str = "openai",
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
|
|
@ -493,7 +494,7 @@ class TestVectorService:
|
|||
"""
|
||||
# Arrange
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock(
|
||||
doc_form="text_model", indexing_technique="high_quality"
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_technique="high_quality"
|
||||
)
|
||||
|
||||
segment = VectorServiceTestDataFactory.create_document_segment_mock()
|
||||
|
|
@ -505,7 +506,7 @@ class TestVectorService:
|
|||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor
|
||||
|
||||
# Act
|
||||
VectorService.create_segments_vector(keywords_list, [segment], dataset, "text_model")
|
||||
VectorService.create_segments_vector(keywords_list, [segment], dataset, IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
# Assert
|
||||
mock_index_processor.load.assert_called_once()
|
||||
|
|
@ -649,7 +650,7 @@ class TestVectorService:
|
|||
mock_index_processor_factory.return_value.init_index_processor.return_value = mock_index_processor
|
||||
|
||||
# Act
|
||||
VectorService.create_segments_vector(None, [], dataset, "text_model")
|
||||
VectorService.create_segments_vector(None, [], dataset, IndexStructureType.PARAGRAPH_INDEX)
|
||||
|
||||
# Assert
|
||||
mock_index_processor.load.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.enums import DataSourceType
|
||||
from tasks.clean_dataset_task import clean_dataset_task
|
||||
|
||||
|
|
@ -186,7 +187,7 @@ class TestErrorHandling:
|
|||
indexing_technique="high_quality",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
@ -231,7 +232,7 @@ class TestPipelineAndWorkflowDeletion:
|
|||
indexing_technique="high_quality",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
|
|
@ -267,7 +268,7 @@ class TestPipelineAndWorkflowDeletion:
|
|||
indexing_technique="high_quality",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
pipeline_id=None,
|
||||
)
|
||||
|
||||
|
|
@ -323,7 +324,7 @@ class TestSegmentAttachmentCleanup:
|
|||
indexing_technique="high_quality",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
@ -368,7 +369,7 @@ class TestSegmentAttachmentCleanup:
|
|||
indexing_technique="high_quality",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
# Assert - storage delete was attempted
|
||||
|
|
@ -410,7 +411,7 @@ class TestEdgeCases:
|
|||
indexing_technique="high_quality",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
@ -454,7 +455,7 @@ class TestIndexProcessorParameters:
|
|||
indexing_technique=indexing_technique,
|
||||
index_struct=index_struct,
|
||||
collection_binding_id=collection_binding_id,
|
||||
doc_form="paragraph_index",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
|
|
@ -222,7 +223,7 @@ def mock_documents(document_ids, dataset_id):
|
|||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
# optional attribute used in some code paths
|
||||
doc.doc_form = "text_model"
|
||||
doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
|
@ -62,7 +63,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id,
|
|||
document.tenant_id = str(uuid.uuid4())
|
||||
document.data_source_type = "notion_import"
|
||||
document.indexing_status = "completed"
|
||||
document.doc_form = "text_model"
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
document.data_source_info_dict = {
|
||||
"notion_workspace_id": notion_workspace_id,
|
||||
"notion_page_id": notion_page_id,
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@
|
|||
},
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
"flatted@<=3.4.1": "3.4.2",
|
||||
"rollup@>=4.0.0,<4.59.0": "4.59.0"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ settings:
|
|||
excludeLinksFromLockfile: false
|
||||
|
||||
overrides:
|
||||
flatted@<=3.4.1: 3.4.2
|
||||
rollup@>=4.0.0,<4.59.0: 4.59.0
|
||||
|
||||
importers:
|
||||
|
|
@ -324,66 +325,79 @@ packages:
|
|||
resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==}
|
||||
cpu: [arm]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-arm-musleabihf@4.59.0':
|
||||
resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==}
|
||||
cpu: [arm]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-arm64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-arm64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-loong64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==}
|
||||
cpu: [loong64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-loong64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==}
|
||||
cpu: [loong64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-ppc64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==}
|
||||
cpu: [ppc64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-ppc64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==}
|
||||
cpu: [ppc64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-riscv64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==}
|
||||
cpu: [riscv64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-riscv64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==}
|
||||
cpu: [riscv64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-linux-s390x-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==}
|
||||
cpu: [s390x]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-x64-gnu@4.59.0':
|
||||
resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rollup/rollup-linux-x64-musl@4.59.0':
|
||||
resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rollup/rollup-openbsd-x64@4.59.0':
|
||||
resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==}
|
||||
|
|
@ -741,8 +755,8 @@ packages:
|
|||
resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==}
|
||||
engines: {node: '>=16'}
|
||||
|
||||
flatted@3.4.1:
|
||||
resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==}
|
||||
flatted@3.4.2:
|
||||
resolution: {integrity: sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==}
|
||||
|
||||
follow-redirects@1.15.11:
|
||||
resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==}
|
||||
|
|
@ -1836,10 +1850,10 @@ snapshots:
|
|||
|
||||
flat-cache@4.0.1:
|
||||
dependencies:
|
||||
flatted: 3.4.1
|
||||
flatted: 3.4.2
|
||||
keyv: 4.5.4
|
||||
|
||||
flatted@3.4.1: {}
|
||||
flatted@3.4.2: {}
|
||||
|
||||
follow-redirects@1.15.11: {}
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue