Merge branch 'main' into refactor/auth-services-typeddict

This commit is contained in:
BitToby 2026-03-24 02:06:54 +02:00 committed by GitHub
commit e04b839ae7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 410 additions and 1035 deletions

View File

@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
subquery = (
db.session.query(
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
)
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery()
)
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
)
if not conversation:

View File

@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
app = db.session.get(App, args.flow_id)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)

View File

@ -2,6 +2,7 @@ import json
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
return server
@console_ns.doc("create_app_mcp_server")
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
server = db.session.get(AppMCPServer, payload.id)
if not server:
raise NotFound()
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
server = db.session.scalar(
select(AppMCPServer)
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
.limit(1)
)
if not server:
raise NotFound()

View File

@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
)
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
if original_app_model_config is None:
raise ValueError("Original app model config not found")
agent_mode = original_app_model_config.agent_mode_dict

View File

@ -2,6 +2,7 @@ from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
@ -75,7 +76,7 @@ class AppSite(Resource):
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound

View File

@ -2,6 +2,8 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from sqlalchemy import select
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
return app_model

View File

@ -33,6 +33,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, TidbAuthBinding
from models.enums import TidbAuthBindingStatus
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
password=new_cluster["password"],
tenant_id=dataset.tenant_id,
active=True,
status="ACTIVE",
status=TidbAuthBindingStatus.ACTIVE,
)
db.session.add(new_tidb_auth_binding)
db.session.commit()

View File

@ -9,6 +9,7 @@ from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
class TidbService:
@ -170,7 +171,7 @@ class TidbService:
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = "ACTIVE"
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()

View File

@ -43,7 +43,9 @@ from .enums import (
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SegmentType,
SummaryStatus,
TidbAuthBindingStatus,
)
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
@ -998,7 +1000,9 @@ class ChildChunk(Base):
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -1239,7 +1243,9 @@ class TidbAuthBinding(TypeBase):
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
status: Mapped[TidbAuthBindingStatus] = mapped_column(
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
)
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(

View File

@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum):
TIME = "time"
class SegmentType(StrEnum):
"""Document segment type"""
AUTOMATIC = "automatic"
CUSTOMIZED = "customized"
class SegmentStatus(StrEnum):
"""Document segment status"""

View File

@ -21,7 +21,7 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.tools.signature import sign_tool_file
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from dify_graph.file import helpers as file_helpers
from extensions.storage.storage_type import StorageType
from libs.helper import generate_string # type: ignore[import-not-found]
@ -1785,7 +1785,7 @@ class MessageFile(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(
EnumText(FileTransferMethod, length=255), nullable=False
)

View File

@ -8,6 +8,7 @@ from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
@app.celery.task(queue="dataset")
@ -57,7 +58,7 @@ def create_clusters(batch_size):
account=new_cluster["account"],
password=new_cluster["password"],
active=False,
status="CREATING",
status=TidbAuthBindingStatus.CREATING,
)
db.session.add(tidb_auth_binding)
db.session.commit()

View File

@ -9,6 +9,7 @@ from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
@app.celery.task(queue="dataset")
@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
try:
# check the number of idle tidb serverless
tidb_serverless_list = db.session.scalars(
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
select(TidbAuthBinding).where(
TidbAuthBinding.active == False,
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
)
).all()
if len(tidb_serverless_list) == 0:
return

View File

@ -58,6 +58,7 @@ from models.enums import (
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SegmentType,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
@ -3786,7 +3787,7 @@ class SegmentService:
child_chunk.word_count = len(child_chunk.content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
child_chunk.type = SegmentType.CUSTOMIZED
update_child_chunks.append(child_chunk)
else:
new_child_chunks_args.append(child_chunk_update_args)
@ -3845,7 +3846,7 @@ class SegmentService:
child_chunk.word_count = len(content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
child_chunk.type = SegmentType.CUSTOMIZED
db.session.add(child_chunk)
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
db.session.commit()

View File

@ -1,17 +1,10 @@
"""
Test suite for password reset authentication flows.
"""Testcontainers integration tests for password reset authentication flows."""
This module tests the password reset mechanism including:
- Password reset email sending
- Verification code validation
- Password reset with token
- Rate limiting and security checks
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.error import (
EmailCodeError,
@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import (
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
@pytest.fixture(autouse=True)
def _mock_forgot_password_session():
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.__exit__.return_value = None
yield mock_session
@pytest.fixture(autouse=True)
def _mock_forgot_password_db():
with patch("controllers.console.auth.forgot_password.db") as mock_db:
mock_db.engine = MagicMock()
yield mock_db
class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_account(self):
@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi:
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
mock_wraps_db,
app,
mock_account,
):
"""
Test successful password reset email sending.
Verifies that:
- Email is sent to valid account
- Reset token is generated and returned
- IP rate limiting is checked
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_account.return_value = mock_account
mock_send_email.return_value = "reset_token_123"
@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi:
assert response["data"] == "reset_token_123"
mock_send_email.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
"""
Test password reset email blocked by IP rate limit.
@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi:
- No email is sent when rate limited
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi:
(None, "en-US"), # Defaults to en-US when not provided
],
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
mock_wraps_db,
app,
mock_account,
language_input,
@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi:
- Unsupported languages default to en-US
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token"
@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi:
"""Test cases for verifying password reset codes."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
"""
@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi:
- Rate limit is reset on success
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_generate_token.return_value = (None, "new_token")
@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi:
)
mock_reset_rate_limit.assert_called_once_with("test@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
mock_generate_token.return_value = (None, "fresh-token")
@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi:
mock_revoke_token.assert_called_once_with("upper_token")
mock_reset_rate_limit.assert_called_once_with("user@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
"""
Test code verification blocked by rate limit.
@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi:
- Prevents brute force attacks on verification codes
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
# Act & Assert
@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi:
with pytest.raises(EmailPasswordResetLimitError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with invalid token.
@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = None
@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with mismatched email.
@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi:
- Prevents token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi:
with pytest.raises(InvalidEmailError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with incorrect code.
@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi:
- Rate limit counter is incremented
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
@ -380,11 +321,8 @@ class TestForgotPasswordResetApi:
"""Test cases for resetting password with verified token."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_account(self):
@ -394,7 +332,6 @@ class TestForgotPasswordResetApi:
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@ -405,7 +342,6 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_revoke_token,
mock_get_data,
mock_wraps_db,
app,
mock_account,
):
@ -418,7 +354,6 @@ class TestForgotPasswordResetApi:
- Success response is returned
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_get_account.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@ -436,9 +371,8 @@ class TestForgotPasswordResetApi:
assert response["result"] == "success"
mock_revoke_token.assert_called_once_with("valid_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
def test_reset_password_mismatch(self, mock_get_data, app):
"""
Test password reset with mismatched passwords.
@ -447,7 +381,6 @@ class TestForgotPasswordResetApi:
- No password update occurs
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
# Act & Assert
@ -460,9 +393,8 @@ class TestForgotPasswordResetApi:
with pytest.raises(PasswordMismatchError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
def test_reset_password_invalid_token(self, mock_get_data, app):
"""
Test password reset with invalid token.
@ -470,7 +402,6 @@ class TestForgotPasswordResetApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@ -483,9 +414,8 @@ class TestForgotPasswordResetApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
def test_reset_password_wrong_phase(self, mock_get_data, app):
"""
Test password reset with token not in reset phase.
@ -494,7 +424,6 @@ class TestForgotPasswordResetApi:
- Prevents use of verification-phase tokens for reset
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
# Act & Assert
@ -507,13 +436,10 @@ class TestForgotPasswordResetApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_reset_password_account_not_found(
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
):
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
"""
Test password reset for non-existent account.
@ -521,7 +447,6 @@ class TestForgotPasswordResetApi:
- AccountNotFound is raised when account doesn't exist
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
mock_get_account.return_value = None

View File

@ -8,6 +8,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.file.enums import FileType
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
# MessageFile
file = MessageFile(
message_id=message.id,
type="image",
type=FileType.IMAGE,
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to=MessageFileBelongsTo.USER,

View File

@ -0,0 +1,174 @@
"""Testcontainers integration tests for OAuthServerService."""
from __future__ import annotations
import uuid
from typing import cast
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import BadRequest
from models.model import OAuthProviderApp
from services.oauth_server import (
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
OAUTH_ACCESS_TOKEN_REDIS_KEY,
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
OAUTH_REFRESH_TOKEN_REDIS_KEY,
OAuthGrantType,
OAuthServerService,
)
class TestOAuthServerServiceGetProviderApp:
"""DB-backed tests for get_oauth_provider_app."""
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
app = OAuthProviderApp(
app_icon="icon.png",
client_id=client_id,
client_secret=str(uuid4()),
app_label={"en-US": "Test OAuth App"},
redirect_uris=["https://example.com/callback"],
scope="read",
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
client_id = f"client-{uuid4()}"
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
result = OAuthServerService.get_oauth_provider_app(client_id)
assert result is not None
assert result.client_id == client_id
assert result.id == created.id
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
assert result is None
class TestOAuthServerServiceTokenOperations:
"""Redis-backed tests for token sign/validate operations."""
@pytest.fixture
def mock_redis(self):
with patch("services.oauth_server.redis_client") as mock:
yield mock
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
assert code == str(deterministic_uuid)
mock_redis.set.assert_called_once_with(
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
"user-1",
ex=600,
)
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
mock_redis.get.return_value = None
with pytest.raises(BadRequest, match="invalid code"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="bad-code",
client_id="client-1",
)
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
token_uuids = [
uuid.UUID("00000000-0000-0000-0000-000000000201"),
uuid.UUID("00000000-0000-0000-0000-000000000202"),
]
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
mock_redis.get.return_value = b"user-1"
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="code-1",
client_id="client-1",
)
assert access_token == str(token_uuids[0])
assert refresh_token == str(token_uuids[1])
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
mock_redis.delete.assert_called_once_with(code_key)
mock_redis.set.assert_any_call(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
mock_redis.set.assert_any_call(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
b"user-1",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
mock_redis.get.return_value = None
with pytest.raises(BadRequest, match="invalid refresh token"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="stale-token",
client_id="client-1",
)
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
mock_redis.get.return_value = b"user-1"
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="refresh-1",
client_id="client-1",
)
assert access_token == str(deterministic_uuid)
assert returned_refresh == "refresh-1"
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
grant_type = cast(OAuthGrantType, "invalid-grant-type")
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
assert result is None
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
assert refresh_token == str(deterministic_uuid)
mock_redis.set.assert_called_once_with(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
"user-2",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
mock_redis.get.return_value = None
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
assert result is None
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
mock_redis.get.return_value = b"user-88"
expected_user = MagicMock()
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
assert result is expected_user
mock_load.assert_called_once_with("user-88")

View File

@ -536,3 +536,151 @@ class TestApiToolManageService:
# Verify mock interactions
mock_external_service_dependencies["encrypter"].assert_called_once()
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
def test_delete_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test successful deletion of an API tool provider."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
schema = self._create_test_openapi_schema()
provider_name = fake.unique.word()
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy="",
custom_disclaimer="",
labels=[],
)
provider = (
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert provider is not None
result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name)
assert result == {"result": "success"}
deleted = (
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert deleted is None
def test_delete_api_tool_provider_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test deletion raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
def test_update_api_tool_provider_not_found(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test update raises ValueError when original provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="does not exists"):
ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name="new-name",
original_provider="nonexistent",
icon={},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
privacy_policy=None,
custom_disclaimer="",
labels=[],
)
def test_update_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test update raises ValueError when auth_type is missing from credentials."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
schema = self._create_test_openapi_schema()
provider_name = fake.unique.word()
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon={"content": "🔧", "background": "#FFF"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy="",
custom_disclaimer="",
labels=[],
)
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
original_provider=provider_name,
icon={},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=schema,
privacy_policy=None,
custom_disclaimer="",
labels=[],
)
def test_list_api_tool_provider_tools_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test listing tools raises ValueError when provider not found."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
def test_test_api_tool_preview_invalid_schema_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test preview raises ValueError for invalid schema type."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
with pytest.raises(ValueError, match="invalid schema type"):
ApiToolManageService.test_api_tool_preview(
tenant_id=tenant.id,
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type="bad-schema-type",
schema="schema",
)

View File

@ -281,12 +281,10 @@ class TestSiteEndpoints:
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
)
monkeypatch.setattr(
site_module,
@ -305,12 +303,10 @@ class TestSiteEndpoints:
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
)
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
monkeypatch.setattr(

View File

@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
conversation = SimpleNamespace(id="c1", app_id="app-1")
query = MagicMock()
query.where.return_value = query
query.first.return_value = conversation
session = MagicMock()
session.query.return_value = query
session.scalar.return_value = conversation
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
session = MagicMock()
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)

View File

@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged():
),
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
):
mock_session.query.return_value.where.return_value.first.return_value = conversation
mock_session.scalar.return_value = conversation
_get_conversation(app_model, "conversation-id")

View File

@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
with app.test_request_context(
"/console/api/instruction-generate",
@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
_install_workflow_service(monkeypatch, workflow=None)
with app.test_request_context(
@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
workflow = SimpleNamespace(graph_dict={"nodes": []})
_install_workflow_service(monkeypatch, workflow=workflow)
@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
workflow = SimpleNamespace(
graph_dict={

View File

@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
)
session = MagicMock()
query = MagicMock()
query.where.return_value = query
query.first.return_value = original_config
session.query.return_value = query
session.get.return_value = original_config
monkeypatch.setattr(model_config_module.db, "session", session)
monkeypatch.setattr(

View File

@ -11,10 +11,8 @@ from models.model import AppMode
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
@wraps_module.get_app_model
def handler(app_model):
@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
def handler(app_model):

View File

@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
from core.app.entities.task_entities import MessageEndStreamResponse
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from dify_graph.file.enums import FileTransferMethod
from dify_graph.file.enums import FileTransferMethod, FileType
from models.model import MessageFile, UploadFile
@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles:
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
message_file.upload_file_id = str(uuid.uuid4())
message_file.url = None
message_file.type = "image"
message_file.type = FileType.IMAGE
return message_file
@pytest.fixture
@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles:
message_file.transfer_method = FileTransferMethod.REMOTE_URL
message_file.upload_file_id = None
message_file.url = "https://example.com/image.jpg"
message_file.type = "image"
message_file.type = FileType.IMAGE
return message_file
@pytest.fixture
@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles:
message_file.transfer_method = FileTransferMethod.TOOL_FILE
message_file.upload_file_id = None
message_file.url = "tool_file_123.png"
message_file.type = "image"
message_file.type = FileType.IMAGE
return message_file
@pytest.fixture

View File

@ -4,6 +4,7 @@ import pytest
from models.account import Account
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
from models.enums import SegmentType
from services.dataset_service import SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
@ -77,7 +78,7 @@ class SegmentTestDataFactory:
chunk.word_count = word_count
chunk.index_node_id = f"node-{chunk_id}"
chunk.index_node_hash = "hash-123"
chunk.type = "automatic"
chunk.type = SegmentType.AUTOMATIC
chunk.created_by = "user-123"
chunk.updated_by = None
chunk.updated_at = None

View File

@ -1,224 +0,0 @@
from __future__ import annotations
import uuid
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from werkzeug.exceptions import BadRequest
from services.oauth_server import (
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
OAUTH_ACCESS_TOKEN_REDIS_KEY,
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
OAUTH_REFRESH_TOKEN_REDIS_KEY,
OAuthGrantType,
OAuthServerService,
)
@pytest.fixture
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
return mocker.patch("services.oauth_server.redis_client")
@pytest.fixture
def mock_session(mocker: MockerFixture) -> MagicMock:
"""Mock the OAuth server Session context manager."""
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
mocker.patch("services.oauth_server.Session", return_value=session_cm)
return session
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
# Arrange
mock_execute_result = MagicMock()
expected_app = MagicMock()
mock_execute_result.scalar_one_or_none.return_value = expected_app
mock_session.execute.return_value = mock_execute_result
# Act
result = OAuthServerService.get_oauth_provider_app("client-1")
# Assert
assert result is expected_app
mock_session.execute.assert_called_once()
mock_execute_result.scalar_one_or_none.assert_called_once()
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
# Act
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
# Assert
expected_code = str(deterministic_uuid)
assert code == expected_code
mock_redis_client.set.assert_called_once_with(
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
"user-1",
ex=600,
)
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act + Assert
with pytest.raises(BadRequest, match="invalid code"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="bad-code",
client_id="client-1",
)
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
token_uuids = [
uuid.UUID("00000000-0000-0000-0000-000000000201"),
uuid.UUID("00000000-0000-0000-0000-000000000202"),
]
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
mock_redis_client.get.return_value = b"user-1"
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
# Act
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
code="code-1",
client_id="client-1",
)
# Assert
assert access_token == str(token_uuids[0])
assert refresh_token == str(token_uuids[1])
mock_redis_client.delete.assert_called_once_with(code_key)
mock_redis_client.set.assert_any_call(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
mock_redis_client.set.assert_any_call(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
b"user-1",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act + Assert
with pytest.raises(BadRequest, match="invalid refresh token"):
OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="stale-token",
client_id="client-1",
)
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
mock_redis_client.get.return_value = b"user-1"
# Act
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type=OAuthGrantType.REFRESH_TOKEN,
refresh_token="refresh-1",
client_id="client-1",
)
# Assert
assert access_token == str(deterministic_uuid)
assert returned_refresh_token == "refresh-1"
mock_redis_client.set.assert_called_once_with(
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
b"user-1",
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
)
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
# Arrange
grant_type = cast(OAuthGrantType, "invalid-grant-type")
# Act
result = OAuthServerService.sign_oauth_access_token(
grant_type=grant_type,
client_id="client-1",
)
# Assert
assert result is None
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
# Act
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
# Assert
assert refresh_token == str(deterministic_uuid)
mock_redis_client.set.assert_called_once_with(
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
"user-2",
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
)
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
# Assert
assert result is None
def test_validate_oauth_access_token_should_load_user_when_token_exists(
mocker: MockerFixture, mock_redis_client: MagicMock
) -> None:
# Arrange
mock_redis_client.get.return_value = b"user-88"
expected_user = MagicMock()
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
# Act
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
# Assert
assert result is expected_user
mock_load_user.assert_called_once_with("user-88")

View File

@ -1,643 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.tools.entities.tool_entities import ApiProviderSchemaType
from services.tools.api_tools_manage_service import ApiToolManageService
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
# Arrange
mocked_db = mocker.patch("services.tools.api_tools_manage_service.db")
mocked_db.session = MagicMock()
return mocked_db
def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace:
return SimpleNamespace(operation_id=operation_id)
def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value),
)
# Act
result = ApiToolManageService.parser_api_schema("valid-schema")
# Assert
assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value
assert len(result["credentials_schema"]) == 3
assert "warning" in result
def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=RuntimeError("bad schema"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"):
ApiToolManageService.parser_api_schema("invalid")
def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None:
# Arrange
expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER)
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=expected,
)
extra_info: dict[str, str] = {}
# Act
result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info)
# Assert
assert result == expected
def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=ValueError("parse failed"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema: parse failed"):
ApiToolManageService.convert_schema_to_tool_bundles("schema")
def test_create_api_tool_provider_should_raise_error_when_provider_already_exists(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="provider provider-a already exists"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name=" provider-a ",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
many_tools = [_tool_bundle(str(i)) for i in range(101)]
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=(many_tools, ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="the number of apis should be less than 100"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=[],
)
def test_create_api_tool_provider_should_create_provider_when_input_is_valid(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
mock_controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=mock_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.encrypt.return_value = {"auth_type": "none"}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
# Act
result = ApiToolManageService.create_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
icon={"emoji": "X"},
credentials={"auth_type": "none"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=["news"],
)
# Assert
assert result == {"result": "success"}
mock_controller.load_bundled_tools.assert_called_once()
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid(
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.get",
return_value=SimpleNamespace(status_code=200, text="schema-content"),
)
mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True})
# Act
result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
# Assert
assert result == {"schema": "schema-content"}
@pytest.mark.parametrize("status_code", [400, 404, 500])
def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid(
status_code: int,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.get",
return_value=SimpleNamespace(status_code=status_code, text="schema-content"),
)
mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger")
# Act + Assert
with pytest.raises(ValueError, match="invalid schema, please check the url you provided"):
ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
mock_logger.exception.assert_called_once()
def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found(
mock_db: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="you have not added provider provider-a"):
ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")])
mock_db.session.query.return_value.where.return_value.first.return_value = provider
controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
return_value=controller,
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"])
mock_convert = mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}],
)
# Act
result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
# Assert
assert result == [{"name": "tool-a"}, {"name": "tool-b"}]
assert mock_convert.call_count == 2
def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found(
mock_db: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="api provider provider-a does not exists"):
ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
original_provider="provider-a",
icon={},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy=None,
custom_disclaimer="custom",
labels=[],
)
def test_update_api_tool_provider_should_raise_error_when_auth_type_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(credentials={}, name="old")
mock_db.session.query.return_value.where.return_value.first.return_value = provider
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-a",
original_provider="provider-a",
icon={},
credentials={},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy=None,
custom_disclaimer="custom",
labels=[],
)
def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider = SimpleNamespace(
credentials={"auth_type": "none", "api_key_value": "encrypted-old"},
name="old",
icon="",
schema="",
description="",
schema_type_str="",
tools_str="",
privacy_policy="",
custom_disclaimer="",
credentials_str="",
)
mock_db.session.query.return_value.where.return_value.first.return_value = provider
mocker.patch.object(
ApiToolManageService,
"convert_schema_to_tool_bundles",
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
)
controller = MagicMock()
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=controller,
)
cache = MagicMock()
encrypter = MagicMock()
encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"}
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(encrypter, cache),
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
# Act
result = ApiToolManageService.update_api_tool_provider(
user_id="user-1",
tenant_id="tenant-1",
provider_name="provider-new",
original_provider="provider-old",
icon={"emoji": "E"},
credentials={"auth_type": "none", "api_key_value": "***"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
privacy_policy="privacy",
custom_disclaimer="custom",
labels=["news"],
)
# Assert
assert result == {"result": "success"}
assert provider.name == "provider-new"
assert provider.privacy_policy == "privacy"
assert provider.credentials_str != ""
cache.delete.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None:
# Arrange
mock_db.session.query.return_value.where.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="you have not added provider provider-a"):
ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None:
# Arrange
provider = object()
mock_db.session.query.return_value.where.return_value.first.return_value = provider
# Act
result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
# Assert
assert result == {"result": "success"}
mock_db.session.delete.assert_called_once_with(provider)
mock_db.session.commit.assert_called_once()
def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None:
# Arrange
expected = {"provider": "value"}
mock_get = mocker.patch(
"services.tools.api_tools_manage_service.ToolManager.user_get_api_provider",
return_value=expected,
)
# Act
result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a")
# Assert
assert result == expected
mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1")
def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None:
# Arrange
schema_type = "bad-schema-type"
# Act + Assert
with pytest.raises(ValueError, match="invalid schema type"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=schema_type, # type: ignore[arg-type]
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
side_effect=RuntimeError("invalid"),
)
# Act + Assert
with pytest.raises(ValueError, match="invalid schema"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
# Act + Assert
with pytest.raises(ValueError, match="invalid tool name tool-b"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-b",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_raise_error_when_auth_type_missing(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
# Act + Assert
with pytest.raises(ValueError, match="auth_type is required"):
ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
provider_controller = MagicMock()
tool_obj = MagicMock()
tool_obj.fork_tool_runtime.return_value = tool_obj
tool_obj.validate_credentials.side_effect = ValueError("validation failed")
provider_controller.get_tool.return_value = tool_obj
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=provider_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
mock_encrypter.mask_plugin_credentials.return_value = {}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
# Act
result = ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
# Assert
assert result == {"error": "validation failed"}
def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
mocker.patch(
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
)
provider_controller = MagicMock()
tool_obj = MagicMock()
tool_obj.fork_tool_runtime.return_value = tool_obj
tool_obj.validate_credentials.return_value = {"ok": True}
provider_controller.get_tool.return_value = tool_obj
mocker.patch(
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
return_value=provider_controller,
)
mock_encrypter = MagicMock()
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
mock_encrypter.mask_plugin_credentials.return_value = {}
mocker.patch(
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
return_value=(mock_encrypter, MagicMock()),
)
# Act
result = ApiToolManageService.test_api_tool_preview(
tenant_id="tenant-1",
provider_name="provider-a",
tool_name="tool-a",
credentials={"auth_type": "none"},
parameters={"x": "1"},
schema_type=ApiProviderSchemaType.OPENAPI,
schema="schema",
)
# Assert
assert result == {"result": {"ok": True}}
def test_list_api_tools_should_return_all_user_providers_with_converted_tools(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
provider_one = SimpleNamespace(name="p1")
provider_two = SimpleNamespace(name="p2")
mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two]
controller_one = MagicMock()
controller_one.get_tools.return_value = ["tool-a"]
controller_two = MagicMock()
controller_two.get_tools.return_value = ["tool-b", "tool-c"]
user_provider_one = SimpleNamespace(labels=[], tools=[])
user_provider_two = SimpleNamespace(labels=[], tools=[])
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
side_effect=[controller_one, controller_two],
)
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"])
mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider",
side_effect=[user_provider_one, user_provider_two],
)
mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider")
mock_convert = mocker.patch(
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}],
)
# Act
result = ApiToolManageService.list_api_tools("tenant-1")
# Assert
assert len(result) == 2
assert user_provider_one.tools == [{"name": "tool-a"}]
assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}]
assert mock_convert.call_count == 3