mirror of https://github.com/langgenius/dify.git
fix: missing types
[autofix.ci] apply automated fixes fix/missing import types fix/type checking error fix: api test
This commit is contained in:
parent
f5cc1c8b75
commit
75273b4be5
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Literal
|
from typing import Literal, cast
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restx import marshal
|
from flask_restx import marshal
|
||||||
|
|
@ -9,6 +9,7 @@ from controllers.common.schema import register_schema_model, register_schema_mod
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||||
from fields.dataset_fields import dataset_metadata_fields
|
from fields.dataset_fields import dataset_metadata_fields
|
||||||
|
from models.account import Account
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import (
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
DocumentMetadataOperation,
|
DocumentMetadataOperation,
|
||||||
|
|
@ -55,7 +56,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||||
|
|
||||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||||
return marshal(metadata, dataset_metadata_fields), 201
|
return marshal(metadata, dataset_metadata_fields), 201
|
||||||
|
|
@ -102,7 +103,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||||
|
|
||||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||||
return marshal(metadata, dataset_metadata_fields), 200
|
return marshal(metadata, dataset_metadata_fields), 200
|
||||||
|
|
@ -125,7 +126,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||||
|
|
||||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
@ -166,7 +167,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||||
|
|
||||||
match action:
|
match action:
|
||||||
case "enable":
|
case "enable":
|
||||||
|
|
@ -196,7 +197,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, cast(Account | None, current_user))
|
||||||
|
|
||||||
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
|
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import tempfile
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import clickzetta
|
import clickzetta
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_config(cls, values: dict):
|
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Validate the configuration values.
|
"""Validate the configuration values.
|
||||||
|
|
||||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||||
|
|
@ -110,7 +111,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
||||||
class ClickZettaVolumeStorage(BaseStorage):
|
class ClickZettaVolumeStorage(BaseStorage):
|
||||||
"""ClickZetta Volume storage implementation."""
|
"""ClickZetta Volume storage implementation."""
|
||||||
|
|
||||||
def __init__(self, config: ClickZettaVolumeConfig):
|
def __init__(self, config: ClickZettaVolumeConfig) -> None:
|
||||||
"""Initialize ClickZetta Volume storage.
|
"""Initialize ClickZetta Volume storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -124,7 +125,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||||
|
|
||||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||||
|
|
||||||
def _init_connection(self):
|
def _init_connection(self) -> None:
|
||||||
"""Initialize ClickZetta connection."""
|
"""Initialize ClickZetta connection."""
|
||||||
try:
|
try:
|
||||||
self._connection = clickzetta.connect(
|
self._connection = clickzetta.connect(
|
||||||
|
|
@ -141,7 +142,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||||
logger.exception("Failed to connect to ClickZetta")
|
logger.exception("Failed to connect to ClickZetta")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _init_permission_manager(self):
|
def _init_permission_manager(self) -> None:
|
||||||
"""Initialize permission manager."""
|
"""Initialize permission manager."""
|
||||||
try:
|
try:
|
||||||
self._permission_manager = VolumePermissionManager(
|
self._permission_manager = VolumePermissionManager(
|
||||||
|
|
@ -201,7 +202,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||||
|
|
||||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
def _execute_sql(self, sql: str, fetch: bool = False) -> Any:
|
||||||
"""Execute SQL command."""
|
"""Execute SQL command."""
|
||||||
try:
|
try:
|
||||||
if self._connection is None:
|
if self._connection is None:
|
||||||
|
|
@ -215,7 +216,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||||
logger.exception("SQL execution failed: %s", sql)
|
logger.exception("SQL execution failed: %s", sql)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _ensure_table_volume_exists(self, dataset_id: str):
|
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
|
||||||
"""Ensure table volume exists for the given dataset_id."""
|
"""Ensure table volume exists for the given dataset_id."""
|
||||||
if self._config.volume_type != "table" or not dataset_id:
|
if self._config.volume_type != "table" or not dataset_id:
|
||||||
return
|
return
|
||||||
|
|
@ -250,7 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||||
# Don't raise exception, let the operation continue
|
# Don't raise exception, let the operation continue
|
||||||
# The table might exist but not be visible due to permissions
|
# The table might exist but not be visible due to permissions
|
||||||
|
|
||||||
def save(self, filename: str, data: bytes):
|
def save(self, filename: str, data: bytes) -> None:
|
||||||
"""Save data to ClickZetta Volume.
|
"""Save data to ClickZetta Volume.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -381,7 +382,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||||
|
|
||||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||||
|
|
||||||
def download(self, filename: str, target_filepath: str):
|
def download(self, filename: str, target_filepath: str) -> None:
|
||||||
"""Download file from ClickZetta Volume to local path.
|
"""Download file from ClickZetta Volume to local path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -435,7 +436,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete(self, filename: str):
|
def delete(self, filename: str) -> None:
|
||||||
"""Delete file from ClickZetta Volume.
|
"""Delete file from ClickZetta Volume.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class FileMetadata:
|
||||||
tags: dict[str, str] | None = None
|
tags: dict[str, str] | None = None
|
||||||
parent_version: int | None = None
|
parent_version: int | None = None
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Convert to dictionary format"""
|
"""Convert to dictionary format"""
|
||||||
data = asdict(self)
|
data = asdict(self)
|
||||||
data["created_at"] = self.created_at.isoformat()
|
data["created_at"] = self.created_at.isoformat()
|
||||||
|
|
@ -62,7 +62,7 @@ class FileMetadata:
|
||||||
class FileLifecycleManager:
|
class FileLifecycleManager:
|
||||||
"""File lifecycle manager"""
|
"""File lifecycle manager"""
|
||||||
|
|
||||||
def __init__(self, storage, dataset_id: str | None = None):
|
def __init__(self, storage: Any, dataset_id: str | None = None) -> None:
|
||||||
"""Initialize lifecycle manager
|
"""Initialize lifecycle manager
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -435,7 +435,7 @@ class FileLifecycleManager:
|
||||||
logger.exception("Failed to get storage statistics")
|
logger.exception("Failed to get storage statistics")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _create_version_backup(self, filename: str, metadata: dict):
|
def _create_version_backup(self, filename: str, metadata: dict[str, Any]) -> None:
|
||||||
"""Create version backup"""
|
"""Create version backup"""
|
||||||
try:
|
try:
|
||||||
# Read current file content
|
# Read current file content
|
||||||
|
|
@ -463,7 +463,7 @@ class FileLifecycleManager:
|
||||||
logger.warning("Failed to load metadata: %s", e)
|
logger.warning("Failed to load metadata: %s", e)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _save_metadata(self, metadata_dict: dict):
|
def _save_metadata(self, metadata_dict: dict[str, Any]) -> None:
|
||||||
"""Save metadata file"""
|
"""Save metadata file"""
|
||||||
try:
|
try:
|
||||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ According to ClickZetta's permission model, different Volume types have differen
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -23,7 +24,9 @@ class VolumePermission(StrEnum):
|
||||||
class VolumePermissionManager:
|
class VolumePermissionManager:
|
||||||
"""Volume permission manager"""
|
"""Volume permission manager"""
|
||||||
|
|
||||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None):
|
def __init__(
|
||||||
|
self, connection_or_config: Any, volume_type: str | None = None, volume_name: str | None = None
|
||||||
|
) -> None:
|
||||||
"""Initialize permission manager
|
"""Initialize permission manager
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -434,7 +437,7 @@ class VolumePermissionManager:
|
||||||
self._permission_cache[cache_key] = permissions
|
self._permission_cache[cache_key] = permissions
|
||||||
return permissions
|
return permissions
|
||||||
|
|
||||||
def clear_permission_cache(self):
|
def clear_permission_cache(self) -> None:
|
||||||
"""Clear permission cache"""
|
"""Clear permission cache"""
|
||||||
self._permission_cache.clear()
|
self._permission_cache.clear()
|
||||||
logger.debug("Permission cache cleared")
|
logger.debug("Permission cache cleared")
|
||||||
|
|
@ -625,7 +628,9 @@ class VolumePermissionError(Exception):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None):
|
def check_volume_permission(
|
||||||
|
permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None
|
||||||
|
) -> Any:
|
||||||
"""Permission check decorator function
|
"""Permission check decorator function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,15 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DatasetService:
|
class DatasetService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
def get_datasets(
|
||||||
|
page: int,
|
||||||
|
per_page: int,
|
||||||
|
tenant_id: str | None = None,
|
||||||
|
user: Account | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
tag_ids: list[str] | None = None,
|
||||||
|
include_all: bool = False,
|
||||||
|
) -> tuple[list[Dataset], int]:
|
||||||
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
|
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc(), Dataset.id)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
|
@ -176,10 +184,10 @@ class DatasetService:
|
||||||
|
|
||||||
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||||
|
|
||||||
return datasets.items, datasets.total
|
return datasets.items, datasets.total or 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_process_rules(dataset_id):
|
def get_process_rules(dataset_id: str) -> dict[str, Any]:
|
||||||
# get the latest process rule
|
# get the latest process rule
|
||||||
dataset_process_rule = (
|
dataset_process_rule = (
|
||||||
db.session.query(DatasetProcessRule)
|
db.session.query(DatasetProcessRule)
|
||||||
|
|
@ -197,7 +205,7 @@ class DatasetService:
|
||||||
return {"mode": mode, "rules": rules}
|
return {"mode": mode, "rules": rules}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_datasets_by_ids(ids, tenant_id):
|
def get_datasets_by_ids(ids: list[str] | None, tenant_id: str) -> tuple[list[Dataset], int]:
|
||||||
# Check if ids is not empty to avoid WHERE false condition
|
# Check if ids is not empty to avoid WHERE false condition
|
||||||
if not ids or len(ids) == 0:
|
if not ids or len(ids) == 0:
|
||||||
return [], 0
|
return [], 0
|
||||||
|
|
@ -205,7 +213,7 @@ class DatasetService:
|
||||||
|
|
||||||
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
||||||
|
|
||||||
return datasets.items, datasets.total
|
return datasets.items, datasets.total or 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_empty_dataset(
|
def create_empty_dataset(
|
||||||
|
|
@ -222,7 +230,7 @@ class DatasetService:
|
||||||
embedding_model_name: str | None = None,
|
embedding_model_name: str | None = None,
|
||||||
retrieval_model: RetrievalModel | None = None,
|
retrieval_model: RetrievalModel | None = None,
|
||||||
summary_index_setting: dict | None = None,
|
summary_index_setting: dict | None = None,
|
||||||
):
|
) -> Dataset:
|
||||||
# check if dataset name already exists
|
# check if dataset name already exists
|
||||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||||
|
|
@ -291,7 +299,7 @@ class DatasetService:
|
||||||
def create_empty_rag_pipeline_dataset(
|
def create_empty_rag_pipeline_dataset(
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||||
):
|
) -> Dataset:
|
||||||
if rag_pipeline_dataset_create_entity.name:
|
if rag_pipeline_dataset_create_entity.name:
|
||||||
# check if dataset name already exists
|
# check if dataset name already exists
|
||||||
if (
|
if (
|
||||||
|
|
@ -337,17 +345,17 @@ class DatasetService:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset(dataset_id) -> Dataset | None:
|
def get_dataset(dataset_id: str) -> Dataset | None:
|
||||||
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_doc_form(dataset: Dataset, doc_form: str):
|
def check_doc_form(dataset: Dataset, doc_form: str) -> None:
|
||||||
if dataset.doc_form and doc_form != dataset.doc_form:
|
if dataset.doc_form and doc_form != dataset.doc_form:
|
||||||
raise ValueError("doc_form is different from the dataset doc_form.")
|
raise ValueError("doc_form is different from the dataset doc_form.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_dataset_model_setting(dataset):
|
def check_dataset_model_setting(dataset: Dataset) -> None:
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
|
|
@ -365,7 +373,7 @@ class DatasetService:
|
||||||
raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
|
raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
|
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str) -> None:
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
|
|
@ -382,7 +390,7 @@ class DatasetService:
|
||||||
raise ValueError(ex.description)
|
raise ValueError(ex.description)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
|
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str) -> bool:
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_instance = model_manager.get_model_instance(
|
model_instance = model_manager.get_model_instance(
|
||||||
|
|
@ -403,7 +411,7 @@ class DatasetService:
|
||||||
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
|
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str) -> None:
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
|
|
@ -420,7 +428,7 @@ class DatasetService:
|
||||||
raise ValueError(ex.description)
|
raise ValueError(ex.description)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_dataset(dataset_id, data, user):
|
def update_dataset(dataset_id: str, data: dict[str, Any], user: Account) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Update dataset configuration and settings.
|
Update dataset configuration and settings.
|
||||||
|
|
||||||
|
|
@ -459,7 +467,7 @@ class DatasetService:
|
||||||
return DatasetService._update_internal_dataset(dataset, data, user)
|
return DatasetService._update_internal_dataset(dataset, data, user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
|
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str) -> bool:
|
||||||
dataset = (
|
dataset = (
|
||||||
db.session.query(Dataset)
|
db.session.query(Dataset)
|
||||||
.where(
|
.where(
|
||||||
|
|
@ -472,7 +480,7 @@ class DatasetService:
|
||||||
return dataset is not None
|
return dataset is not None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_external_dataset(dataset, data, user):
|
def _update_external_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Update external dataset configuration.
|
Update external dataset configuration.
|
||||||
|
|
||||||
|
|
@ -485,12 +493,12 @@ class DatasetService:
|
||||||
Dataset: Updated dataset object
|
Dataset: Updated dataset object
|
||||||
"""
|
"""
|
||||||
# Update retrieval model if provided
|
# Update retrieval model if provided
|
||||||
external_retrieval_model = data.get("external_retrieval_model", None)
|
external_retrieval_model = data.get("external_retrieval_model")
|
||||||
if external_retrieval_model:
|
if external_retrieval_model:
|
||||||
dataset.retrieval_model = external_retrieval_model
|
dataset.retrieval_model = external_retrieval_model
|
||||||
|
|
||||||
# Update summary index setting if provided
|
# Update summary index setting if provided
|
||||||
summary_index_setting = data.get("summary_index_setting", None)
|
summary_index_setting = data.get("summary_index_setting")
|
||||||
if summary_index_setting is not None:
|
if summary_index_setting is not None:
|
||||||
dataset.summary_index_setting = summary_index_setting
|
dataset.summary_index_setting = summary_index_setting
|
||||||
|
|
||||||
|
|
@ -504,8 +512,8 @@ class DatasetService:
|
||||||
dataset.permission = permission
|
dataset.permission = permission
|
||||||
|
|
||||||
# Validate and update external knowledge configuration
|
# Validate and update external knowledge configuration
|
||||||
external_knowledge_id = data.get("external_knowledge_id", None)
|
external_knowledge_id = data.get("external_knowledge_id")
|
||||||
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
external_knowledge_api_id = data.get("external_knowledge_api_id")
|
||||||
|
|
||||||
if not external_knowledge_id:
|
if not external_knowledge_id:
|
||||||
raise ValueError("External knowledge id is required.")
|
raise ValueError("External knowledge id is required.")
|
||||||
|
|
@ -525,7 +533,9 @@ class DatasetService:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
|
def _update_external_knowledge_binding(
|
||||||
|
dataset_id: str, external_knowledge_id: str, external_knowledge_api_id: str
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update external knowledge binding configuration.
|
Update external knowledge binding configuration.
|
||||||
|
|
||||||
|
|
@ -552,7 +562,7 @@ class DatasetService:
|
||||||
db.session.add(external_knowledge_binding)
|
db.session.add(external_knowledge_binding)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_internal_dataset(dataset, data, user):
|
def _update_internal_dataset(dataset: Dataset, data: dict[str, Any], user: Account) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Update internal dataset configuration.
|
Update internal dataset configuration.
|
||||||
|
|
||||||
|
|
@ -590,7 +600,7 @@ class DatasetService:
|
||||||
filtered_data["icon_info"] = data.get("icon_info")
|
filtered_data["icon_info"] = data.get("icon_info")
|
||||||
|
|
||||||
# Update dataset in database
|
# Update dataset in database
|
||||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) # type: ignore[arg-type]
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Reload dataset to get updated values
|
# Reload dataset to get updated values
|
||||||
|
|
@ -618,7 +628,7 @@ class DatasetService:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str):
|
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Update pipeline knowledge base node data.
|
Update pipeline knowledge base node data.
|
||||||
"""
|
"""
|
||||||
|
|
@ -701,7 +711,9 @@ class DatasetService:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_indexing_technique_change(dataset, data, filtered_data):
|
def _handle_indexing_technique_change(
|
||||||
|
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
|
||||||
|
) -> str | None:
|
||||||
"""
|
"""
|
||||||
Handle changes in indexing technique and configure embedding models accordingly.
|
Handle changes in indexing technique and configure embedding models accordingly.
|
||||||
|
|
||||||
|
|
@ -732,7 +744,7 @@ class DatasetService:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_embedding_model_for_high_quality(data, filtered_data):
|
def _configure_embedding_model_for_high_quality(data: dict[str, Any], filtered_data: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
Configure embedding model settings for high quality indexing.
|
Configure embedding model settings for high quality indexing.
|
||||||
|
|
||||||
|
|
@ -767,7 +779,9 @@ class DatasetService:
|
||||||
raise ValueError(ex.description)
|
raise ValueError(ex.description)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
|
def _handle_embedding_model_update_when_technique_unchanged(
|
||||||
|
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
|
||||||
|
) -> str | None:
|
||||||
"""
|
"""
|
||||||
Handle embedding model updates when indexing technique remains the same.
|
Handle embedding model updates when indexing technique remains the same.
|
||||||
|
|
||||||
|
|
@ -792,7 +806,7 @@ class DatasetService:
|
||||||
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
|
return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _preserve_existing_embedding_settings(dataset, filtered_data):
|
def _preserve_existing_embedding_settings(dataset: Dataset, filtered_data: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
Preserve existing embedding model settings when not provided in update.
|
Preserve existing embedding model settings when not provided in update.
|
||||||
|
|
||||||
|
|
@ -815,7 +829,9 @@ class DatasetService:
|
||||||
del filtered_data["embedding_model"]
|
del filtered_data["embedding_model"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_embedding_model_settings(dataset, data, filtered_data):
|
def _update_embedding_model_settings(
|
||||||
|
dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]
|
||||||
|
) -> str | None:
|
||||||
"""
|
"""
|
||||||
Update embedding model settings with new values.
|
Update embedding model settings with new values.
|
||||||
|
|
||||||
|
|
@ -849,7 +865,7 @@ class DatasetService:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _apply_new_embedding_settings(dataset, data, filtered_data):
|
def _apply_new_embedding_settings(dataset: Dataset, data: dict[str, Any], filtered_data: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
Apply new embedding model settings to the dataset.
|
Apply new embedding model settings to the dataset.
|
||||||
|
|
||||||
|
|
@ -946,7 +962,7 @@ class DatasetService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_rag_pipeline_dataset_settings(
|
def update_rag_pipeline_dataset_settings(
|
||||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||||
):
|
) -> None:
|
||||||
if not current_user or not current_user.current_tenant_id:
|
if not current_user or not current_user.current_tenant_id:
|
||||||
raise ValueError("Current user or current tenant not found")
|
raise ValueError("Current user or current tenant not found")
|
||||||
dataset = session.merge(dataset)
|
dataset = session.merge(dataset)
|
||||||
|
|
@ -1101,7 +1117,7 @@ class DatasetService:
|
||||||
deal_dataset_index_update_task.delay(dataset.id, action)
|
deal_dataset_index_update_task.delay(dataset.id, action)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_dataset(dataset_id, user):
|
def delete_dataset(dataset_id: str, user: Account) -> bool:
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
|
|
@ -1116,12 +1132,14 @@ class DatasetService:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dataset_use_check(dataset_id) -> bool:
|
def dataset_use_check(dataset_id: str) -> bool:
|
||||||
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
|
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
|
||||||
return db.session.execute(stmt).scalar_one()
|
return db.session.execute(stmt).scalar_one()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_dataset_permission(dataset, user):
|
def check_dataset_permission(dataset: Dataset, user: Account | None) -> None:
|
||||||
|
if not user:
|
||||||
|
raise NoPermissionError("User not found.")
|
||||||
if dataset.tenant_id != user.current_tenant_id:
|
if dataset.tenant_id != user.current_tenant_id:
|
||||||
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
|
|
@ -1140,7 +1158,7 @@ class DatasetService:
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None):
|
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None) -> None:
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset not found")
|
raise ValueError("Dataset not found")
|
||||||
|
|
||||||
|
|
@ -1160,15 +1178,15 @@ class DatasetService:
|
||||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
def get_dataset_queries(dataset_id: str, page: int, per_page: int) -> tuple[list[DatasetQuery], int]:
|
||||||
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
|
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
|
||||||
|
|
||||||
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||||
|
|
||||||
return dataset_queries.items, dataset_queries.total
|
return dataset_queries.items, dataset_queries.total or 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_related_apps(dataset_id: str):
|
def get_related_apps(dataset_id: str) -> list[AppDatasetJoin]:
|
||||||
return (
|
return (
|
||||||
db.session.query(AppDatasetJoin)
|
db.session.query(AppDatasetJoin)
|
||||||
.where(AppDatasetJoin.dataset_id == dataset_id)
|
.where(AppDatasetJoin.dataset_id == dataset_id)
|
||||||
|
|
@ -1177,7 +1195,7 @@ class DatasetService:
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_dataset_api_status(dataset_id: str, status: bool):
|
def update_dataset_api_status(dataset_id: str, status: bool) -> None:
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
@ -1189,7 +1207,7 @@ class DatasetService:
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset_auto_disable_logs(dataset_id: str):
|
def get_dataset_auto_disable_logs(dataset_id: str) -> dict[str, Any]:
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
assert current_user.current_tenant_id is not None
|
assert current_user.current_tenant_id is not None
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
@ -1288,7 +1306,7 @@ class DocumentService:
|
||||||
return cls.DISPLAY_STATUS_FILTERS[normalized]
|
return cls.DISPLAY_STATUS_FILTERS[normalized]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply_display_status_filter(cls, query, status: str | None):
|
def apply_display_status_filter(cls, query: Any, status: str | None) -> Any:
|
||||||
filters = cls.build_display_status_filters(status)
|
filters = cls.build_display_status_filters(status)
|
||||||
if not filters:
|
if not filters:
|
||||||
return query
|
return query
|
||||||
|
|
@ -1684,19 +1702,19 @@ class DocumentService:
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_document_file_detail(file_id: str):
|
def get_document_file_detail(file_id: str) -> UploadFile | None:
|
||||||
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
|
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
|
||||||
return file_detail
|
return file_detail
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_archived(document):
|
def check_archived(document: Document) -> bool:
|
||||||
if document.archived:
|
if document.archived:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_document(document):
|
def delete_document(document: Document) -> None:
|
||||||
# trigger document_was_deleted signal
|
# trigger document_was_deleted signal
|
||||||
file_id = None
|
file_id = None
|
||||||
if document.data_source_type == DataSourceType.UPLOAD_FILE:
|
if document.data_source_type == DataSourceType.UPLOAD_FILE:
|
||||||
|
|
@ -1712,7 +1730,7 @@ class DocumentService:
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_documents(dataset: Dataset, document_ids: list[str]):
|
def delete_documents(dataset: Dataset, document_ids: list[str]) -> None:
|
||||||
# Check if document_ids is not empty to avoid WHERE false condition
|
# Check if document_ids is not empty to avoid WHERE false condition
|
||||||
if not document_ids or len(document_ids) == 0:
|
if not document_ids or len(document_ids) == 0:
|
||||||
return
|
return
|
||||||
|
|
@ -1768,7 +1786,7 @@ class DocumentService:
|
||||||
return document
|
return document
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pause_document(document):
|
def pause_document(document: Document) -> None:
|
||||||
if document.indexing_status not in {
|
if document.indexing_status not in {
|
||||||
IndexingStatus.WAITING,
|
IndexingStatus.WAITING,
|
||||||
IndexingStatus.PARSING,
|
IndexingStatus.PARSING,
|
||||||
|
|
@ -1790,7 +1808,7 @@ class DocumentService:
|
||||||
redis_client.setnx(indexing_cache_key, "True")
|
redis_client.setnx(indexing_cache_key, "True")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def recover_document(document):
|
def recover_document(document: Document) -> None:
|
||||||
if not document.is_paused:
|
if not document.is_paused:
|
||||||
raise DocumentIndexingError()
|
raise DocumentIndexingError()
|
||||||
# update document to be recover
|
# update document to be recover
|
||||||
|
|
@ -1807,7 +1825,7 @@ class DocumentService:
|
||||||
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def retry_document(dataset_id: str, documents: list[Document]):
|
def retry_document(dataset_id: str, documents: list[Document]) -> None:
|
||||||
for document in documents:
|
for document in documents:
|
||||||
# add retry flag
|
# add retry flag
|
||||||
retry_indexing_cache_key = f"document_{document.id}_is_retried"
|
retry_indexing_cache_key = f"document_{document.id}_is_retried"
|
||||||
|
|
@ -1827,7 +1845,7 @@ class DocumentService:
|
||||||
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
|
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sync_website_document(dataset_id: str, document: Document):
|
def sync_website_document(dataset_id: str, document: Document) -> None:
|
||||||
# add sync flag
|
# add sync flag
|
||||||
sync_indexing_cache_key = f"document_{document.id}_is_sync"
|
sync_indexing_cache_key = f"document_{document.id}_is_sync"
|
||||||
cache_result = redis_client.get(sync_indexing_cache_key)
|
cache_result = redis_client.get(sync_indexing_cache_key)
|
||||||
|
|
@ -1847,7 +1865,7 @@ class DocumentService:
|
||||||
sync_website_document_indexing_task.delay(dataset_id, document.id)
|
sync_website_document_indexing_task.delay(dataset_id, document.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_documents_position(dataset_id):
|
def get_documents_position(dataset_id: str) -> int:
|
||||||
document = (
|
document = (
|
||||||
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -97,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
||||||
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
|
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_model_configs(tenant_id: str, app_id: str):
|
def _delete_app_model_configs(tenant_id: str, app_id: str) -> None:
|
||||||
def del_model_config(session, model_config_id: str):
|
def del_model_config(session: Any, model_config_id: str) -> None:
|
||||||
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
|
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -109,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_site(tenant_id: str, app_id: str):
|
def _delete_app_site(tenant_id: str, app_id: str) -> None:
|
||||||
def del_site(session, site_id: str):
|
def del_site(session: Any, site_id: str) -> None:
|
||||||
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -121,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
def _delete_app_mcp_servers(tenant_id: str, app_id: str) -> None:
|
||||||
def del_mcp_server(session, mcp_server_id: str):
|
def del_mcp_server(session: Any, mcp_server_id: str) -> None:
|
||||||
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -133,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
def _delete_app_api_tokens(tenant_id: str, app_id: str) -> None:
|
||||||
def del_api_token(session, api_token_id: str):
|
def del_api_token(session: Any, api_token_id: str) -> None:
|
||||||
# Fetch token details for cache invalidation
|
# Fetch token details for cache invalidation
|
||||||
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
|
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
|
||||||
if token_obj:
|
if token_obj:
|
||||||
|
|
@ -151,8 +151,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_installed_apps(tenant_id: str, app_id: str):
|
def _delete_installed_apps(tenant_id: str, app_id: str) -> None:
|
||||||
def del_installed_app(session, installed_app_id: str):
|
def del_installed_app(session: Any, installed_app_id: str) -> None:
|
||||||
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
|
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -163,8 +163,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_recommended_apps(tenant_id: str, app_id: str):
|
def _delete_recommended_apps(tenant_id: str, app_id: str) -> None:
|
||||||
def del_recommended_app(session, recommended_app_id: str):
|
def del_recommended_app(session: Any, recommended_app_id: str) -> None:
|
||||||
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
|
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -175,8 +175,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
def _delete_app_annotation_data(tenant_id: str, app_id: str) -> None:
|
||||||
def del_annotation_hit_history(session, annotation_hit_history_id: str):
|
def del_annotation_hit_history(session: Any, annotation_hit_history_id: str) -> None:
|
||||||
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
|
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
@ -188,7 +188,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||||
"annotation hit history",
|
"annotation hit history",
|
||||||
)
|
)
|
||||||
|
|
||||||
def del_annotation_setting(session, annotation_setting_id: str):
|
def del_annotation_setting(session: Any, annotation_setting_id: str) -> None:
|
||||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
|
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
@ -201,8 +201,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
def _delete_app_dataset_joins(tenant_id: str, app_id: str) -> None:
|
||||||
def del_dataset_join(session, dataset_join_id: str):
|
def del_dataset_join(session: Any, dataset_join_id: str) -> None:
|
||||||
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
|
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -213,8 +213,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_workflows(tenant_id: str, app_id: str):
|
def _delete_app_workflows(tenant_id: str, app_id: str) -> None:
|
||||||
def del_workflow(session, workflow_id: str):
|
def del_workflow(session: Any, workflow_id: str) -> None:
|
||||||
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
|
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -225,7 +225,7 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
def _delete_app_workflow_runs(tenant_id: str, app_id: str) -> None:
|
||||||
"""Delete all workflow runs for an app using the service repository."""
|
"""Delete all workflow runs for an app using the service repository."""
|
||||||
session_maker = sessionmaker(bind=db.engine)
|
session_maker = sessionmaker(bind=db.engine)
|
||||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
@ -239,7 +239,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
||||||
logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id)
|
logger.info("Deleted %s workflow runs for app %s", deleted_count, app_id)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str) -> None:
|
||||||
"""Delete all workflow node executions for an app using the service repository."""
|
"""Delete all workflow node executions for an app using the service repository."""
|
||||||
session_maker = sessionmaker(bind=db.engine)
|
session_maker = sessionmaker(bind=db.engine)
|
||||||
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
||||||
|
|
@ -253,8 +253,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||||
logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id)
|
logger.info("Deleted %s workflow node executions for app %s", deleted_count, app_id)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str) -> None:
|
||||||
def del_workflow_app_log(session, workflow_app_log_id: str):
|
def del_workflow_app_log(session: Any, workflow_app_log_id: str) -> None:
|
||||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
|
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -265,8 +265,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str) -> None:
|
||||||
def del_workflow_archive_log(session, workflow_archive_log_id: str):
|
def del_workflow_archive_log(session: Any, workflow_archive_log_id: str) -> None:
|
||||||
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
@ -279,7 +279,7 @@ def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
|
def _delete_archived_workflow_run_files(tenant_id: str, app_id: str) -> None:
|
||||||
prefix = f"{tenant_id}/app_id={app_id}/"
|
prefix = f"{tenant_id}/app_id={app_id}/"
|
||||||
try:
|
try:
|
||||||
archive_storage = get_archive_storage()
|
archive_storage = get_archive_storage()
|
||||||
|
|
@ -304,8 +304,8 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
|
||||||
logger.info("Deleted %s archive objects for app %s", deleted, app_id)
|
logger.info("Deleted %s archive objects for app %s", deleted, app_id)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_conversations(tenant_id: str, app_id: str):
|
def _delete_app_conversations(tenant_id: str, app_id: str) -> None:
|
||||||
def del_conversation(session, conversation_id: str):
|
def del_conversation(session: Any, conversation_id: str) -> None:
|
||||||
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
@ -319,7 +319,7 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_conversation_variables(*, app_id: str):
|
def _delete_conversation_variables(app_id: str) -> None:
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
|
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
|
||||||
session.execute(stmt)
|
session.execute(stmt)
|
||||||
|
|
@ -327,8 +327,8 @@ def _delete_conversation_variables(*, app_id: str):
|
||||||
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
|
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_messages(tenant_id: str, app_id: str):
|
def _delete_app_messages(tenant_id: str, app_id: str) -> None:
|
||||||
def del_message(session, message_id: str):
|
def del_message(session: Any, message_id: str) -> None:
|
||||||
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
|
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
|
||||||
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
|
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
|
|
@ -349,8 +349,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
def _delete_workflow_tool_providers(tenant_id: str, app_id: str) -> None:
|
||||||
def del_tool_provider(session, tool_provider_id: str):
|
def del_tool_provider(session: Any, tool_provider_id: str) -> None:
|
||||||
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
|
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
@ -363,8 +363,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
def _delete_app_tag_bindings(tenant_id: str, app_id: str) -> None:
|
||||||
def del_tag_binding(session, tag_binding_id: str):
|
def del_tag_binding(session: Any, tag_binding_id: str) -> None:
|
||||||
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
|
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -375,8 +375,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_end_users(tenant_id: str, app_id: str):
|
def _delete_end_users(tenant_id: str, app_id: str) -> None:
|
||||||
def del_end_user(session, end_user_id: str):
|
def del_end_user(session: Any, end_user_id: str) -> None:
|
||||||
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
|
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -387,8 +387,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
def _delete_trace_app_configs(tenant_id: str, app_id: str) -> None:
|
||||||
def del_trace_app_config(session, trace_app_config_id: str):
|
def del_trace_app_config(session: Any, trace_app_config_id: str) -> None:
|
||||||
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
|
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -399,9 +399,9 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_draft_variables(app_id: str):
|
def _delete_draft_variables(app_id: str) -> None:
|
||||||
"""Delete all workflow draft variables for an app in batches."""
|
"""Delete all workflow draft variables for an app in batches."""
|
||||||
return delete_draft_variables_batch(app_id, batch_size=1000)
|
delete_draft_variables_batch(app_id, batch_size=1000)
|
||||||
|
|
||||||
|
|
||||||
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||||
|
|
@ -543,8 +543,8 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
|
||||||
return files_deleted
|
return files_deleted
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_triggers(tenant_id: str, app_id: str):
|
def _delete_app_triggers(tenant_id: str, app_id: str) -> None:
|
||||||
def del_app_trigger(session, trigger_id: str):
|
def del_app_trigger(session: Any, trigger_id: str) -> None:
|
||||||
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
|
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -555,8 +555,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str) -> None:
|
||||||
def del_plugin_trigger(session, trigger_id: str):
|
def del_plugin_trigger(session: Any, trigger_id: str) -> None:
|
||||||
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
|
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
@ -569,8 +569,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str) -> None:
|
||||||
def del_webhook_trigger(session, trigger_id: str):
|
def del_webhook_trigger(session: Any, trigger_id: str) -> None:
|
||||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
|
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
@ -583,8 +583,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str) -> None:
|
||||||
def del_schedule_plan(session, plan_id: str):
|
def del_schedule_plan(session: Any, plan_id: str) -> None:
|
||||||
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
|
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
@ -595,8 +595,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
|
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str) -> None:
|
||||||
def del_trigger_log(session, log_id: str):
|
def del_trigger_log(session: Any, log_id: str) -> None:
|
||||||
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
|
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
|
||||||
|
|
||||||
_delete_records(
|
_delete_records(
|
||||||
|
|
|
||||||
|
|
@ -28,12 +28,11 @@ class TestDeleteDraftVariablesBatch:
|
||||||
def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete):
|
def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete):
|
||||||
"""Test that _delete_draft_variables calls the batch function correctly."""
|
"""Test that _delete_draft_variables calls the batch function correctly."""
|
||||||
app_id = "test-app-id"
|
app_id = "test-app-id"
|
||||||
expected_return = 42
|
mock_batch_delete.return_value = 42
|
||||||
mock_batch_delete.return_value = expected_return
|
|
||||||
|
|
||||||
result = _delete_draft_variables(app_id)
|
result = _delete_draft_variables(app_id)
|
||||||
|
|
||||||
assert result == expected_return
|
assert result is None
|
||||||
mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)
|
mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue