mirror of https://github.com/langgenius/dify.git
Merge 75273b4be5 into 508350ec6a
This commit is contained in:
commit
07269e9013
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Literal
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal
|
||||
|
|
@ -9,6 +9,7 @@ from controllers.common.schema import register_schema_model, register_schema_mod
|
|||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
|
|
@ -55,7 +56,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
|||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return marshal(metadata, dataset_metadata_fields), 201
|
||||
|
|
@ -102,7 +103,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
|||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
|
@ -125,7 +126,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
|||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return "", 204
|
||||
|
|
@ -166,7 +167,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
|||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||
|
||||
match action:
|
||||
case "enable":
|
||||
|
|
@ -196,7 +197,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
|||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||
|
||||
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import tempfile
|
|||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
|
@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
|||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
|
|
@ -110,7 +111,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
|||
class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
def __init__(self, config: ClickZettaVolumeConfig):
|
||||
def __init__(self, config: ClickZettaVolumeConfig) -> None:
|
||||
"""Initialize ClickZetta Volume storage.
|
||||
|
||||
Args:
|
||||
|
|
@ -124,7 +125,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
|
||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||
|
||||
def _init_connection(self):
|
||||
def _init_connection(self) -> None:
|
||||
"""Initialize ClickZetta connection."""
|
||||
try:
|
||||
self._connection = clickzetta.connect(
|
||||
|
|
@ -141,7 +142,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
logger.exception("Failed to connect to ClickZetta")
|
||||
raise
|
||||
|
||||
def _init_permission_manager(self):
|
||||
def _init_permission_manager(self) -> None:
|
||||
"""Initialize permission manager."""
|
||||
try:
|
||||
self._permission_manager = VolumePermissionManager(
|
||||
|
|
@ -201,7 +202,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
||||
def _execute_sql(self, sql: str, fetch: bool = False) -> Any:
|
||||
"""Execute SQL command."""
|
||||
try:
|
||||
if self._connection is None:
|
||||
|
|
@ -215,7 +216,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
logger.exception("SQL execution failed: %s", sql)
|
||||
raise
|
||||
|
||||
def _ensure_table_volume_exists(self, dataset_id: str):
|
||||
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
|
||||
"""Ensure table volume exists for the given dataset_id."""
|
||||
if self._config.volume_type != "table" or not dataset_id:
|
||||
return
|
||||
|
|
@ -250,7 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
def save(self, filename: str, data: bytes):
|
||||
def save(self, filename: str, data: bytes) -> None:
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
|
|
@ -381,7 +382,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
def download(self, filename: str, target_filepath: str) -> None:
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
Args:
|
||||
|
|
@ -435,7 +436,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
|||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
def delete(self, filename: str):
|
||||
def delete(self, filename: str) -> None:
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class FileMetadata:
|
|||
tags: dict[str, str] | None = None
|
||||
parent_version: int | None = None
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format"""
|
||||
data = asdict(self)
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
|
|
@ -62,7 +62,7 @@ class FileMetadata:
|
|||
class FileLifecycleManager:
|
||||
"""File lifecycle manager"""
|
||||
|
||||
def __init__(self, storage, dataset_id: str | None = None):
|
||||
def __init__(self, storage: Any, dataset_id: str | None = None) -> None:
|
||||
"""Initialize lifecycle manager
|
||||
|
||||
Args:
|
||||
|
|
@ -435,7 +435,7 @@ class FileLifecycleManager:
|
|||
logger.exception("Failed to get storage statistics")
|
||||
return {}
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
def _create_version_backup(self, filename: str, metadata: dict[str, Any]) -> None:
|
||||
"""Create version backup"""
|
||||
try:
|
||||
# Read current file content
|
||||
|
|
@ -463,7 +463,7 @@ class FileLifecycleManager:
|
|||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
def _save_metadata(self, metadata_dict: dict[str, Any]) -> None:
|
||||
"""Save metadata file"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ According to ClickZetta's permission model, different Volume types have differen
|
|||
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -23,7 +24,9 @@ class VolumePermission(StrEnum):
|
|||
class VolumePermissionManager:
|
||||
"""Volume permission manager"""
|
||||
|
||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None):
|
||||
def __init__(
|
||||
self, connection_or_config: Any, volume_type: str | None = None, volume_name: str | None = None
|
||||
) -> None:
|
||||
"""Initialize permission manager
|
||||
|
||||
Args:
|
||||
|
|
@ -434,7 +437,7 @@ class VolumePermissionManager:
|
|||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def clear_permission_cache(self):
|
||||
def clear_permission_cache(self) -> None:
|
||||
"""Clear permission cache"""
|
||||
self._permission_cache.clear()
|
||||
logger.debug("Permission cache cleared")
|
||||
|
|
@ -625,7 +628,9 @@ class VolumePermissionError(Exception):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None):
|
||||
def check_volume_permission(
|
||||
permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None
|
||||
) -> Any:
|
||||
"""Permission check decorator function
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -109,7 +109,15 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
def get_datasets(
|
||||
page: int,
|
||||
per_page: int,
|
||||
tenant_id: str | None = None,
|
||||
user: Account | None = None,
|
||||
search: str | None = None,
|
||||
tag_ids: list[str] | None = None,
|
||||
include_all: bool = False,
|
||||
) -> tuple[list[Dataset], int]:
|
||||
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
|
||||
|
||||
if user:
|
||||
|
|
@ -177,10 +185,10 @@ class DatasetService:
|
|||
|
||||
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
|
||||
return datasets.items, datasets.total
|
||||
return datasets.items, datasets.total or 0
|
||||
|
||||
@staticmethod
|
||||
def get_process_rules(dataset_id):
|
||||
def get_process_rules(dataset_id: str) -> dict[str, Any]:
|
||||
# get the latest process rule
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
|
|
@ -198,7 +206,7 @@ class DatasetService:
|
|||
return {"mode": mode, "rules": rules}
|
||||
|
||||
@staticmethod
|
||||
def get_datasets_by_ids(ids, tenant_id):
|
||||
def get_datasets_by_ids(ids: list[str] | None, tenant_id: str) -> tuple[list[Dataset], int]:
|
||||
# Check if ids is not empty to avoid WHERE false condition
|
||||
if not ids or len(ids) == 0:
|
||||
return [], 0
|
||||
|
|
@ -206,7 +214,7 @@ class DatasetService:
|
|||
|
||||
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
||||
|
||||
return datasets.items, datasets.total
|
||||
return datasets.items, datasets.total or 0
|
||||
|
||||
@staticmethod
|
||||
def create_empty_dataset(
|
||||
|
|
@ -223,7 +231,7 @@ class DatasetService:
|
|||
embedding_model_name: str | None = None,
|
||||
retrieval_model: RetrievalModel | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
|
|
@ -292,7 +300,7 @@ class DatasetService:
|
|||
def create_empty_rag_pipeline_dataset(
|
||||
tenant_id: str,
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
) -> Dataset:
|
||||
if rag_pipeline_dataset_create_entity.name:
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
|
|
@ -338,17 +346,17 @@ class DatasetService:
|
|||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Dataset | None:
|
||||
def get_dataset(dataset_id: str) -> Dataset | None:
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def check_doc_form(dataset: Dataset, doc_form: str):
|
||||
def check_doc_form(dataset: Dataset, doc_form: str) -> None:
|
||||
if dataset.doc_form and doc_form != dataset.doc_form:
|
||||
raise ValueError("doc_form is different from the dataset doc_form.")
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_model_setting(dataset):
|
||||
def check_dataset_model_setting(dataset: Dataset) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
|
|
@ -366,7 +374,7 @@ class DatasetService:
|
|||
raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
|
||||
|
||||
@staticmethod
|
||||
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
|
||||
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str) -> None:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
|
|
@ -383,7 +391,7 @@ class DatasetService:
|
|||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
|
||||
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str) -> bool:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
|
|
@ -404,7 +412,7 @@ class DatasetService:
|
|||
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
|
||||
|
||||
@staticmethod
|
||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str) -> None:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
|
|
@ -421,7 +429,7 @@ class DatasetService:
|
|||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
def update_dataset(dataset_id: str, data: dict[str, Any], user: Account) -> Dataset:
|
||||
"""
|
||||
Update dataset configuration and settings.
|
||||
|
||||
|
|
@ -460,7 +468,7 @@ class DatasetService:
|
|||
return DatasetService._update_internal_dataset(dataset, data, user)
|
||||
|
||||
@staticmethod
|
||||
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
|
||||
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str) -> bool:
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
|
|
@ -473,7 +481,7 @@ class DatasetService:
|
|||
return dataset is not None
|
||||
|
||||
@staticmethod
|
||||
def _update_external_dataset(dataset, data, user):
|
||||
def _update_external_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset:
|
||||
"""
|
||||
Update external dataset configuration.
|
||||
|
||||
|
|
@ -486,12 +494,12 @@ class DatasetService:
|
|||
Dataset: Updated dataset object
|
||||
"""
|
||||
# Update retrieval model if provided
|
||||
external_retrieval_model = data.get("external_retrieval_model", None)
|
||||
external_retrieval_model = data.get("external_retrieval_model")
|
||||
if external_retrieval_model:
|
||||
dataset.retrieval_model = external_retrieval_model
|
||||
|
||||
# Update summary index setting if provided
|
||||
summary_index_setting = data.get("summary_index_setting", None)
|
||||
summary_index_setting = data.get("summary_index_setting")
|
||||
if summary_index_setting is not None:
|
||||
dataset.summary_index_setting = summary_index_setting
|
||||
|
||||
|
|
@ -505,8 +513,8 @@ class DatasetService:
|
|||
dataset.permission = permission
|
||||
|
||||
# Validate and update external knowledge configuration
|
||||
external_knowledge_id = data.get("external_knowledge_id", None)
|
||||
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
||||
external_knowledge_id = data.get("external_knowledge_id")
|
||||
external_knowledge_api_id = data.get("external_knowledge_api_id")
|
||||
|
||||
if not external_knowledge_id:
|
||||
raise ValueError("External knowledge id is required.")
|
||||
|
|
@ -526,7 +534,9 @@ class DatasetService:
|
|||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
|
||||
def _update_external_knowledge_binding(
|
||||
dataset_id: str, external_knowledge_id: str, external_knowledge_api_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Update external knowledge binding configuration.
|
||||
|
||||
|
|
@ -553,7 +563,7 @@ class DatasetService:
|
|||
db.session.add(external_knowledge_binding)
|
||||
|
||||
@staticmethod
|
||||
def _update_internal_dataset(dataset, data, user):
|
||||
def _update_internal_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset:
|
||||
"""
|
||||
Update internal dataset configuration.
|
||||
|
||||
|
|
@ -591,7 +601,7 @@ class DatasetService:
|
|||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
|
||||
# Update dataset in database
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) # type: ignore[arg-type]
|
||||
db.session.commit()
|
||||
|
||||
# Reload dataset to get updated values
|
||||
|
|
@ -619,7 +629,7 @@ class DatasetService:
|
|||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str):
|
||||
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str) -> None:
|
||||
"""
|
||||
Update pipeline knowledge base node data.
|
||||
"""
|
||||
|
|
@ -702,7 +712,9 @@ class DatasetService:
|
|||
raise
|
||||
|
||||
@staticmethod
|
||||
def _handle_indexing_technique_change(dataset, data, filtered_data):
|
||||
def _handle_indexing_technique_change(
|
||||
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
|
||||
) -> str | None:
|
||||
"""
|
||||
Handle changes in indexing technique and configure embedding models accordingly.
|
||||
|
||||
|
|
@ -733,7 +745,7 @@ class DatasetService:
|
|||
return None
|
||||
|
||||
@staticmethod
|
||||
def _configure_embedding_model_for_high_quality(data, filtered_data):
|
||||
def _configure_embedding_model_for_high_quality(data: dict[str, Any], filtered_data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Configure embedding model settings for high quality indexing.
|
||||
|
||||
|
|
@ -768,7 +780,9 @@ class DatasetService:
|
|||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
|
||||
def _handle_embedding_model_update_when_technique_unchanged(
|
||||
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
|
||||
) -> str | None:
|
||||
"""
|
||||
Handle embedding model updates when indexing technique remains the same.
|
||||
|
||||
|
|
@ -793,7 +807,7 @@ class DatasetService:
|
|||
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
|
||||
|
||||
@staticmethod
|
||||
def _preserve_existing_embedding_settings(dataset, filtered_data):
|
||||
def _preserve_existing_embedding_settings(dataset: Dataset, filtered_data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Preserve existing embedding model settings when not provided in update.
|
||||
|
||||
|
|
@ -816,7 +830,9 @@ class DatasetService:
|
|||
del filtered_data["embedding_model"]
|
||||
|
||||
@staticmethod
|
||||
def _update_embedding_model_settings(dataset, data, filtered_data):
|
||||
def _update_embedding_model_settings(
|
||||
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
|
||||
) -> str | None:
|
||||
"""
|
||||
Update embedding model settings with new values.
|
||||
|
||||
|
|
@ -850,7 +866,7 @@ class DatasetService:
|
|||
return None
|
||||
|
||||
@staticmethod
|
||||
def _apply_new_embedding_settings(dataset, data, filtered_data):
|
||||
def _apply_new_embedding_settings(dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]) -> None:
|
||||
"""
|
||||
Apply new embedding model settings to the dataset.
|
||||
|
||||
|
|
@ -947,7 +963,7 @@ class DatasetService:
|
|||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
):
|
||||
) -> None:
|
||||
if not current_user or not current_user.current_tenant_id:
|
||||
raise ValueError("Current user or current tenant not found")
|
||||
dataset = session.merge(dataset)
|
||||
|
|
@ -1102,7 +1118,7 @@ class DatasetService:
|
|||
deal_dataset_index_update_task.delay(dataset.id, action)
|
||||
|
||||
@staticmethod
|
||||
def delete_dataset(dataset_id, user):
|
||||
def delete_dataset(dataset_id: str, user: Account) -> bool:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
if dataset is None:
|
||||
|
|
@ -1117,12 +1133,14 @@ class DatasetService:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
def dataset_use_check(dataset_id) -> bool:
|
||||
def dataset_use_check(dataset_id: str) -> bool:
|
||||
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_permission(dataset, user):
|
||||
def check_dataset_permission(dataset: Dataset, user: Account | None) -> None:
|
||||
if not user:
|
||||
raise NoPermissionError("User not found.")
|
||||
if dataset.tenant_id != user.current_tenant_id:
|
||||
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
|
@ -1141,7 +1159,7 @@ class DatasetService:
|
|||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None):
|
||||
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None) -> None:
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
|
|
@ -1161,15 +1179,15 @@ class DatasetService:
|
|||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
||||
def get_dataset_queries(dataset_id: str, page: int, per_page: int) -> tuple[list[DatasetQuery], int]:
|
||||
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
|
||||
|
||||
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
|
||||
return dataset_queries.items, dataset_queries.total
|
||||
return dataset_queries.items, dataset_queries.total or 0
|
||||
|
||||
@staticmethod
|
||||
def get_related_apps(dataset_id: str):
|
||||
def get_related_apps(dataset_id: str) -> list[AppDatasetJoin]:
|
||||
return (
|
||||
db.session.query(AppDatasetJoin)
|
||||
.where(AppDatasetJoin.dataset_id == dataset_id)
|
||||
|
|
@ -1178,7 +1196,7 @@ class DatasetService:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def update_dataset_api_status(dataset_id: str, status: bool):
|
||||
def update_dataset_api_status(dataset_id: str, status: bool) -> None:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
@ -1190,7 +1208,7 @@ class DatasetService:
|
|||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str):
|
||||
def get_dataset_auto_disable_logs(dataset_id: str) -> dict[str, Any]:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
|
@ -1289,7 +1307,7 @@ class DocumentService:
|
|||
return cls.DISPLAY_STATUS_FILTERS[normalized]
|
||||
|
||||
@classmethod
|
||||
def apply_display_status_filter(cls, query, status: str | None):
|
||||
def apply_display_status_filter(cls, query: Any, status: str | None) -> Any:
|
||||
filters = cls.build_display_status_filters(status)
|
||||
if not filters:
|
||||
return query
|
||||
|
|
@ -1685,19 +1703,19 @@ class DocumentService:
|
|||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_document_file_detail(file_id: str):
|
||||
def get_document_file_detail(file_id: str) -> UploadFile | None:
|
||||
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
|
||||
return file_detail
|
||||
|
||||
@staticmethod
|
||||
def check_archived(document):
|
||||
def check_archived(document: Document) -> bool:
|
||||
if document.archived:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def delete_document(document):
|
||||
def delete_document(document: Document) -> None:
|
||||
# trigger document_was_deleted signal
|
||||
file_id = None
|
||||
if document.data_source_type == DataSourceType.UPLOAD_FILE:
|
||||
|
|
@ -1713,7 +1731,7 @@ class DocumentService:
|
|||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_documents(dataset: Dataset, document_ids: list[str]):
|
||||
def delete_documents(dataset: Dataset, document_ids: list[str]) -> None:
|
||||
# Check if document_ids is not empty to avoid WHERE false condition
|
||||
if not document_ids or len(document_ids) == 0:
|
||||
return
|
||||
|
|
@ -1769,7 +1787,7 @@ class DocumentService:
|
|||
return document
|
||||
|
||||
@staticmethod
|
||||
def pause_document(document):
|
||||
def pause_document(document: Document) -> None:
|
||||
if document.indexing_status not in {
|
||||
IndexingStatus.WAITING,
|
||||
IndexingStatus.PARSING,
|
||||
|
|
@ -1791,7 +1809,7 @@ class DocumentService:
|
|||
redis_client.setnx(indexing_cache_key, "True")
|
||||
|
||||
@staticmethod
|
||||
def recover_document(document):
|
||||
def recover_document(document: Document) -> None:
|
||||
if not document.is_paused:
|
||||
raise DocumentIndexingError()
|
||||
# update document to be recover
|
||||
|
|
@ -1808,7 +1826,7 @@ class DocumentService:
|
|||
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
||||
|
||||
@staticmethod
|
||||
def retry_document(dataset_id: str, documents: list[Document]):
|
||||
def retry_document(dataset_id: str, documents: list[Document]) -> None:
|
||||
for document in documents:
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = f"document_{document.id}_is_retried"
|
||||
|
|
@ -1828,7 +1846,7 @@ class DocumentService:
|
|||
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
|
||||
|
||||
@staticmethod
|
||||
def sync_website_document(dataset_id: str, document: Document):
|
||||
def sync_website_document(dataset_id: str, document: Document) -> None:
|
||||
# add sync flag
|
||||
sync_indexing_cache_key = f"document_{document.id}_is_sync"
|
||||
cache_result = redis_client.get(sync_indexing_cache_key)
|
||||
|
|
@ -1848,7 +1866,7 @@ class DocumentService:
|
|||
sync_website_document_indexing_task.delay(dataset_id, document.id)
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
def get_documents_position(dataset_id: str) -> int:
|
||||
document = (
|
||||
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -97,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
|||
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
|
||||
|
||||
|
||||
def _delete_app_model_configs(tenant_id: str, app_id: str):
|
||||
def del_model_config(session, model_config_id: str):
|
||||
def _delete_app_model_configs(tenant_id: str, app_id: str) -> None:
|
||||
def del_model_config(session: Any, model_config_id: str) -> None:
|
||||
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -109,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_site(tenant_id: str, app_id: str):
|
||||
def del_site(session, site_id: str):
|
||||
def _delete_app_site(tenant_id: str, app_id: str) -> None:
|
||||
def del_site(session: Any, site_id: str) -> None:
|
||||
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -121,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
def del_mcp_server(session, mcp_server_id: str):
|
||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str) -> None:
|
||||
def del_mcp_server(session: Any, mcp_server_id: str) -> None:
|
||||
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -133,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
def del_api_token(session, api_token_id: str):
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str) -> None:
|
||||
def del_api_token(session: Any, api_token_id: str) -> None:
|
||||
# Fetch token details for cache invalidation
|
||||
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
|
||||
if token_obj:
|
||||
|
|
@ -151,8 +151,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_installed_apps(tenant_id: str, app_id: str):
|
||||
def del_installed_app(session, installed_app_id: str):
|
||||
def _delete_installed_apps(tenant_id: str, app_id: str) -> None:
|
||||
def del_installed_app(session: Any, installed_app_id: str) -> None:
|
||||
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -163,8 +163,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_recommended_apps(tenant_id: str, app_id: str):
|
||||
def del_recommended_app(session, recommended_app_id: str):
|
||||
def _delete_recommended_apps(tenant_id: str, app_id: str) -> None:
|
||||
def del_recommended_app(session: Any, recommended_app_id: str) -> None:
|
||||
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -175,8 +175,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
def del_annotation_hit_history(session, annotation_hit_history_id: str):
|
||||
def _delete_app_annotation_data(tenant_id: str, app_id: str) -> None:
|
||||
def del_annotation_hit_history(session: Any, annotation_hit_history_id: str) -> None:
|
||||
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
|
@ -188,7 +188,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
|||
"annotation hit history",
|
||||
)
|
||||
|
||||
def del_annotation_setting(session, annotation_setting_id: str):
|
||||
def del_annotation_setting(session: Any, annotation_setting_id: str) -> None:
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
|
@ -201,8 +201,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
||||
def del_dataset_join(session, dataset_join_id: str):
|
||||
def _delete_app_dataset_joins(tenant_id: str, app_id: str) -> None:
|
||||
def del_dataset_join(session: Any, dataset_join_id: str) -> None:
|
||||
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -213,8 +213,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_workflows(tenant_id: str, app_id: str):
|
||||
def del_workflow(session, workflow_id: str):
|
||||
def _delete_app_workflows(tenant_id: str, app_id: str) -> None:
|
||||
def del_workflow(session: Any, workflow_id: str) -> None:
|
||||
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -225,7 +225,7 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
||||
def _delete_app_workflow_runs(tenant_id: str, app_id: str) -> None:
|
||||
"""Delete all workflow runs for an app using the service repository."""
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
|
@ -239,7 +239,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
|||
logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id)
|
||||
|
||||
|
||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str) -> None:
|
||||
"""Delete all workflow node executions for an app using the service repository."""
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
||||
|
|
@ -253,8 +253,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
|||
logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id)
|
||||
|
||||
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_app_log(session, workflow_app_log_id: str):
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str) -> None:
|
||||
def del_workflow_app_log(session: Any, workflow_app_log_id: str) -> None:
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -265,8 +265,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_archive_log(session, workflow_archive_log_id: str):
|
||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str) -> None:
|
||||
def del_workflow_archive_log(session: Any, workflow_archive_log_id: str) -> None:
|
||||
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
|
@ -279,7 +279,7 @@ def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
|
||||
def _delete_archived_workflow_run_files(tenant_id: str, app_id: str) -> None:
|
||||
prefix = f"{tenant_id}/app_id={app_id}/"
|
||||
try:
|
||||
archive_storage = get_archive_storage()
|
||||
|
|
@ -304,8 +304,8 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
|
|||
logger.info("Deleted %s archive objects for app %s", deleted, app_id)
|
||||
|
||||
|
||||
def _delete_app_conversations(tenant_id: str, app_id: str):
|
||||
def del_conversation(session, conversation_id: str):
|
||||
def _delete_app_conversations(tenant_id: str, app_id: str) -> None:
|
||||
def del_conversation(session: Any, conversation_id: str) -> None:
|
||||
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
|
@ -319,7 +319,7 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_conversation_variables(*, app_id: str):
|
||||
def _delete_conversation_variables(app_id: str) -> None:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
|
||||
session.execute(stmt)
|
||||
|
|
@ -327,8 +327,8 @@ def _delete_conversation_variables(*, app_id: str):
|
|||
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
|
||||
|
||||
|
||||
def _delete_app_messages(tenant_id: str, app_id: str):
|
||||
def del_message(session, message_id: str):
|
||||
def _delete_app_messages(tenant_id: str, app_id: str) -> None:
|
||||
def del_message(session: Any, message_id: str) -> None:
|
||||
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
|
|
@ -349,8 +349,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
||||
def del_tool_provider(session, tool_provider_id: str):
|
||||
def _delete_workflow_tool_providers(tenant_id: str, app_id: str) -> None:
|
||||
def del_tool_provider(session: Any, tool_provider_id: str) -> None:
|
||||
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
|
@ -363,8 +363,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
||||
def del_tag_binding(session, tag_binding_id: str):
|
||||
def _delete_app_tag_bindings(tenant_id: str, app_id: str) -> None:
|
||||
def del_tag_binding(session: Any, tag_binding_id: str) -> None:
|
||||
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -375,8 +375,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_end_users(tenant_id: str, app_id: str):
|
||||
def del_end_user(session, end_user_id: str):
|
||||
def _delete_end_users(tenant_id: str, app_id: str) -> None:
|
||||
def del_end_user(session: Any, end_user_id: str) -> None:
|
||||
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -387,8 +387,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
||||
def del_trace_app_config(session, trace_app_config_id: str):
|
||||
def _delete_trace_app_configs(tenant_id: str, app_id: str) -> None:
|
||||
def del_trace_app_config(session: Any, trace_app_config_id: str) -> None:
|
||||
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -399,9 +399,9 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_draft_variables(app_id: str):
|
||||
def _delete_draft_variables(app_id: str) -> None:
|
||||
"""Delete all workflow draft variables for an app in batches."""
|
||||
return delete_draft_variables_batch(app_id, batch_size=1000)
|
||||
delete_draft_variables_batch(app_id, batch_size=1000)
|
||||
|
||||
|
||||
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
|
|
@ -543,8 +543,8 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
|
|||
return files_deleted
|
||||
|
||||
|
||||
def _delete_app_triggers(tenant_id: str, app_id: str):
|
||||
def del_app_trigger(session, trigger_id: str):
|
||||
def _delete_app_triggers(tenant_id: str, app_id: str) -> None:
|
||||
def del_app_trigger(session: Any, trigger_id: str) -> None:
|
||||
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -555,8 +555,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
||||
def del_plugin_trigger(session, trigger_id: str):
|
||||
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str) -> None:
|
||||
def del_plugin_trigger(session: Any, trigger_id: str) -> None:
|
||||
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
|
@ -569,8 +569,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
||||
def del_webhook_trigger(session, trigger_id: str):
|
||||
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str) -> None:
|
||||
def del_webhook_trigger(session: Any, trigger_id: str) -> None:
|
||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
|
@ -583,8 +583,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
||||
def del_schedule_plan(session, plan_id: str):
|
||||
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str) -> None:
|
||||
def del_schedule_plan(session: Any, plan_id: str) -> None:
|
||||
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
@ -595,8 +595,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
|
||||
def del_trigger_log(session, log_id: str):
|
||||
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str) -> None:
|
||||
def del_trigger_log(session: Any, log_id: str) -> None:
|
||||
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
|
|
|||
|
|
@ -28,12 +28,11 @@ class TestDeleteDraftVariablesBatch:
|
|||
def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete):
|
||||
"""Test that _delete_draft_variables calls the batch function correctly."""
|
||||
app_id = "test-app-id"
|
||||
expected_return = 42
|
||||
mock_batch_delete.return_value = expected_return
|
||||
mock_batch_delete.return_value = 42
|
||||
|
||||
result = _delete_draft_variables(app_id)
|
||||
|
||||
assert result == expected_return
|
||||
assert result is None
|
||||
mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue