From 75273b4be5990f64da103c89fcc21fc79b3f10e3 Mon Sep 17 00:00:00 2001 From: ckstck Date: Mon, 23 Mar 2026 07:54:26 -0700 Subject: [PATCH] fix: missing types [autofix.ci] apply automated fixes fix/missing import types fix/type checking error fix: api test --- .../service_api/dataset/metadata.py | 13 +- .../clickzetta_volume_storage.py | 19 +-- .../clickzetta_volume/file_lifecycle.py | 8 +- .../clickzetta_volume/volume_permissions.py | 11 +- api/services/dataset_service.py | 118 ++++++++++-------- api/tasks/remove_app_and_related_data_task.py | 102 +++++++-------- .../test_remove_app_and_related_data_task.py | 5 +- 7 files changed, 150 insertions(+), 126 deletions(-) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 52166f7fcc..37f9ad30ab 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, cast from flask_login import current_user from flask_restx import marshal @@ -9,6 +9,7 @@ from controllers.common.schema import register_schema_model, register_schema_mod from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from fields.dataset_fields import dataset_metadata_fields +from models.account import Account from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( DocumentMetadataOperation, @@ -55,7 +56,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - DatasetService.check_dataset_permission(dataset, current_user) + DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user)) metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) return marshal(metadata, dataset_metadata_fields), 201 @@ -102,7 +103,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - DatasetService.check_dataset_permission(dataset, current_user) + DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user)) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name) return marshal(metadata, dataset_metadata_fields), 200 @@ -125,7 +126,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - DatasetService.check_dataset_permission(dataset, current_user) + DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user)) MetadataService.delete_metadata(dataset_id_str, metadata_id_str) return "", 204 @@ -166,7 +167,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - DatasetService.check_dataset_permission(dataset, current_user) + DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user)) match action: case "enable": @@ -196,7 +197,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - DatasetService.check_dataset_permission(dataset, current_user) + DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user)) metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {}) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 18eed4e481..5f261e4e3c 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -10,6 +10,7 @@ import tempfile from collections.abc import Generator from io import BytesIO from pathlib import Path +from typing import Any import clickzetta from pydantic import BaseModel, model_validator @@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]: """Validate the configuration values. This method will first try to use CLICKZETTA_VOLUME_* environment variables, @@ -110,7 +111,7 @@ class ClickZettaVolumeConfig(BaseModel): class ClickZettaVolumeStorage(BaseStorage): """ClickZetta Volume storage implementation.""" - def __init__(self, config: ClickZettaVolumeConfig): + def __init__(self, config: ClickZettaVolumeConfig) -> None: """Initialize ClickZetta Volume storage. Args: @@ -124,7 +125,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type) - def _init_connection(self): + def _init_connection(self) -> None: """Initialize ClickZetta connection.""" try: self._connection = clickzetta.connect( @@ -141,7 +142,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.exception("Failed to connect to ClickZetta") raise - def _init_permission_manager(self): + def _init_permission_manager(self) -> None: """Initialize permission manager.""" try: self._permission_manager = VolumePermissionManager( @@ -201,7 +202,7 @@ class ClickZettaVolumeStorage(BaseStorage): else: raise ValueError(f"Unsupported volume type: {self._config.volume_type}") - def _execute_sql(self, sql: str, fetch: bool = False): + def _execute_sql(self, sql: str, fetch: bool = False) -> Any: """Execute SQL command.""" try: if self._connection is None: @@ -215,7 +216,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.exception("SQL execution failed: %s", sql) raise - def _ensure_table_volume_exists(self, dataset_id: str): + def _ensure_table_volume_exists(self, dataset_id: str) -> None: """Ensure table volume exists for the given dataset_id.""" if self._config.volume_type != "table" or not dataset_id: return @@ -250,7 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage): # Don't raise exception, let the operation continue # The table might exist but not be visible due to permissions - def save(self, filename: str, data: bytes): + def save(self, filename: str, data: bytes) -> None: """Save data to ClickZetta Volume. Args: @@ -381,7 +382,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.debug("File %s loaded as stream from ClickZetta Volume", filename) - def download(self, filename: str, target_filepath: str): + def download(self, filename: str, target_filepath: str) -> None: """Download file from ClickZetta Volume to local path. Args: @@ -435,7 +436,7 @@ class ClickZettaVolumeStorage(BaseStorage): logger.warning("Error checking file existence for %s: %s", filename, e) return False - def delete(self, filename: str): + def delete(self, filename: str) -> None: """Delete file from ClickZetta Volume. Args: diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 1d9911465b..fc9968d430 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -41,7 +41,7 @@ class FileMetadata: tags: dict[str, str] | None = None parent_version: int | None = None - def to_dict(self): + def to_dict(self) -> dict[str, Any]: """Convert to dictionary format""" data = asdict(self) data["created_at"] = self.created_at.isoformat() @@ -62,7 +62,7 @@ class FileMetadata: class FileLifecycleManager: """File lifecycle manager""" - def __init__(self, storage, dataset_id: str | None = None): + def __init__(self, storage: Any, dataset_id: str | None = None) -> None: """Initialize lifecycle manager Args: @@ -435,7 +435,7 @@ class FileLifecycleManager: logger.exception("Failed to get storage statistics") return {} - def _create_version_backup(self, filename: str, metadata: dict): + def _create_version_backup(self, filename: str, metadata: dict[str, Any]) -> None: """Create version backup""" try: # Read current file content @@ -463,7 +463,7 @@ class FileLifecycleManager: logger.warning("Failed to load metadata: %s", e) return {} - def _save_metadata(self, metadata_dict: dict): + def _save_metadata(self, metadata_dict: dict[str, Any]) -> None: """Save metadata file""" try: metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 9d4ca689d8..805cb662c2 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -6,6 +6,7 @@ According to ClickZetta's permission model, different Volume types have differen import logging from enum import StrEnum +from typing import Any logger = logging.getLogger(__name__) @@ -23,7 +24,9 @@ class VolumePermission(StrEnum): class VolumePermissionManager: """Volume permission manager""" - def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None): + def __init__( + self, connection_or_config: Any, volume_type: str | None = None, volume_name: str | None = None + ) -> None: """Initialize permission manager Args: @@ -434,7 +437,7 @@ class VolumePermissionManager: self._permission_cache[cache_key] = permissions return permissions - def clear_permission_cache(self): + def clear_permission_cache(self) -> None: """Clear permission cache""" self._permission_cache.clear() logger.debug("Permission cache cleared") @@ -625,7 +628,9 @@ class VolumePermissionError(Exception): super().__init__(message) -def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None): +def check_volume_permission( + permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None +) -> Any: """Permission check decorator function Args: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dc..643d2bf26f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -108,7 +108,15 @@ logger = logging.getLogger(__name__) class DatasetService: @staticmethod - def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): + def get_datasets( + page: int, + per_page: int, + tenant_id: str | None = None, + user: Account | None = None, + search: str | None = None, + tag_ids: list[str] | None = None, + include_all: bool = False, + ) -> tuple[list[Dataset], int]: query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id) if user: @@ -176,10 +184,10 @@ class DatasetService: datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False) - return datasets.items, datasets.total + return datasets.items, datasets.total or 0 @staticmethod - def get_process_rules(dataset_id): + def get_process_rules(dataset_id: str) -> dict[str, Any]: # get the latest process rule dataset_process_rule = ( db.session.query(DatasetProcessRule) @@ -197,7 +205,7 @@ class DatasetService: return {"mode": mode, "rules": rules} @staticmethod - def get_datasets_by_ids(ids, tenant_id): + def get_datasets_by_ids(ids: list[str] | None, tenant_id: str) -> tuple[list[Dataset], int]: # Check if ids is not empty to avoid WHERE false condition if not ids or len(ids) == 0: return [], 0 @@ -205,7 +213,7 @@ class DatasetService: datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) - return datasets.items, datasets.total + return datasets.items, datasets.total or 0 @staticmethod def create_empty_dataset( @@ -222,7 +230,7 @@ class DatasetService: embedding_model_name: str | None = None, retrieval_model: RetrievalModel | None = None, summary_index_setting: dict | None = None, - ): + ) -> Dataset: # check if dataset name already exists if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") @@ -291,7 +299,7 @@ class DatasetService: def create_empty_rag_pipeline_dataset( tenant_id: str, rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, - ): + ) -> Dataset: if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists if ( @@ -337,17 +345,17 @@ class DatasetService: return dataset @staticmethod - def get_dataset(dataset_id) -> Dataset | None: + def get_dataset(dataset_id: str) -> Dataset | None: dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first() return dataset @staticmethod - def check_doc_form(dataset: Dataset, doc_form: str): + def check_doc_form(dataset: Dataset, doc_form: str) -> None: if dataset.doc_form and doc_form != dataset.doc_form: raise ValueError("doc_form is different from the dataset doc_form.") @staticmethod - def check_dataset_model_setting(dataset): + def check_dataset_model_setting(dataset: Dataset) -> None: if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() @@ -365,7 +373,7 @@ class DatasetService: raise ValueError(f"The dataset is unavailable, due to: {ex.description}") @staticmethod - def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): + def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str) -> None: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -382,7 +390,7 @@ class DatasetService: raise ValueError(ex.description) @staticmethod - def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): + def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str) -> bool: try: model_manager = ModelManager() model_instance = model_manager.get_model_instance( @@ -403,7 +411,7 @@ class DatasetService: raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.") @staticmethod - def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): + def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str) -> None: try: model_manager = ModelManager() model_manager.get_model_instance( @@ -420,7 +428,7 @@ class DatasetService: raise ValueError(ex.description) @staticmethod - def update_dataset(dataset_id, data, user): + def update_dataset(dataset_id: str, data: dict[str, Any], user: Account) -> Dataset: """ Update dataset configuration and settings. @@ -459,7 +467,7 @@ class DatasetService: return DatasetService._update_internal_dataset(dataset, data, user) @staticmethod - def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str): + def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str) -> bool: dataset = ( db.session.query(Dataset) .where( @@ -472,7 +480,7 @@ class DatasetService: return dataset is not None @staticmethod - def _update_external_dataset(dataset, data, user): + def _update_external_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset: """ Update external dataset configuration. @@ -485,12 +493,12 @@ class DatasetService: Dataset: Updated dataset object """ # Update retrieval model if provided - external_retrieval_model = data.get("external_retrieval_model", None) + external_retrieval_model = data.get("external_retrieval_model") if external_retrieval_model: dataset.retrieval_model = external_retrieval_model # Update summary index setting if provided - summary_index_setting = data.get("summary_index_setting", None) + summary_index_setting = data.get("summary_index_setting") if summary_index_setting is not None: dataset.summary_index_setting = summary_index_setting @@ -504,8 +512,8 @@ class DatasetService: dataset.permission = permission # Validate and update external knowledge configuration - external_knowledge_id = data.get("external_knowledge_id", None) - external_knowledge_api_id = data.get("external_knowledge_api_id", None) + external_knowledge_id = data.get("external_knowledge_id") + external_knowledge_api_id = data.get("external_knowledge_api_id") if not external_knowledge_id: raise ValueError("External knowledge id is required.") @@ -525,7 +533,9 @@ class DatasetService: return dataset @staticmethod - def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): + def _update_external_knowledge_binding( + dataset_id: str, external_knowledge_id: str, external_knowledge_api_id: str + ) -> None: """ Update external knowledge binding configuration. @@ -552,7 +562,7 @@ class DatasetService: db.session.add(external_knowledge_binding) @staticmethod - def _update_internal_dataset(dataset, data, user): + def _update_internal_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset: """ Update internal dataset configuration. @@ -590,7 +600,7 @@ class DatasetService: filtered_data["icon_info"] = data.get("icon_info") # Update dataset in database - db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) + db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) # type: ignore[arg-type] db.session.commit() # Reload dataset to get updated values @@ -618,7 +628,7 @@ class DatasetService: return dataset @staticmethod - def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str): + def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str) -> None: """ Update pipeline knowledge base node data. """ @@ -701,7 +711,9 @@ class DatasetService: raise @staticmethod - def _handle_indexing_technique_change(dataset, data, filtered_data): + def _handle_indexing_technique_change( + dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any] + ) -> str | None: """ Handle changes in indexing technique and configure embedding models accordingly. @@ -732,7 +744,7 @@ class DatasetService: return None @staticmethod - def _configure_embedding_model_for_high_quality(data, filtered_data): + def _configure_embedding_model_for_high_quality(data: dict[str, Any], filtered_data: dict[str, Any]) -> None: """ Configure embedding model settings for high quality indexing. @@ -767,7 +779,9 @@ class DatasetService: raise ValueError(ex.description) @staticmethod - def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): + def _handle_embedding_model_update_when_technique_unchanged( + dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any] + ) -> str | None: """ Handle embedding model updates when indexing technique remains the same. @@ -792,7 +806,7 @@ class DatasetService: return DatasetService._update_embedding_model_settings(dataset, data, filtered_data) @staticmethod - def _preserve_existing_embedding_settings(dataset, filtered_data): + def _preserve_existing_embedding_settings(dataset: Dataset, filtered_data: dict[str, Any]) -> None: """ Preserve existing embedding model settings when not provided in update. @@ -815,7 +829,9 @@ class DatasetService: del filtered_data["embedding_model"] @staticmethod - def _update_embedding_model_settings(dataset, data, filtered_data): + def _update_embedding_model_settings( + dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any] + ) -> str | None: """ Update embedding model settings with new values. @@ -849,7 +865,7 @@ class DatasetService: return None @staticmethod - def _apply_new_embedding_settings(dataset, data, filtered_data): + def _apply_new_embedding_settings(dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]) -> None: """ Apply new embedding model settings to the dataset. @@ -946,7 +962,7 @@ class DatasetService: @staticmethod def update_rag_pipeline_dataset_settings( session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False - ): + ) -> None: if not current_user or not current_user.current_tenant_id: raise ValueError("Current user or current tenant not found") dataset = session.merge(dataset) @@ -1101,7 +1117,7 @@ class DatasetService: deal_dataset_index_update_task.delay(dataset.id, action) @staticmethod - def delete_dataset(dataset_id, user): + def delete_dataset(dataset_id: str, user: Account) -> bool: dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -1116,12 +1132,14 @@ class DatasetService: return True @staticmethod - def dataset_use_check(dataset_id) -> bool: + def dataset_use_check(dataset_id: str) -> bool: stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id)) return db.session.execute(stmt).scalar_one() @staticmethod - def check_dataset_permission(dataset, user): + def check_dataset_permission(dataset: Dataset, user: Account | None) -> None: + if not user: + raise NoPermissionError("User not found.") if dataset.tenant_id != user.current_tenant_id: logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id) raise NoPermissionError("You do not have permission to access this dataset.") @@ -1140,7 +1158,7 @@ class DatasetService: raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None): + def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None) -> None: if not dataset: raise ValueError("Dataset not found") @@ -1160,15 +1178,15 @@ class DatasetService: raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def get_dataset_queries(dataset_id: str, page: int, per_page: int): + def get_dataset_queries(dataset_id: str, page: int, per_page: int) -> tuple[list[DatasetQuery], int]: stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at)) dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False) - return dataset_queries.items, dataset_queries.total + return dataset_queries.items, dataset_queries.total or 0 @staticmethod - def get_related_apps(dataset_id: str): + def get_related_apps(dataset_id: str) -> list[AppDatasetJoin]: return ( db.session.query(AppDatasetJoin) .where(AppDatasetJoin.dataset_id == dataset_id) @@ -1177,7 +1195,7 @@ class DatasetService: ) @staticmethod - def update_dataset_api_status(dataset_id: str, status: bool): + def update_dataset_api_status(dataset_id: str, status: bool) -> None: dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") @@ -1189,7 +1207,7 @@ class DatasetService: db.session.commit() @staticmethod - def get_dataset_auto_disable_logs(dataset_id: str): + def get_dataset_auto_disable_logs(dataset_id: str) -> dict[str, Any]: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None features = FeatureService.get_features(current_user.current_tenant_id) @@ -1288,7 +1306,7 @@ class DocumentService: return cls.DISPLAY_STATUS_FILTERS[normalized] @classmethod - def apply_display_status_filter(cls, query, status: str | None): + def apply_display_status_filter(cls, query: Any, status: str | None) -> Any: filters = cls.build_display_status_filters(status) if not filters: return query @@ -1684,19 +1702,19 @@ class DocumentService: return documents @staticmethod - def get_document_file_detail(file_id: str): + def get_document_file_detail(file_id: str) -> UploadFile | None: file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none() return file_detail @staticmethod - def check_archived(document): + def check_archived(document: Document) -> bool: if document.archived: return True else: return False @staticmethod - def delete_document(document): + def delete_document(document: Document) -> None: # trigger document_was_deleted signal file_id = None if document.data_source_type == DataSourceType.UPLOAD_FILE: @@ -1712,7 +1730,7 @@ class DocumentService: db.session.commit() @staticmethod - def delete_documents(dataset: Dataset, document_ids: list[str]): + def delete_documents(dataset: Dataset, document_ids: list[str]) -> None: # Check if document_ids is not empty to avoid WHERE false condition if not document_ids or len(document_ids) == 0: return @@ -1768,7 +1786,7 @@ class DocumentService: return document @staticmethod - def pause_document(document): + def pause_document(document: Document) -> None: if document.indexing_status not in { IndexingStatus.WAITING, IndexingStatus.PARSING, @@ -1790,7 +1808,7 @@ class DocumentService: redis_client.setnx(indexing_cache_key, "True") @staticmethod - def recover_document(document): + def recover_document(document: Document) -> None: if not document.is_paused: raise DocumentIndexingError() # update document to be recover @@ -1807,7 +1825,7 @@ class DocumentService: recover_document_indexing_task.delay(document.dataset_id, document.id) @staticmethod - def retry_document(dataset_id: str, documents: list[Document]): + def retry_document(dataset_id: str, documents: list[Document]) -> None: for document in documents: # add retry flag retry_indexing_cache_key = f"document_{document.id}_is_retried" @@ -1827,7 +1845,7 @@ class DocumentService: retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id) @staticmethod - def sync_website_document(dataset_id: str, document: Document): + def sync_website_document(dataset_id: str, document: Document) -> None: # add sync flag sync_indexing_cache_key = f"document_{document.id}_is_sync" cache_result = redis_client.get(sync_indexing_cache_key) @@ -1847,7 +1865,7 @@ class DocumentService: sync_website_document_indexing_task.delay(dataset_id, document.id) @staticmethod - def get_documents_position(dataset_id): + def get_documents_position(dataset_id: str) -> int: document = ( db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() ) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index b1840662ff..cd08bc87f3 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -97,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): raise self.retry(exc=e, countdown=60) # Retry after 60 seconds -def _delete_app_model_configs(tenant_id: str, app_id: str): - def del_model_config(session, model_config_id: str): +def _delete_app_model_configs(tenant_id: str, app_id: str) -> None: + def del_model_config(session: Any, model_config_id: str) -> None: session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) _delete_records( @@ -109,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): ) -def _delete_app_site(tenant_id: str, app_id: str): - def del_site(session, site_id: str): +def _delete_app_site(tenant_id: str, app_id: str) -> None: + def del_site(session: Any, site_id: str) -> None: session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) _delete_records( @@ -121,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str): ) -def _delete_app_mcp_servers(tenant_id: str, app_id: str): - def del_mcp_server(session, mcp_server_id: str): +def _delete_app_mcp_servers(tenant_id: str, app_id: str) -> None: + def del_mcp_server(session: Any, mcp_server_id: str) -> None: session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) _delete_records( @@ -133,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): ) -def _delete_app_api_tokens(tenant_id: str, app_id: str): - def del_api_token(session, api_token_id: str): +def _delete_app_api_tokens(tenant_id: str, app_id: str) -> None: + def del_api_token(session: Any, api_token_id: str) -> None: # Fetch token details for cache invalidation token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first() if token_obj: @@ -151,8 +151,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): ) -def _delete_installed_apps(tenant_id: str, app_id: str): - def del_installed_app(session, installed_app_id: str): +def _delete_installed_apps(tenant_id: str, app_id: str) -> None: + def del_installed_app(session: Any, installed_app_id: str) -> None: session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) _delete_records( @@ -163,8 +163,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str): ) -def _delete_recommended_apps(tenant_id: str, app_id: str): - def del_recommended_app(session, recommended_app_id: str): +def _delete_recommended_apps(tenant_id: str, app_id: str) -> None: + def del_recommended_app(session: Any, recommended_app_id: str) -> None: session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False) _delete_records( @@ -175,8 +175,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): ) -def _delete_app_annotation_data(tenant_id: str, app_id: str): - def del_annotation_hit_history(session, annotation_hit_history_id: str): +def _delete_app_annotation_data(tenant_id: str, app_id: str) -> None: + def del_annotation_hit_history(session: Any, annotation_hit_history_id: str) -> None: session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( synchronize_session=False ) @@ -188,7 +188,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): "annotation hit history", ) - def del_annotation_setting(session, annotation_setting_id: str): + def del_annotation_setting(session: Any, annotation_setting_id: str) -> None: session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( synchronize_session=False ) @@ -201,8 +201,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): ) -def _delete_app_dataset_joins(tenant_id: str, app_id: str): - def del_dataset_join(session, dataset_join_id: str): +def _delete_app_dataset_joins(tenant_id: str, app_id: str) -> None: + def del_dataset_join(session: Any, dataset_join_id: str) -> None: session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) _delete_records( @@ -213,8 +213,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): ) -def _delete_app_workflows(tenant_id: str, app_id: str): - def del_workflow(session, workflow_id: str): +def _delete_app_workflows(tenant_id: str, app_id: str) -> None: + def del_workflow(session: Any, workflow_id: str) -> None: session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) _delete_records( @@ -225,7 +225,7 @@ def _delete_app_workflows(tenant_id: str, app_id: str): ) -def _delete_app_workflow_runs(tenant_id: str, app_id: str): +def _delete_app_workflow_runs(tenant_id: str, app_id: str) -> None: """Delete all workflow runs for an app using the service repository.""" session_maker = sessionmaker(bind=db.engine) workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @@ -239,7 +239,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id) -def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): +def _delete_app_workflow_node_executions(tenant_id: str, app_id: str) -> None: """Delete all workflow node executions for an app using the service repository.""" session_maker = sessionmaker(bind=db.engine) node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) @@ -253,8 +253,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id) -def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): - def del_workflow_app_log(session, workflow_app_log_id: str): +def _delete_app_workflow_app_logs(tenant_id: str, app_id: str) -> None: + def del_workflow_app_log(session: Any, workflow_app_log_id: str) -> None: session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) _delete_records( @@ -265,8 +265,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): ) -def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): - def del_workflow_archive_log(session, workflow_archive_log_id: str): +def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str) -> None: + def del_workflow_archive_log(session: Any, workflow_archive_log_id: str) -> None: session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( synchronize_session=False ) @@ -279,7 +279,7 @@ def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): ) -def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): +def _delete_archived_workflow_run_files(tenant_id: str, app_id: str) -> None: prefix = f"{tenant_id}/app_id={app_id}/" try: archive_storage = get_archive_storage() @@ -304,8 +304,8 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): logger.info("Deleted %s archive objects for app %s", deleted, app_id) -def _delete_app_conversations(tenant_id: str, app_id: str): - def del_conversation(session, conversation_id: str): +def _delete_app_conversations(tenant_id: str, app_id: str) -> None: + def del_conversation(session: Any, conversation_id: str) -> None: session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( synchronize_session=False ) @@ -319,7 +319,7 @@ def _delete_app_conversations(tenant_id: str, app_id: str): ) -def _delete_conversation_variables(*, app_id: str): +def _delete_conversation_variables(app_id: str) -> None: with session_factory.create_session() as session: stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) session.execute(stmt) @@ -327,8 +327,8 @@ def _delete_conversation_variables(*, app_id: str): logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) -def _delete_app_messages(tenant_id: str, app_id: str): - def del_message(session, message_id: str): +def _delete_app_messages(tenant_id: str, app_id: str) -> None: + def del_message(session: Any, message_id: str) -> None: session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False) session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( synchronize_session=False @@ -349,8 +349,8 @@ def _delete_app_messages(tenant_id: str, app_id: str): ) -def _delete_workflow_tool_providers(tenant_id: str, app_id: str): - def del_tool_provider(session, tool_provider_id: str): +def _delete_workflow_tool_providers(tenant_id: str, app_id: str) -> None: + def del_tool_provider(session: Any, tool_provider_id: str) -> None: session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( synchronize_session=False ) @@ -363,8 +363,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): ) -def _delete_app_tag_bindings(tenant_id: str, app_id: str): - def del_tag_binding(session, tag_binding_id: str): +def _delete_app_tag_bindings(tenant_id: str, app_id: str) -> None: + def del_tag_binding(session: Any, tag_binding_id: str) -> None: session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) _delete_records( @@ -375,8 +375,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): ) -def _delete_end_users(tenant_id: str, app_id: str): - def del_end_user(session, end_user_id: str): +def _delete_end_users(tenant_id: str, app_id: str) -> None: + def del_end_user(session: Any, end_user_id: str) -> None: session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) _delete_records( @@ -387,8 +387,8 @@ def _delete_end_users(tenant_id: str, app_id: str): ) -def _delete_trace_app_configs(tenant_id: str, app_id: str): - def del_trace_app_config(session, trace_app_config_id: str): +def _delete_trace_app_configs(tenant_id: str, app_id: str) -> None: + def del_trace_app_config(session: Any, trace_app_config_id: str) -> None: session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False) _delete_records( @@ -399,9 +399,9 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str): ) -def _delete_draft_variables(app_id: str): +def _delete_draft_variables(app_id: str) -> None: """Delete all workflow draft variables for an app in batches.""" - return delete_draft_variables_batch(app_id, batch_size=1000) + delete_draft_variables_batch(app_id, batch_size=1000) def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: @@ -543,8 +543,8 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int: return files_deleted -def _delete_app_triggers(tenant_id: str, app_id: str): - def del_app_trigger(session, trigger_id: str): +def _delete_app_triggers(tenant_id: str, app_id: str) -> None: + def del_app_trigger(session: Any, trigger_id: str) -> None: session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) _delete_records( @@ -555,8 +555,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str): ) -def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): - def del_plugin_trigger(session, trigger_id: str): +def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str) -> None: + def del_plugin_trigger(session: Any, trigger_id: str) -> None: session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( synchronize_session=False ) @@ -569,8 +569,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): ) -def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): - def del_webhook_trigger(session, trigger_id: str): +def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str) -> None: + def del_webhook_trigger(session: Any, trigger_id: str) -> None: session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( synchronize_session=False ) @@ -583,8 +583,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): ) -def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): - def del_schedule_plan(session, plan_id: str): +def _delete_workflow_schedule_plans(tenant_id: str, app_id: str) -> None: + def del_schedule_plan(session: Any, plan_id: str) -> None: session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False) _delete_records( @@ -595,8 +595,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): ) -def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): - def del_trigger_log(session, log_id: str): +def _delete_workflow_trigger_logs(tenant_id: str, app_id: str) -> None: + def del_trigger_log(session: Any, log_id: str) -> None: session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) _delete_records( diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index 0ed4ca05fa..d7d9a6cc38 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -28,12 +28,11 @@ class TestDeleteDraftVariablesBatch: def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete): """Test that _delete_draft_variables calls the batch function correctly.""" app_id = "test-app-id" - expected_return = 42 - mock_batch_delete.return_value = expected_return + mock_batch_delete.return_value = 42 result = _delete_draft_variables(app_id) - assert result == expected_return + assert result is None mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)