mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into refactor/auth-services-typeddict
This commit is contained in:
commit
e04b839ae7
|
|
@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
|
|||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||
)
|
||||
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||
.subquery()
|
||||
)
|
||||
|
|
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
|
|||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
|
|||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
app = db.session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
return server
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
|
|
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
|
|||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
|
|
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
|
|
|||
|
|
@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
|
|||
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
|
||||
if original_app_model_config is None:
|
||||
raise ValueError("Original app model config not found")
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Literal
|
|||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
|
|
@ -75,7 +76,7 @@ class AppSite(Resource):
|
|||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
|
|
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
|||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
|
|
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
|
|||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
|
||||
return app_model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
|
|
@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||
password=new_cluster["password"],
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status="ACTIVE",
|
||||
status=TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from configs import dify_config
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
class TidbService:
|
||||
|
|
@ -170,7 +171,7 @@ class TidbService:
|
|||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = "ACTIVE"
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -43,7 +43,9 @@ from .enums import (
|
|||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
SummaryStatus,
|
||||
TidbAuthBindingStatus,
|
||||
)
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
|
|
@ -998,7 +1000,9 @@ class ChildChunk(Base):
|
|||
# indexing fields
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1239,7 +1243,9 @@ class TidbAuthBinding(TypeBase):
|
|||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
|
||||
status: Mapped[TidbAuthBindingStatus] = mapped_column(
|
||||
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
|
||||
)
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
|
|||
|
|
@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum):
|
|||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentType(StrEnum):
|
||||
"""Document segment type"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOMIZED = "customized"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from configs import dify_config
|
|||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
|
|
@ -1785,7 +1785,7 @@ class MessageFile(TypeBase):
|
|||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(
|
||||
EnumText(FileTransferMethod, length=255), nullable=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from configs import dify_config
|
|||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
|
|
@ -57,7 +58,7 @@ def create_clusters(batch_size):
|
|||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
active=False,
|
||||
status="CREATING",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
db.session.add(tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from configs import dify_config
|
|||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
|
|
@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
|
|||
try:
|
||||
# check the number of idle tidb serverless
|
||||
tidb_serverless_list = db.session.scalars(
|
||||
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
||||
select(TidbAuthBinding).where(
|
||||
TidbAuthBinding.active == False,
|
||||
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
).all()
|
||||
if len(tidb_serverless_list) == 0:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ from models.enums import (
|
|||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
|
@ -3786,7 +3787,7 @@ class SegmentService:
|
|||
child_chunk.word_count = len(child_chunk.content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = naive_utc_now()
|
||||
child_chunk.type = "customized"
|
||||
child_chunk.type = SegmentType.CUSTOMIZED
|
||||
update_child_chunks.append(child_chunk)
|
||||
else:
|
||||
new_child_chunks_args.append(child_chunk_update_args)
|
||||
|
|
@ -3845,7 +3846,7 @@ class SegmentService:
|
|||
child_chunk.word_count = len(content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = naive_utc_now()
|
||||
child_chunk.type = "customized"
|
||||
child_chunk.type = SegmentType.CUSTOMIZED
|
||||
db.session.add(child_chunk)
|
||||
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -1,17 +1,10 @@
|
|||
"""
|
||||
Test suite for password reset authentication flows.
|
||||
"""Testcontainers integration tests for password reset authentication flows."""
|
||||
|
||||
This module tests the password reset mechanism including:
|
||||
- Password reset email sending
|
||||
- Verification code validation
|
||||
- Password reset with token
|
||||
- Rate limiting and security checks
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
|
|
@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import (
|
|||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_session():
|
||||
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_session_cls.return_value.__exit__.return_value = None
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_db():
|
||||
with patch("controllers.console.auth.forgot_password.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
|
|
@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
|
|
@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi:
|
|||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test successful password reset email sending.
|
||||
|
||||
Verifies that:
|
||||
- Email is sent to valid account
|
||||
- Reset token is generated and returned
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "reset_token_123"
|
||||
|
|
@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi:
|
|||
assert response["data"] == "reset_token_123"
|
||||
mock_send_email.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
|
||||
"""
|
||||
Test password reset email blocked by IP rate limit.
|
||||
|
||||
|
|
@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
- No email is sent when rate limited
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
(None, "en-US"), # Defaults to en-US when not provided
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
|
|
@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
language_input,
|
||||
|
|
@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
- Unsupported languages default to en-US
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
|
|
@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi:
|
|||
"""Test cases for verifying password reset codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
|
|
@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""
|
||||
|
|
@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Rate limit is reset on success
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_generate_token.return_value = (None, "new_token")
|
||||
|
|
@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi:
|
|||
)
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
|
|
@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
|
@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token.assert_called_once_with("upper_token")
|
||||
mock_reset_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification blocked by rate limit.
|
||||
|
||||
|
|
@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Prevents brute force attacks on verification codes
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(EmailPasswordResetLimitError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with invalid token.
|
||||
|
||||
|
|
@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi:
|
|||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = None
|
||||
|
||||
|
|
@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with mismatched email.
|
||||
|
||||
|
|
@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Prevents token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
|
|
@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(InvalidEmailError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with incorrect code.
|
||||
|
||||
|
|
@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Rate limit counter is incremented
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
|
|
@ -380,11 +321,8 @@ class TestForgotPasswordResetApi:
|
|||
"""Test cases for resetting password with verified token."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
|
|
@ -394,7 +332,6 @@ class TestForgotPasswordResetApi:
|
|||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
|
|
@ -405,7 +342,6 @@ class TestForgotPasswordResetApi:
|
|||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
|
|
@ -418,7 +354,6 @@ class TestForgotPasswordResetApi:
|
|||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
|
@ -436,9 +371,8 @@ class TestForgotPasswordResetApi:
|
|||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with("valid_token")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_mismatch(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with mismatched passwords.
|
||||
|
||||
|
|
@ -447,7 +381,6 @@ class TestForgotPasswordResetApi:
|
|||
- No password update occurs
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -460,9 +393,8 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(PasswordMismatchError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_invalid_token(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with invalid token.
|
||||
|
||||
|
|
@ -470,7 +402,6 @@ class TestForgotPasswordResetApi:
|
|||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -483,9 +414,8 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with token not in reset phase.
|
||||
|
||||
|
|
@ -494,7 +424,6 @@ class TestForgotPasswordResetApi:
|
|||
- Prevents use of verification-phase tokens for reset
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -507,13 +436,10 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_reset_password_account_not_found(
|
||||
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
|
||||
):
|
||||
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
|
||||
|
|
@ -521,7 +447,6 @@ class TestForgotPasswordResetApi:
|
|||
- AccountNotFound is raised when account doesn't exist
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = None
|
||||
|
||||
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.file.enums import FileType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
|
@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
|
|||
# MessageFile
|
||||
file = MessageFile(
|
||||
message_id=message.id,
|
||||
type="image",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method="local_file",
|
||||
url="http://example.com/test.jpg",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,174 @@
|
|||
"""Testcontainers integration tests for OAuthServerService."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthServerServiceGetProviderApp:
|
||||
"""DB-backed tests for get_oauth_provider_app."""
|
||||
|
||||
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
|
||||
app = OAuthProviderApp(
|
||||
app_icon="icon.png",
|
||||
client_id=client_id,
|
||||
client_secret=str(uuid4()),
|
||||
app_label={"en-US": "Test OAuth App"},
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
scope="read",
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
|
||||
client_id = f"client-{uuid4()}"
|
||||
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
|
||||
|
||||
result = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.client_id == client_id
|
||||
assert result.id == created.id
|
||||
|
||||
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestOAuthServerServiceTokenOperations:
|
||||
"""Redis-backed tests for token sign/validate operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
with patch("services.oauth_server.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
def test_sign_authorization_code_stores_and_returns_code(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
assert code == str(deterministic_uuid)
|
||||
mock_redis.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis):
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids):
|
||||
mock_redis.get.return_value = b"user-1"
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
mock_redis.delete.assert_called_once_with(code_key)
|
||||
mock_redis.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
mock_redis.get.return_value = b"user-1"
|
||||
|
||||
access_token, returned_refresh = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh == "refresh-1"
|
||||
|
||||
def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis):
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis):
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid):
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
def test_validate_access_token_returns_none_when_not_found(self, mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_validate_access_token_loads_user_when_exists(self, mock_redis):
|
||||
mock_redis.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
|
||||
with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load:
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
assert result is expected_user
|
||||
mock_load.assert_called_once_with("user-88")
|
||||
|
|
@ -536,3 +536,151 @@ class TestApiToolManageService:
|
|||
# Verify mock interactions
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
|
||||
def test_delete_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test successful deletion of an API tool provider."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
schema = self._create_test_openapi_schema()
|
||||
provider_name = fake.unique.word()
|
||||
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
provider = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
deleted = (
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
assert deleted is None
|
||||
|
||||
def test_delete_api_tool_provider_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test deletion raises ValueError when provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_update_api_tool_provider_not_found(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test update raises ValueError when original provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name="new-name",
|
||||
original_provider="nonexistent",
|
||||
icon={},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=self._create_test_openapi_schema(),
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
def test_update_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test update raises ValueError when auth_type is missing from credentials."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
schema = self._create_test_openapi_schema()
|
||||
provider_name = fake.unique.word()
|
||||
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon={"content": "🔧", "background": "#FFF"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
original_provider=provider_name,
|
||||
icon={},
|
||||
credentials={},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema=schema,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
def test_list_api_tool_provider_tools_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test listing tools raises ValueError when provider not found."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent")
|
||||
|
||||
def test_test_api_tool_preview_invalid_schema_type(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Test preview raises ValueError for invalid schema type."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="invalid schema type"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id=tenant.id,
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type="bad-schema-type",
|
||||
schema="schema",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -281,12 +281,10 @@ class TestSiteEndpoints:
|
|||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
|
|
@ -305,12 +303,10 @@ class TestSiteEndpoints:
|
|||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
|
|
@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
|
|||
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = conversation
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = conversation
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
|
@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No
|
|||
|
||||
|
||||
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged():
|
|||
),
|
||||
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
|
||||
):
|
||||
mock_session.query.return_value.where.return_value.first.return_value = conversation
|
||||
mock_session.scalar.return_value = conversation
|
||||
|
||||
_get_conversation(app_model, "conversation-id")
|
||||
|
||||
|
|
|
|||
|
|
@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
|
|||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
|
|
@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
|||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
|
|
@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
|||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
|
|
@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
|||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
|
|
|
|||
|
|
@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
|
|||
)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = original_config
|
||||
session.query.return_value = query
|
||||
session.get.return_value = original_config
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
|
|
@ -11,10 +11,8 @@ from models.model import AppMode
|
|||
|
||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
|
|
@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
|
||||
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model))
|
||||
|
||||
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
||||
def handler(app_model):
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.app.entities.task_entities import MessageEndStreamResponse
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles:
|
|||
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.url = None
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles:
|
|||
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "https://example.com/image.jpg"
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles:
|
|||
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "tool_file_123.png"
|
||||
message_file.type = "image"
|
||||
message_file.type = FileType.IMAGE
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
|
||||
from models.account import Account
|
||||
from models.dataset import ChildChunk, Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentType
|
||||
from services.dataset_service import SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
|
|
@ -77,7 +78,7 @@ class SegmentTestDataFactory:
|
|||
chunk.word_count = word_count
|
||||
chunk.index_node_id = f"node-{chunk_id}"
|
||||
chunk.index_node_hash = "hash-123"
|
||||
chunk.type = "automatic"
|
||||
chunk.type = SegmentType.AUTOMATIC
|
||||
chunk.created_by = "user-123"
|
||||
chunk.updated_by = None
|
||||
chunk.updated_at = None
|
||||
|
|
|
|||
|
|
@ -1,224 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from services.oauth_server import (
|
||||
OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY,
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY,
|
||||
OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY,
|
||||
OAuthGrantType,
|
||||
OAuthServerService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.oauth_server.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mocker: MockerFixture) -> MagicMock:
|
||||
"""Mock the OAuth server Session context manager."""
|
||||
mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object()))
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
mocker.patch("services.oauth_server.Session", return_value=session_cm)
|
||||
return session
|
||||
|
||||
|
||||
def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_execute_result = MagicMock()
|
||||
expected_app = MagicMock()
|
||||
mock_execute_result.scalar_one_or_none.return_value = expected_app
|
||||
mock_session.execute.return_value = mock_execute_result
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.get_oauth_provider_app("client-1")
|
||||
|
||||
# Assert
|
||||
assert result is expected_app
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_execute_result.scalar_one_or_none.assert_called_once()
|
||||
|
||||
|
||||
def test_sign_oauth_authorization_code_should_store_code_and_return_value(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1")
|
||||
|
||||
# Assert
|
||||
expected_code = str(deterministic_uuid)
|
||||
assert code == expected_code
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code),
|
||||
"user-1",
|
||||
ex=600,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid code"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="bad-code",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
token_uuids = [
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000201"),
|
||||
uuid.UUID("00000000-0000-0000-0000-000000000202"),
|
||||
]
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1")
|
||||
|
||||
# Act
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
code="code-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(token_uuids[0])
|
||||
assert refresh_token == str(token_uuids[1])
|
||||
mock_redis_client.delete.assert_called_once_with(code_key)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
mock_redis_client.set.assert_any_call(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(BadRequest, match="invalid refresh token"):
|
||||
OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="stale-token",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
mock_redis_client.get.return_value = b"user-1"
|
||||
|
||||
# Act
|
||||
access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=OAuthGrantType.REFRESH_TOKEN,
|
||||
refresh_token="refresh-1",
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token == str(deterministic_uuid)
|
||||
assert returned_refresh_token == "refresh-1"
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token),
|
||||
b"user-1",
|
||||
ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None:
|
||||
# Arrange
|
||||
grant_type = cast(OAuthGrantType, "invalid-grant-type")
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type=grant_type,
|
||||
client_id="client-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401")
|
||||
mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid)
|
||||
|
||||
# Act
|
||||
refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2")
|
||||
|
||||
# Assert
|
||||
assert refresh_token == str(deterministic_uuid)
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token),
|
||||
"user-2",
|
||||
ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_return_none_when_token_not_found(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_validate_oauth_access_token_should_load_user_when_token_exists(
|
||||
mocker: MockerFixture, mock_redis_client: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = b"user-88"
|
||||
expected_user = MagicMock()
|
||||
mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user)
|
||||
|
||||
# Act
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
# Assert
|
||||
assert result is expected_user
|
||||
mock_load_user.assert_called_once_with("user-88")
|
||||
|
|
@ -1,643 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
# Arrange
|
||||
mocked_db = mocker.patch("services.tools.api_tools_manage_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace:
|
||||
return SimpleNamespace(operation_id=operation_id)
|
||||
|
||||
|
||||
def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.parser_api_schema("valid-schema")
|
||||
|
||||
# Assert
|
||||
assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value
|
||||
assert len(result["credentials_schema"]) == 3
|
||||
assert "warning" in result
|
||||
|
||||
|
||||
def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=RuntimeError("bad schema"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"):
|
||||
ApiToolManageService.parser_api_schema("invalid")
|
||||
|
||||
|
||||
def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER)
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=expected,
|
||||
)
|
||||
extra_info: dict[str, str] = {}
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=ValueError("parse failed"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema: parse failed"):
|
||||
ApiToolManageService.convert_schema_to_tool_bundles("schema")
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_provider_already_exists(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="provider provider-a already exists"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name=" provider-a ",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
many_tools = [_tool_bundle(str(i)) for i in range(101)]
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=(many_tools, ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="the number of apis should be less than 100"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_create_api_tool_provider_should_create_provider_when_input_is_valid(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=mock_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.encrypt.return_value = {"auth_type": "none"}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.create_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
icon={"emoji": "X"},
|
||||
credentials={"auth_type": "none"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=["news"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_controller.load_bundled_tools.assert_called_once()
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.get",
|
||||
return_value=SimpleNamespace(status_code=200, text="schema-content"),
|
||||
)
|
||||
mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True})
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
|
||||
|
||||
# Assert
|
||||
assert result == {"schema": "schema-content"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [400, 404, 500])
|
||||
def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid(
|
||||
status_code: int,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.get",
|
||||
return_value=SimpleNamespace(status_code=status_code, text="schema-content"),
|
||||
)
|
||||
mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema, please check the url you provided"):
|
||||
ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema")
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found(
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="you have not added provider provider-a"):
|
||||
ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
|
||||
|
||||
|
||||
def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")])
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
|
||||
return_value=controller,
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"])
|
||||
mock_convert = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
|
||||
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == [{"name": "tool-a"}, {"name": "tool-b"}]
|
||||
assert mock_convert.call_count == 2
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found(
|
||||
mock_db: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="api provider provider-a does not exists"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
original_provider="provider-a",
|
||||
icon={},
|
||||
credentials={"auth_type": "none"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_raise_error_when_auth_type_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(credentials={}, name="old")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
original_provider="provider-a",
|
||||
icon={},
|
||||
credentials={},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy=None,
|
||||
custom_disclaimer="custom",
|
||||
labels=[],
|
||||
)
|
||||
|
||||
|
||||
def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider = SimpleNamespace(
|
||||
credentials={"auth_type": "none", "api_key_value": "encrypted-old"},
|
||||
name="old",
|
||||
icon="",
|
||||
schema="",
|
||||
description="",
|
||||
schema_type_str="",
|
||||
tools_str="",
|
||||
privacy_policy="",
|
||||
custom_disclaimer="",
|
||||
credentials_str="",
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
mocker.patch.object(
|
||||
ApiToolManageService,
|
||||
"convert_schema_to_tool_bundles",
|
||||
return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
controller = MagicMock()
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=controller,
|
||||
)
|
||||
cache = MagicMock()
|
||||
encrypter = MagicMock()
|
||||
encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"}
|
||||
encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"}
|
||||
encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(encrypter, cache),
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels")
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.update_api_tool_provider(
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-new",
|
||||
original_provider="provider-old",
|
||||
icon={"emoji": "E"},
|
||||
credentials={"auth_type": "none", "api_key_value": "***"},
|
||||
_schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="custom",
|
||||
labels=["news"],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
assert provider.name == "provider-new"
|
||||
assert provider.privacy_policy == "privacy"
|
||||
assert provider.credentials_str != ""
|
||||
cache.delete.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="you have not added provider provider-a"):
|
||||
ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
|
||||
def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
provider = object()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = provider
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_db.session.delete.assert_called_once_with(provider)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
expected = {"provider": "value"}
|
||||
mock_get = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolManager.user_get_api_provider",
|
||||
return_value=expected,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1")
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None:
|
||||
# Arrange
|
||||
schema_type = "bad-schema-type"
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema type"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=schema_type, # type: ignore[arg-type]
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
side_effect=RuntimeError("invalid"),
|
||||
)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid schema"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="invalid tool name tool-b"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-b",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_raise_error_when_auth_type_missing(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id")
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
provider_controller = MagicMock()
|
||||
tool_obj = MagicMock()
|
||||
tool_obj.fork_tool_runtime.return_value = tool_obj
|
||||
tool_obj.validate_credentials.side_effect = ValueError("validation failed")
|
||||
provider_controller.get_tool.return_value = tool_obj
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=provider_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "validation failed"}
|
||||
|
||||
|
||||
def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"})
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle",
|
||||
return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI),
|
||||
)
|
||||
provider_controller = MagicMock()
|
||||
tool_obj = MagicMock()
|
||||
tool_obj.fork_tool_runtime.return_value = tool_obj
|
||||
tool_obj.validate_credentials.return_value = {"ok": True}
|
||||
provider_controller.get_tool.return_value = tool_obj
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ApiToolProviderController.from_db",
|
||||
return_value=provider_controller,
|
||||
)
|
||||
mock_encrypter = MagicMock()
|
||||
mock_encrypter.decrypt.return_value = {"auth_type": "none"}
|
||||
mock_encrypter.mask_plugin_credentials.return_value = {}
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.create_tool_provider_encrypter",
|
||||
return_value=(mock_encrypter, MagicMock()),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.test_api_tool_preview(
|
||||
tenant_id="tenant-1",
|
||||
provider_name="provider-a",
|
||||
tool_name="tool-a",
|
||||
credentials={"auth_type": "none"},
|
||||
parameters={"x": "1"},
|
||||
schema_type=ApiProviderSchemaType.OPENAPI,
|
||||
schema="schema",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": {"ok": True}}
|
||||
|
||||
|
||||
def test_list_api_tools_should_return_all_user_providers_with_converted_tools(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
provider_one = SimpleNamespace(name="p1")
|
||||
provider_two = SimpleNamespace(name="p2")
|
||||
mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two]
|
||||
|
||||
controller_one = MagicMock()
|
||||
controller_one.get_tools.return_value = ["tool-a"]
|
||||
controller_two = MagicMock()
|
||||
controller_two.get_tools.return_value = ["tool-b", "tool-c"]
|
||||
|
||||
user_provider_one = SimpleNamespace(labels=[], tools=[])
|
||||
user_provider_two = SimpleNamespace(labels=[], tools=[])
|
||||
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller",
|
||||
side_effect=[controller_one, controller_two],
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"])
|
||||
mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider",
|
||||
side_effect=[user_provider_one, user_provider_two],
|
||||
)
|
||||
mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider")
|
||||
mock_convert = mocker.patch(
|
||||
"services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity",
|
||||
side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}],
|
||||
)
|
||||
|
||||
# Act
|
||||
result = ApiToolManageService.list_api_tools("tenant-1")
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert user_provider_one.tools == [{"name": "tool-a"}]
|
||||
assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}]
|
||||
assert mock_convert.call_count == 3
|
||||
Loading…
Reference in New Issue