fix: missing types

[autofix.ci] apply automated fixes

fix/missing import types

fix/type checking error

fix: api test
This commit is contained in:
ckstck 2026-03-23 07:54:26 -07:00
parent f5cc1c8b75
commit 75273b4be5
7 changed files with 150 additions and 126 deletions

View File

@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, cast
from flask_login import current_user
from flask_restx import marshal
@ -9,6 +9,7 @@ from controllers.common.schema import register_schema_model, register_schema_mod
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields
from models.account import Account
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
@ -55,7 +56,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return marshal(metadata, dataset_metadata_fields), 201
@ -102,7 +103,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200
@ -125,7 +126,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return "", 204
@ -166,7 +167,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
match action:
case "enable":
@ -196,7 +197,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})

View File

@ -10,6 +10,7 @@ import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
from typing import Any
import clickzetta
from pydantic import BaseModel, model_validator
@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
@ -110,7 +111,7 @@ class ClickZettaVolumeConfig(BaseModel):
class ClickZettaVolumeStorage(BaseStorage):
"""ClickZetta Volume storage implementation."""
def __init__(self, config: ClickZettaVolumeConfig):
def __init__(self, config: ClickZettaVolumeConfig) -> None:
"""Initialize ClickZetta Volume storage.
Args:
@ -124,7 +125,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
def _init_connection(self):
def _init_connection(self) -> None:
"""Initialize ClickZetta connection."""
try:
self._connection = clickzetta.connect(
@ -141,7 +142,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.exception("Failed to connect to ClickZetta")
raise
def _init_permission_manager(self):
def _init_permission_manager(self) -> None:
"""Initialize permission manager."""
try:
self._permission_manager = VolumePermissionManager(
@ -201,7 +202,7 @@ class ClickZettaVolumeStorage(BaseStorage):
else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _execute_sql(self, sql: str, fetch: bool = False):
def _execute_sql(self, sql: str, fetch: bool = False) -> Any:
"""Execute SQL command."""
try:
if self._connection is None:
@ -215,7 +216,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.exception("SQL execution failed: %s", sql)
raise
def _ensure_table_volume_exists(self, dataset_id: str):
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
"""Ensure table volume exists for the given dataset_id."""
if self._config.volume_type != "table" or not dataset_id:
return
@ -250,7 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage):
# Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions
def save(self, filename: str, data: bytes):
def save(self, filename: str, data: bytes) -> None:
"""Save data to ClickZetta Volume.
Args:
@ -381,7 +382,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
def download(self, filename: str, target_filepath: str):
def download(self, filename: str, target_filepath: str) -> None:
"""Download file from ClickZetta Volume to local path.
Args:
@ -435,7 +436,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.warning("Error checking file existence for %s: %s", filename, e)
return False
def delete(self, filename: str):
def delete(self, filename: str) -> None:
"""Delete file from ClickZetta Volume.
Args:

View File

@ -41,7 +41,7 @@ class FileMetadata:
tags: dict[str, str] | None = None
parent_version: int | None = None
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary format"""
data = asdict(self)
data["created_at"] = self.created_at.isoformat()
@ -62,7 +62,7 @@ class FileMetadata:
class FileLifecycleManager:
"""File lifecycle manager"""
def __init__(self, storage, dataset_id: str | None = None):
def __init__(self, storage: Any, dataset_id: str | None = None) -> None:
"""Initialize lifecycle manager
Args:
@ -435,7 +435,7 @@ class FileLifecycleManager:
logger.exception("Failed to get storage statistics")
return {}
def _create_version_backup(self, filename: str, metadata: dict):
def _create_version_backup(self, filename: str, metadata: dict[str, Any]) -> None:
"""Create version backup"""
try:
# Read current file content
@ -463,7 +463,7 @@ class FileLifecycleManager:
logger.warning("Failed to load metadata: %s", e)
return {}
def _save_metadata(self, metadata_dict: dict):
def _save_metadata(self, metadata_dict: dict[str, Any]) -> None:
"""Save metadata file"""
try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)

View File

@ -6,6 +6,7 @@ According to ClickZetta's permission model, different Volume types have differen
import logging
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__)
@ -23,7 +24,9 @@ class VolumePermission(StrEnum):
class VolumePermissionManager:
"""Volume permission manager"""
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None):
def __init__(
self, connection_or_config: Any, volume_type: str | None = None, volume_name: str | None = None
) -> None:
"""Initialize permission manager
Args:
@ -434,7 +437,7 @@ class VolumePermissionManager:
self._permission_cache[cache_key] = permissions
return permissions
def clear_permission_cache(self):
def clear_permission_cache(self) -> None:
"""Clear permission cache"""
self._permission_cache.clear()
logger.debug("Permission cache cleared")
@ -625,7 +628,9 @@ class VolumePermissionError(Exception):
super().__init__(message)
def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None):
def check_volume_permission(
permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None
) -> Any:
"""Permission check decorator function
Args:

View File

@ -108,7 +108,15 @@ logger = logging.getLogger(__name__)
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
def get_datasets(
page: int,
per_page: int,
tenant_id: str | None = None,
user: Account | None = None,
search: str | None = None,
tag_ids: list[str] | None = None,
include_all: bool = False,
) -> tuple[list[Dataset], int]:
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
if user:
@ -176,10 +184,10 @@ class DatasetService:
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
return datasets.items, datasets.total
return datasets.items, datasets.total or 0
@staticmethod
def get_process_rules(dataset_id):
def get_process_rules(dataset_id: str) -> dict[str, Any]:
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
@ -197,7 +205,7 @@ class DatasetService:
return {"mode": mode, "rules": rules}
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
def get_datasets_by_ids(ids: list[str] | None, tenant_id: str) -> tuple[list[Dataset], int]:
# Check if ids is not empty to avoid WHERE false condition
if not ids or len(ids) == 0:
return [], 0
@ -205,7 +213,7 @@ class DatasetService:
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total
return datasets.items, datasets.total or 0
@staticmethod
def create_empty_dataset(
@ -222,7 +230,7 @@ class DatasetService:
embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None,
summary_index_setting: dict | None = None,
):
) -> Dataset:
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
@ -291,7 +299,7 @@ class DatasetService:
def create_empty_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
) -> Dataset:
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
@ -337,17 +345,17 @@ class DatasetService:
return dataset
@staticmethod
def get_dataset(dataset_id) -> Dataset | None:
def get_dataset(dataset_id: str) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
def check_doc_form(dataset: Dataset, doc_form: str):
def check_doc_form(dataset: Dataset, doc_form: str) -> None:
if dataset.doc_form and doc_form != dataset.doc_form:
raise ValueError("doc_form is different from the dataset doc_form.")
@staticmethod
def check_dataset_model_setting(dataset):
def check_dataset_model_setting(dataset: Dataset) -> None:
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
@ -365,7 +373,7 @@ class DatasetService:
raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
@staticmethod
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str) -> None:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -382,7 +390,7 @@ class DatasetService:
raise ValueError(ex.description)
@staticmethod
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str) -> bool:
try:
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
@ -403,7 +411,7 @@ class DatasetService:
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str) -> None:
try:
model_manager = ModelManager()
model_manager.get_model_instance(
@ -420,7 +428,7 @@ class DatasetService:
raise ValueError(ex.description)
@staticmethod
def update_dataset(dataset_id, data, user):
def update_dataset(dataset_id: str, data: dict[str, Any], user: Account) -> Dataset:
"""
Update dataset configuration and settings.
@ -459,7 +467,7 @@ class DatasetService:
return DatasetService._update_internal_dataset(dataset, data, user)
@staticmethod
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str) -> bool:
dataset = (
db.session.query(Dataset)
.where(
@ -472,7 +480,7 @@ class DatasetService:
return dataset is not None
@staticmethod
def _update_external_dataset(dataset, data, user):
def _update_external_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset:
"""
Update external dataset configuration.
@ -485,12 +493,12 @@ class DatasetService:
Dataset: Updated dataset object
"""
# Update retrieval model if provided
external_retrieval_model = data.get("external_retrieval_model", None)
external_retrieval_model = data.get("external_retrieval_model")
if external_retrieval_model:
dataset.retrieval_model = external_retrieval_model
# Update summary index setting if provided
summary_index_setting = data.get("summary_index_setting", None)
summary_index_setting = data.get("summary_index_setting")
if summary_index_setting is not None:
dataset.summary_index_setting = summary_index_setting
@ -504,8 +512,8 @@ class DatasetService:
dataset.permission = permission
# Validate and update external knowledge configuration
external_knowledge_id = data.get("external_knowledge_id", None)
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
external_knowledge_id = data.get("external_knowledge_id")
external_knowledge_api_id = data.get("external_knowledge_api_id")
if not external_knowledge_id:
raise ValueError("External knowledge id is required.")
@ -525,7 +533,9 @@ class DatasetService:
return dataset
@staticmethod
def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
def _update_external_knowledge_binding(
dataset_id: str, external_knowledge_id: str, external_knowledge_api_id: str
) -> None:
"""
Update external knowledge binding configuration.
@ -552,7 +562,7 @@ class DatasetService:
db.session.add(external_knowledge_binding)
@staticmethod
def _update_internal_dataset(dataset, data, user):
def _update_internal_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset:
"""
Update internal dataset configuration.
@ -590,7 +600,7 @@ class DatasetService:
filtered_data["icon_info"] = data.get("icon_info")
# Update dataset in database
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) # type: ignore[arg-type]
db.session.commit()
# Reload dataset to get updated values
@ -618,7 +628,7 @@ class DatasetService:
return dataset
@staticmethod
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str):
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str) -> None:
"""
Update pipeline knowledge base node data.
"""
@ -701,7 +711,9 @@ class DatasetService:
raise
@staticmethod
def _handle_indexing_technique_change(dataset, data, filtered_data):
def _handle_indexing_technique_change(
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
) -> str | None:
"""
Handle changes in indexing technique and configure embedding models accordingly.
@ -732,7 +744,7 @@ class DatasetService:
return None
@staticmethod
def _configure_embedding_model_for_high_quality(data, filtered_data):
def _configure_embedding_model_for_high_quality(data: dict[str, Any], filtered_data: dict[str, Any]) -> None:
"""
Configure embedding model settings for high quality indexing.
@ -767,7 +779,9 @@ class DatasetService:
raise ValueError(ex.description)
@staticmethod
def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
def _handle_embedding_model_update_when_technique_unchanged(
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
) -> str | None:
"""
Handle embedding model updates when indexing technique remains the same.
@ -792,7 +806,7 @@ class DatasetService:
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
@staticmethod
def _preserve_existing_embedding_settings(dataset, filtered_data):
def _preserve_existing_embedding_settings(dataset: Dataset, filtered_data: dict[str, Any]) -> None:
"""
Preserve existing embedding model settings when not provided in update.
@ -815,7 +829,9 @@ class DatasetService:
del filtered_data["embedding_model"]
@staticmethod
def _update_embedding_model_settings(dataset, data, filtered_data):
def _update_embedding_model_settings(
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
) -> str | None:
"""
Update embedding model settings with new values.
@ -849,7 +865,7 @@ class DatasetService:
return None
@staticmethod
def _apply_new_embedding_settings(dataset, data, filtered_data):
def _apply_new_embedding_settings(dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]) -> None:
"""
Apply new embedding model settings to the dataset.
@ -946,7 +962,7 @@ class DatasetService:
@staticmethod
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
):
) -> None:
if not current_user or not current_user.current_tenant_id:
raise ValueError("Current user or current tenant not found")
dataset = session.merge(dataset)
@ -1101,7 +1117,7 @@ class DatasetService:
deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod
def delete_dataset(dataset_id, user):
def delete_dataset(dataset_id: str, user: Account) -> bool:
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
@ -1116,12 +1132,14 @@ class DatasetService:
return True
@staticmethod
def dataset_use_check(dataset_id) -> bool:
def dataset_use_check(dataset_id: str) -> bool:
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
return db.session.execute(stmt).scalar_one()
@staticmethod
def check_dataset_permission(dataset, user):
def check_dataset_permission(dataset: Dataset, user: Account | None) -> None:
if not user:
raise NoPermissionError("User not found.")
if dataset.tenant_id != user.current_tenant_id:
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
raise NoPermissionError("You do not have permission to access this dataset.")
@ -1140,7 +1158,7 @@ class DatasetService:
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None):
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None) -> None:
if not dataset:
raise ValueError("Dataset not found")
@ -1160,15 +1178,15 @@ class DatasetService:
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
def get_dataset_queries(dataset_id: str, page: int, per_page: int) -> tuple[list[DatasetQuery], int]:
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
return dataset_queries.items, dataset_queries.total
return dataset_queries.items, dataset_queries.total or 0
@staticmethod
def get_related_apps(dataset_id: str):
def get_related_apps(dataset_id: str) -> list[AppDatasetJoin]:
return (
db.session.query(AppDatasetJoin)
.where(AppDatasetJoin.dataset_id == dataset_id)
@ -1177,7 +1195,7 @@ class DatasetService:
)
@staticmethod
def update_dataset_api_status(dataset_id: str, status: bool):
def update_dataset_api_status(dataset_id: str, status: bool) -> None:
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@ -1189,7 +1207,7 @@ class DatasetService:
db.session.commit()
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str):
def get_dataset_auto_disable_logs(dataset_id: str) -> dict[str, Any]:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
@ -1288,7 +1306,7 @@ class DocumentService:
return cls.DISPLAY_STATUS_FILTERS[normalized]
@classmethod
def apply_display_status_filter(cls, query, status: str | None):
def apply_display_status_filter(cls, query: Any, status: str | None) -> Any:
filters = cls.build_display_status_filters(status)
if not filters:
return query
@ -1684,19 +1702,19 @@ class DocumentService:
return documents
@staticmethod
def get_document_file_detail(file_id: str):
def get_document_file_detail(file_id: str) -> UploadFile | None:
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
return file_detail
@staticmethod
def check_archived(document):
def check_archived(document: Document) -> bool:
if document.archived:
return True
else:
return False
@staticmethod
def delete_document(document):
def delete_document(document: Document) -> None:
# trigger document_was_deleted signal
file_id = None
if document.data_source_type == DataSourceType.UPLOAD_FILE:
@ -1712,7 +1730,7 @@ class DocumentService:
db.session.commit()
@staticmethod
def delete_documents(dataset: Dataset, document_ids: list[str]):
def delete_documents(dataset: Dataset, document_ids: list[str]) -> None:
# Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0:
return
@ -1768,7 +1786,7 @@ class DocumentService:
return document
@staticmethod
def pause_document(document):
def pause_document(document: Document) -> None:
if document.indexing_status not in {
IndexingStatus.WAITING,
IndexingStatus.PARSING,
@ -1790,7 +1808,7 @@ class DocumentService:
redis_client.setnx(indexing_cache_key, "True")
@staticmethod
def recover_document(document):
def recover_document(document: Document) -> None:
if not document.is_paused:
raise DocumentIndexingError()
# update document to be recover
@ -1807,7 +1825,7 @@ class DocumentService:
recover_document_indexing_task.delay(document.dataset_id, document.id)
@staticmethod
def retry_document(dataset_id: str, documents: list[Document]):
def retry_document(dataset_id: str, documents: list[Document]) -> None:
for document in documents:
# add retry flag
retry_indexing_cache_key = f"document_{document.id}_is_retried"
@ -1827,7 +1845,7 @@ class DocumentService:
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
@staticmethod
def sync_website_document(dataset_id: str, document: Document):
def sync_website_document(dataset_id: str, document: Document) -> None:
# add sync flag
sync_indexing_cache_key = f"document_{document.id}_is_sync"
cache_result = redis_client.get(sync_indexing_cache_key)
@ -1847,7 +1865,7 @@ class DocumentService:
sync_website_document_indexing_task.delay(dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id):
def get_documents_position(dataset_id: str) -> int:
document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
)

View File

@ -97,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
def _delete_app_model_configs(tenant_id: str, app_id: str):
def del_model_config(session, model_config_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str) -> None:
def del_model_config(session: Any, model_config_id: str) -> None:
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
@ -109,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
)
def _delete_app_site(tenant_id: str, app_id: str):
def del_site(session, site_id: str):
def _delete_app_site(tenant_id: str, app_id: str) -> None:
def del_site(session: Any, site_id: str) -> None:
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records(
@ -121,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
)
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(session, mcp_server_id: str):
def _delete_app_mcp_servers(tenant_id: str, app_id: str) -> None:
def del_mcp_server(session: Any, mcp_server_id: str) -> None:
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
@ -133,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
)
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(session, api_token_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str) -> None:
def del_api_token(session: Any, api_token_id: str) -> None:
# Fetch token details for cache invalidation
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
if token_obj:
@ -151,8 +151,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
)
def _delete_installed_apps(tenant_id: str, app_id: str):
def del_installed_app(session, installed_app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str) -> None:
def del_installed_app(session: Any, installed_app_id: str) -> None:
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
@ -163,8 +163,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
)
def _delete_recommended_apps(tenant_id: str, app_id: str):
def del_recommended_app(session, recommended_app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str) -> None:
def del_recommended_app(session: Any, recommended_app_id: str) -> None:
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
_delete_records(
@ -175,8 +175,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
)
def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(session, annotation_hit_history_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str) -> None:
def del_annotation_hit_history(session: Any, annotation_hit_history_id: str) -> None:
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
)
@ -188,7 +188,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
"annotation hit history",
)
def del_annotation_setting(session, annotation_setting_id: str):
def del_annotation_setting(session: Any, annotation_setting_id: str) -> None:
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
@ -201,8 +201,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
)
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def del_dataset_join(session, dataset_join_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str) -> None:
def del_dataset_join(session: Any, dataset_join_id: str) -> None:
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
@ -213,8 +213,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
)
def _delete_app_workflows(tenant_id: str, app_id: str):
def del_workflow(session, workflow_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str) -> None:
def del_workflow(session: Any, workflow_id: str) -> None:
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
@ -225,7 +225,7 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
)
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_runs(tenant_id: str, app_id: str) -> None:
"""Delete all workflow runs for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@ -239,7 +239,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id)
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str) -> None:
"""Delete all workflow node executions for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
@ -253,8 +253,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id)
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(session, workflow_app_log_id: str):
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str) -> None:
def del_workflow_app_log(session: Any, workflow_app_log_id: str) -> None:
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
_delete_records(
@ -265,8 +265,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
)
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
def del_workflow_archive_log(session, workflow_archive_log_id: str):
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str) -> None:
def del_workflow_archive_log(session: Any, workflow_archive_log_id: str) -> None:
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
)
@ -279,7 +279,7 @@ def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
)
def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
def _delete_archived_workflow_run_files(tenant_id: str, app_id: str) -> None:
prefix = f"{tenant_id}/app_id={app_id}/"
try:
archive_storage = get_archive_storage()
@ -304,8 +304,8 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
logger.info("Deleted %s archive objects for app %s", deleted, app_id)
def _delete_app_conversations(tenant_id: str, app_id: str):
def del_conversation(session, conversation_id: str):
def _delete_app_conversations(tenant_id: str, app_id: str) -> None:
def del_conversation(session: Any, conversation_id: str) -> None:
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
@ -319,7 +319,7 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
)
def _delete_conversation_variables(*, app_id: str):
def _delete_conversation_variables(app_id: str) -> None:
with session_factory.create_session() as session:
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
session.execute(stmt)
@ -327,8 +327,8 @@ def _delete_conversation_variables(*, app_id: str):
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(session, message_id: str):
def _delete_app_messages(tenant_id: str, app_id: str) -> None:
def del_message(session: Any, message_id: str) -> None:
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
@ -349,8 +349,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
)
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def del_tool_provider(session, tool_provider_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str) -> None:
def del_tool_provider(session: Any, tool_provider_id: str) -> None:
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
@ -363,8 +363,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
)
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def del_tag_binding(session, tag_binding_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str) -> None:
def del_tag_binding(session: Any, tag_binding_id: str) -> None:
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
@ -375,8 +375,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
)
def _delete_end_users(tenant_id: str, app_id: str):
def del_end_user(session, end_user_id: str):
def _delete_end_users(tenant_id: str, app_id: str) -> None:
def del_end_user(session: Any, end_user_id: str) -> None:
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
@ -387,8 +387,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
)
def _delete_trace_app_configs(tenant_id: str, app_id: str):
def del_trace_app_config(session, trace_app_config_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str) -> None:
def del_trace_app_config(session: Any, trace_app_config_id: str) -> None:
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
_delete_records(
@ -399,9 +399,9 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
)
def _delete_draft_variables(app_id: str):
def _delete_draft_variables(app_id: str) -> None:
"""Delete all workflow draft variables for an app in batches."""
return delete_draft_variables_batch(app_id, batch_size=1000)
delete_draft_variables_batch(app_id, batch_size=1000)
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
@ -543,8 +543,8 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
return files_deleted
def _delete_app_triggers(tenant_id: str, app_id: str):
def del_app_trigger(session, trigger_id: str):
def _delete_app_triggers(tenant_id: str, app_id: str) -> None:
def del_app_trigger(session: Any, trigger_id: str) -> None:
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
_delete_records(
@ -555,8 +555,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
)
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def del_plugin_trigger(session, trigger_id: str):
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str) -> None:
def del_plugin_trigger(session: Any, trigger_id: str) -> None:
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
)
@ -569,8 +569,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
)
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def del_webhook_trigger(session, trigger_id: str):
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str) -> None:
def del_webhook_trigger(session: Any, trigger_id: str) -> None:
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
)
@ -583,8 +583,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
)
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def del_schedule_plan(session, plan_id: str):
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str) -> None:
def del_schedule_plan(session: Any, plan_id: str) -> None:
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
_delete_records(
@ -595,8 +595,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
)
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def del_trigger_log(session, log_id: str):
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str) -> None:
def del_trigger_log(session: Any, log_id: str) -> None:
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
_delete_records(

View File

@ -28,12 +28,11 @@ class TestDeleteDraftVariablesBatch:
def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete):
"""Test that _delete_draft_variables calls the batch function correctly."""
app_id = "test-app-id"
expected_return = 42
mock_batch_delete.return_value = expected_return
mock_batch_delete.return_value = 42
result = _delete_draft_variables(app_id)
assert result == expected_return
assert result is None
mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)