fix: missing types

[autofix.ci] apply automated fixes

fix/missing import types

fix/type checking error

fix: api test
This commit is contained in:
ckstck 2026-03-23 07:54:26 -07:00
parent f5cc1c8b75
commit 75273b4be5
7 changed files with 150 additions and 126 deletions

View File

@ -1,4 +1,4 @@
from typing import Literal from typing import Literal, cast
from flask_login import current_user from flask_login import current_user
from flask_restx import marshal from flask_restx import marshal
@ -9,6 +9,7 @@ from controllers.common.schema import register_schema_model, register_schema_mod
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation, DocumentMetadataOperation,
@ -55,7 +56,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return marshal(metadata, dataset_metadata_fields), 201 return marshal(metadata, dataset_metadata_fields), 201
@ -102,7 +103,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200 return marshal(metadata, dataset_metadata_fields), 200
@ -125,7 +126,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
MetadataService.delete_metadata(dataset_id_str, metadata_id_str) MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return "", 204 return "", 204
@ -166,7 +167,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
match action: match action:
case "enable": case "enable":
@ -196,7 +197,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {}) metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})

View File

@ -10,6 +10,7 @@ import tempfile
from collections.abc import Generator from collections.abc import Generator
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any
import clickzetta import clickzetta
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_config(cls, values: dict): def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate the configuration values. """Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables, This method will first try to use CLICKZETTA_VOLUME_* environment variables,
@ -110,7 +111,7 @@ class ClickZettaVolumeConfig(BaseModel):
class ClickZettaVolumeStorage(BaseStorage): class ClickZettaVolumeStorage(BaseStorage):
"""ClickZetta Volume storage implementation.""" """ClickZetta Volume storage implementation."""
def __init__(self, config: ClickZettaVolumeConfig): def __init__(self, config: ClickZettaVolumeConfig) -> None:
"""Initialize ClickZetta Volume storage. """Initialize ClickZetta Volume storage.
Args: Args:
@ -124,7 +125,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type) logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
def _init_connection(self): def _init_connection(self) -> None:
"""Initialize ClickZetta connection.""" """Initialize ClickZetta connection."""
try: try:
self._connection = clickzetta.connect( self._connection = clickzetta.connect(
@ -141,7 +142,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.exception("Failed to connect to ClickZetta") logger.exception("Failed to connect to ClickZetta")
raise raise
def _init_permission_manager(self): def _init_permission_manager(self) -> None:
"""Initialize permission manager.""" """Initialize permission manager."""
try: try:
self._permission_manager = VolumePermissionManager( self._permission_manager = VolumePermissionManager(
@ -201,7 +202,7 @@ class ClickZettaVolumeStorage(BaseStorage):
else: else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}") raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _execute_sql(self, sql: str, fetch: bool = False): def _execute_sql(self, sql: str, fetch: bool = False) -> Any:
"""Execute SQL command.""" """Execute SQL command."""
try: try:
if self._connection is None: if self._connection is None:
@ -215,7 +216,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.exception("SQL execution failed: %s", sql) logger.exception("SQL execution failed: %s", sql)
raise raise
def _ensure_table_volume_exists(self, dataset_id: str): def _ensure_table_volume_exists(self, dataset_id: str) -> None:
"""Ensure table volume exists for the given dataset_id.""" """Ensure table volume exists for the given dataset_id."""
if self._config.volume_type != "table" or not dataset_id: if self._config.volume_type != "table" or not dataset_id:
return return
@ -250,7 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage):
# Don't raise exception, let the operation continue # Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions # The table might exist but not be visible due to permissions
def save(self, filename: str, data: bytes): def save(self, filename: str, data: bytes) -> None:
"""Save data to ClickZetta Volume. """Save data to ClickZetta Volume.
Args: Args:
@ -381,7 +382,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s loaded as stream from ClickZetta Volume", filename) logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
def download(self, filename: str, target_filepath: str): def download(self, filename: str, target_filepath: str) -> None:
"""Download file from ClickZetta Volume to local path. """Download file from ClickZetta Volume to local path.
Args: Args:
@ -435,7 +436,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.warning("Error checking file existence for %s: %s", filename, e) logger.warning("Error checking file existence for %s: %s", filename, e)
return False return False
def delete(self, filename: str): def delete(self, filename: str) -> None:
"""Delete file from ClickZetta Volume. """Delete file from ClickZetta Volume.
Args: Args:

View File

@ -41,7 +41,7 @@ class FileMetadata:
tags: dict[str, str] | None = None tags: dict[str, str] | None = None
parent_version: int | None = None parent_version: int | None = None
def to_dict(self): def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format""" """Convert to dictionary format"""
data = asdict(self) data = asdict(self)
data["created_at"] = self.created_at.isoformat() data["created_at"] = self.created_at.isoformat()
@ -62,7 +62,7 @@ class FileMetadata:
class FileLifecycleManager: class FileLifecycleManager:
"""File lifecycle manager""" """File lifecycle manager"""
def __init__(self, storage, dataset_id: str | None = None): def __init__(self, storage: Any, dataset_id: str | None = None) -> None:
"""Initialize lifecycle manager """Initialize lifecycle manager
Args: Args:
@ -435,7 +435,7 @@ class FileLifecycleManager:
logger.exception("Failed to get storage statistics") logger.exception("Failed to get storage statistics")
return {} return {}
def _create_version_backup(self, filename: str, metadata: dict): def _create_version_backup(self, filename: str, metadata: dict[str, Any]) -> None:
"""Create version backup""" """Create version backup"""
try: try:
# Read current file content # Read current file content
@ -463,7 +463,7 @@ class FileLifecycleManager:
logger.warning("Failed to load metadata: %s", e) logger.warning("Failed to load metadata: %s", e)
return {} return {}
def _save_metadata(self, metadata_dict: dict): def _save_metadata(self, metadata_dict: dict[str, Any]) -> None:
"""Save metadata file""" """Save metadata file"""
try: try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)

View File

@ -6,6 +6,7 @@ According to ClickZetta's permission model, different Volume types have differen
import logging import logging
from enum import StrEnum from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,7 +24,9 @@ class VolumePermission(StrEnum):
class VolumePermissionManager: class VolumePermissionManager:
"""Volume permission manager""" """Volume permission manager"""
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None): def __init__(
self, connection_or_config: Any, volume_type: str | None = None, volume_name: str | None = None
) -> None:
"""Initialize permission manager """Initialize permission manager
Args: Args:
@ -434,7 +437,7 @@ class VolumePermissionManager:
self._permission_cache[cache_key] = permissions self._permission_cache[cache_key] = permissions
return permissions return permissions
def clear_permission_cache(self): def clear_permission_cache(self) -> None:
"""Clear permission cache""" """Clear permission cache"""
self._permission_cache.clear() self._permission_cache.clear()
logger.debug("Permission cache cleared") logger.debug("Permission cache cleared")
@ -625,7 +628,9 @@ class VolumePermissionError(Exception):
super().__init__(message) super().__init__(message)
def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None): def check_volume_permission(
permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None
) -> Any:
"""Permission check decorator function """Permission check decorator function
Args: Args:

View File

@ -108,7 +108,15 @@ logger = logging.getLogger(__name__)
class DatasetService: class DatasetService:
@staticmethod @staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): def get_datasets(
page: int,
per_page: int,
tenant_id: str | None = None,
user: Account | None = None,
search: str | None = None,
tag_ids: list[str] | None = None,
include_all: bool = False,
) -> tuple[list[Dataset], int]:
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id) query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
if user: if user:
@ -176,10 +184,10 @@ class DatasetService:
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False) datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
return datasets.items, datasets.total return datasets.items, datasets.total or 0
@staticmethod @staticmethod
def get_process_rules(dataset_id): def get_process_rules(dataset_id: str) -> dict[str, Any]:
# get the latest process rule # get the latest process rule
dataset_process_rule = ( dataset_process_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
@ -197,7 +205,7 @@ class DatasetService:
return {"mode": mode, "rules": rules} return {"mode": mode, "rules": rules}
@staticmethod @staticmethod
def get_datasets_by_ids(ids, tenant_id): def get_datasets_by_ids(ids: list[str] | None, tenant_id: str) -> tuple[list[Dataset], int]:
# Check if ids is not empty to avoid WHERE false condition # Check if ids is not empty to avoid WHERE false condition
if not ids or len(ids) == 0: if not ids or len(ids) == 0:
return [], 0 return [], 0
@ -205,7 +213,7 @@ class DatasetService:
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total return datasets.items, datasets.total or 0
@staticmethod @staticmethod
def create_empty_dataset( def create_empty_dataset(
@ -222,7 +230,7 @@ class DatasetService:
embedding_model_name: str | None = None, embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None, retrieval_model: RetrievalModel | None = None,
summary_index_setting: dict | None = None, summary_index_setting: dict | None = None,
): ) -> Dataset:
# check if dataset name already exists # check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
@ -291,7 +299,7 @@ class DatasetService:
def create_empty_rag_pipeline_dataset( def create_empty_rag_pipeline_dataset(
tenant_id: str, tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
): ) -> Dataset:
if rag_pipeline_dataset_create_entity.name: if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists # check if dataset name already exists
if ( if (
@ -337,17 +345,17 @@ class DatasetService:
return dataset return dataset
@staticmethod @staticmethod
def get_dataset(dataset_id) -> Dataset | None: def get_dataset(dataset_id: str) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first() dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset return dataset
@staticmethod @staticmethod
def check_doc_form(dataset: Dataset, doc_form: str): def check_doc_form(dataset: Dataset, doc_form: str) -> None:
if dataset.doc_form and doc_form != dataset.doc_form: if dataset.doc_form and doc_form != dataset.doc_form:
raise ValueError("doc_form is different from the dataset doc_form.") raise ValueError("doc_form is different from the dataset doc_form.")
@staticmethod @staticmethod
def check_dataset_model_setting(dataset): def check_dataset_model_setting(dataset: Dataset) -> None:
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
try: try:
model_manager = ModelManager() model_manager = ModelManager()
@ -365,7 +373,7 @@ class DatasetService:
raise ValueError(f"The dataset is unavailable, due to: {ex.description}") raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
@staticmethod @staticmethod
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str) -> None:
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
@ -382,7 +390,7 @@ class DatasetService:
raise ValueError(ex.description) raise ValueError(ex.description)
@staticmethod @staticmethod
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str) -> bool:
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
@ -403,7 +411,7 @@ class DatasetService:
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.") raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
@staticmethod @staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str) -> None:
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
@ -420,7 +428,7 @@ class DatasetService:
raise ValueError(ex.description) raise ValueError(ex.description)
@staticmethod @staticmethod
def update_dataset(dataset_id, data, user): def update_dataset(dataset_id: str, data: dict[str, Any], user: Account) -> Dataset:
""" """
Update dataset configuration and settings. Update dataset configuration and settings.
@ -459,7 +467,7 @@ class DatasetService:
return DatasetService._update_internal_dataset(dataset, data, user) return DatasetService._update_internal_dataset(dataset, data, user)
@staticmethod @staticmethod
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str): def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str) -> bool:
dataset = ( dataset = (
db.session.query(Dataset) db.session.query(Dataset)
.where( .where(
@ -472,7 +480,7 @@ class DatasetService:
return dataset is not None return dataset is not None
@staticmethod @staticmethod
def _update_external_dataset(dataset, data, user): def _update_external_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset:
""" """
Update external dataset configuration. Update external dataset configuration.
@ -485,12 +493,12 @@ class DatasetService:
Dataset: Updated dataset object Dataset: Updated dataset object
""" """
# Update retrieval model if provided # Update retrieval model if provided
external_retrieval_model = data.get("external_retrieval_model", None) external_retrieval_model = data.get("external_retrieval_model")
if external_retrieval_model: if external_retrieval_model:
dataset.retrieval_model = external_retrieval_model dataset.retrieval_model = external_retrieval_model
# Update summary index setting if provided # Update summary index setting if provided
summary_index_setting = data.get("summary_index_setting", None) summary_index_setting = data.get("summary_index_setting")
if summary_index_setting is not None: if summary_index_setting is not None:
dataset.summary_index_setting = summary_index_setting dataset.summary_index_setting = summary_index_setting
@ -504,8 +512,8 @@ class DatasetService:
dataset.permission = permission dataset.permission = permission
# Validate and update external knowledge configuration # Validate and update external knowledge configuration
external_knowledge_id = data.get("external_knowledge_id", None) external_knowledge_id = data.get("external_knowledge_id")
external_knowledge_api_id = data.get("external_knowledge_api_id", None) external_knowledge_api_id = data.get("external_knowledge_api_id")
if not external_knowledge_id: if not external_knowledge_id:
raise ValueError("External knowledge id is required.") raise ValueError("External knowledge id is required.")
@ -525,7 +533,9 @@ class DatasetService:
return dataset return dataset
@staticmethod @staticmethod
def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): def _update_external_knowledge_binding(
dataset_id: str, external_knowledge_id: str, external_knowledge_api_id: str
) -> None:
""" """
Update external knowledge binding configuration. Update external knowledge binding configuration.
@ -552,7 +562,7 @@ class DatasetService:
db.session.add(external_knowledge_binding) db.session.add(external_knowledge_binding)
@staticmethod @staticmethod
def _update_internal_dataset(dataset, data, user): def _update_internal_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset:
""" """
Update internal dataset configuration. Update internal dataset configuration.
@ -590,7 +600,7 @@ class DatasetService:
filtered_data["icon_info"] = data.get("icon_info") filtered_data["icon_info"] = data.get("icon_info")
# Update dataset in database # Update dataset in database
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) # type: ignore[arg-type]
db.session.commit() db.session.commit()
# Reload dataset to get updated values # Reload dataset to get updated values
@ -618,7 +628,7 @@ class DatasetService:
return dataset return dataset
@staticmethod @staticmethod
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str): def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str) -> None:
""" """
Update pipeline knowledge base node data. Update pipeline knowledge base node data.
""" """
@ -701,7 +711,9 @@ class DatasetService:
raise raise
@staticmethod @staticmethod
def _handle_indexing_technique_change(dataset, data, filtered_data): def _handle_indexing_technique_change(
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
) -> str | None:
""" """
Handle changes in indexing technique and configure embedding models accordingly. Handle changes in indexing technique and configure embedding models accordingly.
@ -732,7 +744,7 @@ class DatasetService:
return None return None
@staticmethod @staticmethod
def _configure_embedding_model_for_high_quality(data, filtered_data): def _configure_embedding_model_for_high_quality(data: dict[str, Any], filtered_data: dict[str, Any]) -> None:
""" """
Configure embedding model settings for high quality indexing. Configure embedding model settings for high quality indexing.
@ -767,7 +779,9 @@ class DatasetService:
raise ValueError(ex.description) raise ValueError(ex.description)
@staticmethod @staticmethod
def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): def _handle_embedding_model_update_when_technique_unchanged(
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
) -> str | None:
""" """
Handle embedding model updates when indexing technique remains the same. Handle embedding model updates when indexing technique remains the same.
@ -792,7 +806,7 @@ class DatasetService:
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data) return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
@staticmethod @staticmethod
def _preserve_existing_embedding_settings(dataset, filtered_data): def _preserve_existing_embedding_settings(dataset: Dataset, filtered_data: dict[str, Any]) -> None:
""" """
Preserve existing embedding model settings when not provided in update. Preserve existing embedding model settings when not provided in update.
@ -815,7 +829,9 @@ class DatasetService:
del filtered_data["embedding_model"] del filtered_data["embedding_model"]
@staticmethod @staticmethod
def _update_embedding_model_settings(dataset, data, filtered_data): def _update_embedding_model_settings(
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
) -> str | None:
""" """
Update embedding model settings with new values. Update embedding model settings with new values.
@ -849,7 +865,7 @@ class DatasetService:
return None return None
@staticmethod @staticmethod
def _apply_new_embedding_settings(dataset, data, filtered_data): def _apply_new_embedding_settings(dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]) -> None:
""" """
Apply new embedding model settings to the dataset. Apply new embedding model settings to the dataset.
@ -946,7 +962,7 @@ class DatasetService:
@staticmethod @staticmethod
def update_rag_pipeline_dataset_settings( def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
): ) -> None:
if not current_user or not current_user.current_tenant_id: if not current_user or not current_user.current_tenant_id:
raise ValueError("Current user or current tenant not found") raise ValueError("Current user or current tenant not found")
dataset = session.merge(dataset) dataset = session.merge(dataset)
@ -1101,7 +1117,7 @@ class DatasetService:
deal_dataset_index_update_task.delay(dataset.id, action) deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod @staticmethod
def delete_dataset(dataset_id, user): def delete_dataset(dataset_id: str, user: Account) -> bool:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
@ -1116,12 +1132,14 @@ class DatasetService:
return True return True
@staticmethod @staticmethod
def dataset_use_check(dataset_id) -> bool: def dataset_use_check(dataset_id: str) -> bool:
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id)) stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
return db.session.execute(stmt).scalar_one() return db.session.execute(stmt).scalar_one()
@staticmethod @staticmethod
def check_dataset_permission(dataset, user): def check_dataset_permission(dataset: Dataset, user: Account | None) -> None:
if not user:
raise NoPermissionError("User not found.")
if dataset.tenant_id != user.current_tenant_id: if dataset.tenant_id != user.current_tenant_id:
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id) logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@ -1140,7 +1158,7 @@ class DatasetService:
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod @staticmethod
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None): def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None) -> None:
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
@ -1160,15 +1178,15 @@ class DatasetService:
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod @staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int): def get_dataset_queries(dataset_id: str, page: int, per_page: int) -> tuple[list[DatasetQuery], int]:
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at)) stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False) dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
return dataset_queries.items, dataset_queries.total return dataset_queries.items, dataset_queries.total or 0
@staticmethod @staticmethod
def get_related_apps(dataset_id: str): def get_related_apps(dataset_id: str) -> list[AppDatasetJoin]:
return ( return (
db.session.query(AppDatasetJoin) db.session.query(AppDatasetJoin)
.where(AppDatasetJoin.dataset_id == dataset_id) .where(AppDatasetJoin.dataset_id == dataset_id)
@ -1177,7 +1195,7 @@ class DatasetService:
) )
@staticmethod @staticmethod
def update_dataset_api_status(dataset_id: str, status: bool): def update_dataset_api_status(dataset_id: str, status: bool) -> None:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -1189,7 +1207,7 @@ class DatasetService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def get_dataset_auto_disable_logs(dataset_id: str): def get_dataset_auto_disable_logs(dataset_id: str) -> dict[str, Any]:
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
@ -1288,7 +1306,7 @@ class DocumentService:
return cls.DISPLAY_STATUS_FILTERS[normalized] return cls.DISPLAY_STATUS_FILTERS[normalized]
@classmethod @classmethod
def apply_display_status_filter(cls, query, status: str | None): def apply_display_status_filter(cls, query: Any, status: str | None) -> Any:
filters = cls.build_display_status_filters(status) filters = cls.build_display_status_filters(status)
if not filters: if not filters:
return query return query
@ -1684,19 +1702,19 @@ class DocumentService:
return documents return documents
@staticmethod @staticmethod
def get_document_file_detail(file_id: str): def get_document_file_detail(file_id: str) -> UploadFile | None:
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none() file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
return file_detail return file_detail
@staticmethod @staticmethod
def check_archived(document): def check_archived(document: Document) -> bool:
if document.archived: if document.archived:
return True return True
else: else:
return False return False
@staticmethod @staticmethod
def delete_document(document): def delete_document(document: Document) -> None:
# trigger document_was_deleted signal # trigger document_was_deleted signal
file_id = None file_id = None
if document.data_source_type == DataSourceType.UPLOAD_FILE: if document.data_source_type == DataSourceType.UPLOAD_FILE:
@ -1712,7 +1730,7 @@ class DocumentService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def delete_documents(dataset: Dataset, document_ids: list[str]): def delete_documents(dataset: Dataset, document_ids: list[str]) -> None:
# Check if document_ids is not empty to avoid WHERE false condition # Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0: if not document_ids or len(document_ids) == 0:
return return
@ -1768,7 +1786,7 @@ class DocumentService:
return document return document
@staticmethod @staticmethod
def pause_document(document): def pause_document(document: Document) -> None:
if document.indexing_status not in { if document.indexing_status not in {
IndexingStatus.WAITING, IndexingStatus.WAITING,
IndexingStatus.PARSING, IndexingStatus.PARSING,
@ -1790,7 +1808,7 @@ class DocumentService:
redis_client.setnx(indexing_cache_key, "True") redis_client.setnx(indexing_cache_key, "True")
@staticmethod @staticmethod
def recover_document(document): def recover_document(document: Document) -> None:
if not document.is_paused: if not document.is_paused:
raise DocumentIndexingError() raise DocumentIndexingError()
# update document to be recover # update document to be recover
@ -1807,7 +1825,7 @@ class DocumentService:
recover_document_indexing_task.delay(document.dataset_id, document.id) recover_document_indexing_task.delay(document.dataset_id, document.id)
@staticmethod @staticmethod
def retry_document(dataset_id: str, documents: list[Document]): def retry_document(dataset_id: str, documents: list[Document]) -> None:
for document in documents: for document in documents:
# add retry flag # add retry flag
retry_indexing_cache_key = f"document_{document.id}_is_retried" retry_indexing_cache_key = f"document_{document.id}_is_retried"
@ -1827,7 +1845,7 @@ class DocumentService:
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id) retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
@staticmethod @staticmethod
def sync_website_document(dataset_id: str, document: Document): def sync_website_document(dataset_id: str, document: Document) -> None:
# add sync flag # add sync flag
sync_indexing_cache_key = f"document_{document.id}_is_sync" sync_indexing_cache_key = f"document_{document.id}_is_sync"
cache_result = redis_client.get(sync_indexing_cache_key) cache_result = redis_client.get(sync_indexing_cache_key)
@ -1847,7 +1865,7 @@ class DocumentService:
sync_website_document_indexing_task.delay(dataset_id, document.id) sync_website_document_indexing_task.delay(dataset_id, document.id)
@staticmethod @staticmethod
def get_documents_position(dataset_id): def get_documents_position(dataset_id: str) -> int:
document = ( document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
) )

View File

@ -97,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str) -> None:
def del_model_config(session, model_config_id: str): def del_model_config(session: Any, model_config_id: str) -> None:
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -109,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
) )
def _delete_app_site(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str) -> None:
def del_site(session, site_id: str): def del_site(session: Any, site_id: str) -> None:
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -121,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
) )
def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_mcp_servers(tenant_id: str, app_id: str) -> None:
def del_mcp_server(session, mcp_server_id: str): def del_mcp_server(session: Any, mcp_server_id: str) -> None:
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -133,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
) )
def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str) -> None:
def del_api_token(session, api_token_id: str): def del_api_token(session: Any, api_token_id: str) -> None:
# Fetch token details for cache invalidation # Fetch token details for cache invalidation
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first() token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
if token_obj: if token_obj:
@ -151,8 +151,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
) )
def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str) -> None:
def del_installed_app(session, installed_app_id: str): def del_installed_app(session: Any, installed_app_id: str) -> None:
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -163,8 +163,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
) )
def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str) -> None:
def del_recommended_app(session, recommended_app_id: str): def del_recommended_app(session: Any, recommended_app_id: str) -> None:
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False) session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -175,8 +175,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
) )
def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str) -> None:
def del_annotation_hit_history(session, annotation_hit_history_id: str): def del_annotation_hit_history(session: Any, annotation_hit_history_id: str) -> None:
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False synchronize_session=False
) )
@ -188,7 +188,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
"annotation hit history", "annotation hit history",
) )
def del_annotation_setting(session, annotation_setting_id: str): def del_annotation_setting(session: Any, annotation_setting_id: str) -> None:
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False synchronize_session=False
) )
@ -201,8 +201,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
) )
def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str) -> None:
def del_dataset_join(session, dataset_join_id: str): def del_dataset_join(session: Any, dataset_join_id: str) -> None:
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -213,8 +213,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
) )
def _delete_app_workflows(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str) -> None:
def del_workflow(session, workflow_id: str): def del_workflow(session: Any, workflow_id: str) -> None:
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -225,7 +225,7 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
) )
def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_runs(tenant_id: str, app_id: str) -> None:
"""Delete all workflow runs for an app using the service repository.""" """Delete all workflow runs for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine) session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@ -239,7 +239,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id) logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id)
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str) -> None:
"""Delete all workflow node executions for an app using the service repository.""" """Delete all workflow node executions for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine) session_maker = sessionmaker(bind=db.engine)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
@ -253,8 +253,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id) logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id)
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str) -> None:
def del_workflow_app_log(session, workflow_app_log_id: str): def del_workflow_app_log(session: Any, workflow_app_log_id: str) -> None:
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -265,8 +265,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
) )
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str) -> None:
def del_workflow_archive_log(session, workflow_archive_log_id: str): def del_workflow_archive_log(session: Any, workflow_archive_log_id: str) -> None:
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False synchronize_session=False
) )
@ -279,7 +279,7 @@ def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
) )
def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): def _delete_archived_workflow_run_files(tenant_id: str, app_id: str) -> None:
prefix = f"{tenant_id}/app_id={app_id}/" prefix = f"{tenant_id}/app_id={app_id}/"
try: try:
archive_storage = get_archive_storage() archive_storage = get_archive_storage()
@ -304,8 +304,8 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
logger.info("Deleted %s archive objects for app %s", deleted, app_id) logger.info("Deleted %s archive objects for app %s", deleted, app_id)
def _delete_app_conversations(tenant_id: str, app_id: str): def _delete_app_conversations(tenant_id: str, app_id: str) -> None:
def del_conversation(session, conversation_id: str): def del_conversation(session: Any, conversation_id: str) -> None:
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False synchronize_session=False
) )
@ -319,7 +319,7 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
) )
def _delete_conversation_variables(*, app_id: str): def _delete_conversation_variables(app_id: str) -> None:
with session_factory.create_session() as session: with session_factory.create_session() as session:
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
session.execute(stmt) session.execute(stmt)
@ -327,8 +327,8 @@ def _delete_conversation_variables(*, app_id: str):
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str): def _delete_app_messages(tenant_id: str, app_id: str) -> None:
def del_message(session, message_id: str): def del_message(session: Any, message_id: str) -> None:
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False) session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False synchronize_session=False
@ -349,8 +349,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
) )
def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str) -> None:
def del_tool_provider(session, tool_provider_id: str): def del_tool_provider(session: Any, tool_provider_id: str) -> None:
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False synchronize_session=False
) )
@ -363,8 +363,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
) )
def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str) -> None:
def del_tag_binding(session, tag_binding_id: str): def del_tag_binding(session: Any, tag_binding_id: str) -> None:
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -375,8 +375,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
) )
def _delete_end_users(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str) -> None:
def del_end_user(session, end_user_id: str): def del_end_user(session: Any, end_user_id: str) -> None:
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -387,8 +387,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
) )
def _delete_trace_app_configs(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str) -> None:
def del_trace_app_config(session, trace_app_config_id: str): def del_trace_app_config(session: Any, trace_app_config_id: str) -> None:
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False) session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -399,9 +399,9 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
) )
def _delete_draft_variables(app_id: str): def _delete_draft_variables(app_id: str) -> None:
"""Delete all workflow draft variables for an app in batches.""" """Delete all workflow draft variables for an app in batches."""
return delete_draft_variables_batch(app_id, batch_size=1000) delete_draft_variables_batch(app_id, batch_size=1000)
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
@ -543,8 +543,8 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
return files_deleted return files_deleted
def _delete_app_triggers(tenant_id: str, app_id: str): def _delete_app_triggers(tenant_id: str, app_id: str) -> None:
def del_app_trigger(session, trigger_id: str): def del_app_trigger(session: Any, trigger_id: str) -> None:
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -555,8 +555,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
) )
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str) -> None:
def del_plugin_trigger(session, trigger_id: str): def del_plugin_trigger(session: Any, trigger_id: str) -> None:
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False synchronize_session=False
) )
@ -569,8 +569,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
) )
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str) -> None:
def del_webhook_trigger(session, trigger_id: str): def del_webhook_trigger(session: Any, trigger_id: str) -> None:
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False synchronize_session=False
) )
@ -583,8 +583,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
) )
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def _delete_workflow_schedule_plans(tenant_id: str, app_id: str) -> None:
def del_schedule_plan(session, plan_id: str): def del_schedule_plan(session: Any, plan_id: str) -> None:
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False) session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
_delete_records( _delete_records(
@ -595,8 +595,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
) )
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): def _delete_workflow_trigger_logs(tenant_id: str, app_id: str) -> None:
def del_trigger_log(session, log_id: str): def del_trigger_log(session: Any, log_id: str) -> None:
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
_delete_records( _delete_records(

View File

@ -28,12 +28,11 @@ class TestDeleteDraftVariablesBatch:
def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete): def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete):
"""Test that _delete_draft_variables calls the batch function correctly.""" """Test that _delete_draft_variables calls the batch function correctly."""
app_id = "test-app-id" app_id = "test-app-id"
expected_return = 42 mock_batch_delete.return_value = 42
mock_batch_delete.return_value = expected_return
result = _delete_draft_variables(app_id) result = _delete_draft_variables(app_id)
assert result == expected_return assert result is None
mock_batch_delete.assert_called_once_with(app_id, batch_size=1000) mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)