mirror of https://github.com/langgenius/dify.git
refactor: replace sa.String with EnumText in mapped_column for type s… (#33332)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
6043ec4423
commit
e64f4d6039
|
|
@ -43,6 +43,7 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
|
@ -231,7 +232,7 @@ class AccountInitApi(Resource):
|
|||
account.interface_language = args.interface_language
|
||||
account.timezone = args.timezone
|
||||
account.interface_theme = "light"
|
||||
account.status = "active"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -38,7 +39,9 @@ class DatasetIndexToolCallbackHandler:
|
|||
source="app",
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=(
|
||||
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
||||
CreatorUserRole.ACCOUNT
|
||||
if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatorUserRole.END_USER
|
||||
),
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -628,10 +628,10 @@ class TraceTask:
|
|||
if not message_data:
|
||||
return {}
|
||||
conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
|
||||
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
|
||||
if not conversation_mode or len(conversation_mode) == 0:
|
||||
conversation_modes = db.session.scalars(conversation_mode_stmt).all()
|
||||
if not conversation_modes or len(conversation_modes) == 0:
|
||||
return {}
|
||||
conversation_mode = conversation_mode[0]
|
||||
conversation_mode = conversation_modes[0]
|
||||
created_at = message_data.created_at
|
||||
inputs = message_data.message
|
||||
|
||||
|
|
|
|||
|
|
@ -627,7 +627,7 @@ class ProviderManager:
|
|||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=quota.quota_type,
|
||||
quota_limit=0, # type: ignore
|
||||
quota_used=0,
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ from models.dataset import (
|
|||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.enums import CreatorUserRole
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -1009,7 +1010,7 @@ class DatasetRetrieval:
|
|||
content=json.dumps(contents),
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by_role=CreatorUserRole(user_from),
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
|
|
|
|||
|
|
@ -146,7 +146,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
|
||||
# No sequence number generation needed anymore
|
||||
|
||||
db_model.type = domain_model.workflow_type
|
||||
from models.workflow import WorkflowType as ModelWorkflowType
|
||||
|
||||
db_model.type = ModelWorkflowType(domain_model.workflow_type.value)
|
||||
db_model.version = domain_model.workflow_version
|
||||
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
|
||||
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
|
|||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -47,12 +48,28 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
|
|||
model.tenant_id = data.get("tenant_id") or ""
|
||||
model.app_id = data.get("app_id") or ""
|
||||
model.workflow_id = data.get("workflow_id") or ""
|
||||
model.triggered_from = data.get("triggered_from") or ""
|
||||
triggered_from_val = data.get("triggered_from")
|
||||
try:
|
||||
model.triggered_from = (
|
||||
WorkflowNodeExecutionTriggeredFrom(str(triggered_from_val))
|
||||
if triggered_from_val
|
||||
else WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning("Invalid triggered_from value: %s, falling back to WORKFLOW_RUN", triggered_from_val)
|
||||
model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
model.node_id = data.get("node_id") or ""
|
||||
model.node_type = data.get("node_type") or ""
|
||||
model.status = data.get("status") or "running" # Default status if missing
|
||||
model.title = data.get("title") or ""
|
||||
model.created_by_role = data.get("created_by_role") or ""
|
||||
created_by_role_val = data.get("created_by_role")
|
||||
try:
|
||||
model.created_by_role = (
|
||||
CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val)
|
||||
model.created_by_role = CreatorUserRole.ACCOUNT
|
||||
model.created_by = data.get("created_by") or ""
|
||||
|
||||
model.index = safe_int(data.get("index", 0))
|
||||
|
|
|
|||
|
|
@ -22,12 +22,13 @@ from typing import Any, cast
|
|||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun, WorkflowType
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.types import (
|
||||
AverageInteractionStats,
|
||||
|
|
@ -59,11 +60,37 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
|
|||
model.tenant_id = data.get("tenant_id") or ""
|
||||
model.app_id = data.get("app_id") or ""
|
||||
model.workflow_id = data.get("workflow_id") or ""
|
||||
model.type = data.get("type") or ""
|
||||
model.triggered_from = data.get("triggered_from") or ""
|
||||
type_val = data.get("type")
|
||||
try:
|
||||
model.type = WorkflowType(str(type_val)) if type_val else WorkflowType.WORKFLOW
|
||||
except ValueError:
|
||||
logger.warning("Invalid type value: %s, falling back to WORKFLOW", type_val)
|
||||
model.type = WorkflowType.WORKFLOW
|
||||
triggered_from_val = data.get("triggered_from")
|
||||
try:
|
||||
model.triggered_from = (
|
||||
WorkflowRunTriggeredFrom(str(triggered_from_val))
|
||||
if triggered_from_val
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning("Invalid triggered_from value: %s, falling back to APP_RUN", triggered_from_val)
|
||||
model.triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
model.version = data.get("version") or ""
|
||||
model.status = data.get("status") or "running" # Default status if missing
|
||||
model.created_by_role = data.get("created_by_role") or ""
|
||||
status_val = data.get("status")
|
||||
try:
|
||||
model.status = WorkflowExecutionStatus(str(status_val)) if status_val else WorkflowExecutionStatus.RUNNING
|
||||
except ValueError:
|
||||
logger.warning("Invalid status value: %s, falling back to RUNNING", status_val)
|
||||
model.status = WorkflowExecutionStatus.RUNNING
|
||||
created_by_role_val = data.get("created_by_role")
|
||||
try:
|
||||
model.created_by_role = (
|
||||
CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val)
|
||||
model.created_by_role = CreatorUserRole.ACCOUNT
|
||||
model.created_by = data.get("created_by") or ""
|
||||
|
||||
model.total_tokens = safe_int(data.get("total_tokens", 0))
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ from uuid import uuid4
|
|||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
class TenantAccountRole(enum.StrEnum):
|
||||
|
|
@ -104,7 +104,9 @@ class Account(UserMixin, TypeBase):
|
|||
last_active_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'"), default="active")
|
||||
status: Mapped[AccountStatus] = mapped_column(
|
||||
EnumText(AccountStatus, length=16), server_default=sa.text("'active'"), default=AccountStatus.ACTIVE
|
||||
)
|
||||
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
|
|
@ -116,12 +118,6 @@ class Account(UserMixin, TypeBase):
|
|||
role: TenantAccountRole | None = field(default=None, init=False)
|
||||
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||
|
||||
@validates("status")
|
||||
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
|
||||
if isinstance(value, AccountStatus):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
return self.password is not None
|
||||
|
|
@ -177,8 +173,7 @@ class Account(UserMixin, TypeBase):
|
|||
return self.role
|
||||
|
||||
def get_status(self) -> AccountStatus:
|
||||
status_str = self.status
|
||||
return AccountStatus(status_str)
|
||||
return self.status
|
||||
|
||||
@classmethod
|
||||
def get_by_openid(cls, provider: str, open_id: str):
|
||||
|
|
@ -249,7 +244,9 @@ class Tenant(TypeBase):
|
|||
name: Mapped[str] = mapped_column(String(255))
|
||||
encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic")
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'"), default="normal")
|
||||
status: Mapped[TenantStatus] = mapped_column(
|
||||
EnumText(TenantStatus, length=255), server_default=sa.text("'normal'"), default=TenantStatus.NORMAL
|
||||
)
|
||||
custom_config: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
|
|
@ -291,7 +288,9 @@ class TenantAccountJoin(TypeBase):
|
|||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
|
||||
role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
|
||||
role: Mapped[TenantAccountRole] = mapped_column(
|
||||
EnumText(TenantAccountRole, length=16), server_default="normal", default=TenantAccountRole.NORMAL
|
||||
)
|
||||
invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
|
||||
|
|
|
|||
|
|
@ -30,8 +30,9 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
|
|||
from .account import Account
|
||||
from .base import Base, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -59,7 +60,11 @@ class Dataset(Base):
|
|||
name: Mapped[str] = mapped_column(String(255))
|
||||
description = mapped_column(LongText, nullable=True)
|
||||
provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'"))
|
||||
permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'"))
|
||||
permission: Mapped[DatasetPermissionEnum] = mapped_column(
|
||||
EnumText(DatasetPermissionEnum, length=255),
|
||||
server_default=sa.text("'only_me'"),
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
data_source_type = mapped_column(String(255))
|
||||
indexing_technique: Mapped[str | None] = mapped_column(String(255))
|
||||
index_struct = mapped_column(LongText, nullable=True)
|
||||
|
|
@ -1003,7 +1008,7 @@ class DatasetQuery(TypeBase):
|
|||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
|
|
|
|||
|
|
@ -29,9 +29,9 @@ from libs.uuid_utils import uuidv7
|
|||
from .account import Account, Tenant
|
||||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .enums import CreatorUserRole, MessageStatus
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .workflow import Workflow
|
||||
|
|
@ -337,8 +337,8 @@ class App(Base):
|
|||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
description: Mapped[str] = mapped_column(LongText, default=sa.text("''"))
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link
|
||||
mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255))
|
||||
icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255))
|
||||
icon = mapped_column(String(255))
|
||||
icon_background: Mapped[str | None] = mapped_column(String(255))
|
||||
app_model_config_id = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1000,7 +1000,7 @@ class Conversation(Base):
|
|||
model_provider = mapped_column(String(255), nullable=True)
|
||||
override_model_configs = mapped_column(LongText)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255))
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
summary = mapped_column(LongText)
|
||||
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
|
||||
|
|
@ -1351,7 +1351,12 @@ class Message(Base):
|
|||
provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7))
|
||||
currency: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'"))
|
||||
status: Mapped[MessageStatus] = mapped_column(
|
||||
EnumText(MessageStatus, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'normal'"),
|
||||
default=MessageStatus.NORMAL,
|
||||
)
|
||||
error: Mapped[str | None] = mapped_column(LongText)
|
||||
message_metadata: Mapped[str | None] = mapped_column(LongText)
|
||||
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
|
@ -1364,7 +1369,7 @@ class Message(Base):
|
|||
)
|
||||
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
app_mode: Mapped[AppMode | None] = mapped_column(EnumText(AppMode, length=255), nullable=True)
|
||||
|
||||
@property
|
||||
def inputs(self) -> dict[str, Any]:
|
||||
|
|
@ -1767,7 +1772,7 @@ class MessageFile(TypeBase):
|
|||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
|
@ -2015,7 +2020,7 @@ class Site(Base):
|
|||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
icon_type = mapped_column(String(255), nullable=True)
|
||||
icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255), nullable=True)
|
||||
icon = mapped_column(String(255))
|
||||
icon_background = mapped_column(String(255))
|
||||
description = mapped_column(LongText)
|
||||
|
|
@ -2110,7 +2115,12 @@ class UploadFile(Base):
|
|||
|
||||
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
|
||||
# Its value is derived from the `CreatorUserRole` enumeration.
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'account'"))
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(
|
||||
EnumText(CreatorUserRole, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'account'"),
|
||||
default=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
|
||||
# The `created_by` field stores the ID of the entity that created this upload file.
|
||||
#
|
||||
|
|
@ -2163,7 +2173,7 @@ class UploadFile(Base):
|
|||
self.size = size
|
||||
self.extension = extension
|
||||
self.mime_type = mime_type
|
||||
self.created_by_role = created_by_role.value
|
||||
self.created_by_role = created_by_role
|
||||
self.created_by = created_by
|
||||
self.created_at = created_at
|
||||
self.used = used
|
||||
|
|
@ -2226,7 +2236,7 @@ class MessageAgentThought(TypeBase):
|
|||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7
|
|||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
class ProviderType(StrEnum):
|
||||
|
|
@ -69,8 +69,8 @@ class Provider(TypeBase):
|
|||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String(40), nullable=False, server_default=text("'custom'"), default="custom"
|
||||
provider_type: Mapped[ProviderType] = mapped_column(
|
||||
EnumText(ProviderType, length=40), nullable=False, server_default=text("'custom'"), default=ProviderType.CUSTOM
|
||||
)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False)
|
||||
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ class WorkflowTriggerLog(TypeBase):
|
|||
|
||||
queue_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@ from datetime import datetime
|
|||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .model import Message
|
||||
from .types import StringUUID
|
||||
from .types import EnumText, StringUUID
|
||||
|
||||
|
||||
class SavedMessage(TypeBase):
|
||||
|
|
@ -24,7 +25,9 @@ class SavedMessage(TypeBase):
|
|||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'"))
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(
|
||||
EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'")
|
||||
)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
|
|
@ -50,8 +53,8 @@ class PinnedConversation(TypeBase):
|
|||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(
|
||||
EnumText(CreatorUserRole, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'end_user'"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ from libs import helper
|
|||
from .account import Account
|
||||
from .base import Base, DefaultFieldsMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -141,7 +141,7 @@ class Workflow(Base): # bug
|
|||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255), nullable=False)
|
||||
version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
marked_name: Mapped[str] = mapped_column(String(255), default="", server_default="")
|
||||
marked_comment: Mapped[str] = mapped_column(String(255), default="", server_default="")
|
||||
|
|
@ -188,7 +188,7 @@ class Workflow(Base): # bug
|
|||
workflow.id = str(uuid4())
|
||||
workflow.tenant_id = tenant_id
|
||||
workflow.app_id = app_id
|
||||
workflow.type = type
|
||||
workflow.type = WorkflowType(type)
|
||||
workflow.version = version
|
||||
workflow.graph = graph
|
||||
workflow.features = features
|
||||
|
|
@ -608,8 +608,8 @@ class WorkflowRun(Base):
|
|||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID)
|
||||
type: Mapped[str] = mapped_column(String(255))
|
||||
triggered_from: Mapped[str] = mapped_column(String(255))
|
||||
type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255))
|
||||
triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(EnumText(WorkflowRunTriggeredFrom, length=255))
|
||||
version: Mapped[str] = mapped_column(String(255))
|
||||
graph: Mapped[str | None] = mapped_column(LongText)
|
||||
inputs: Mapped[str | None] = mapped_column(LongText)
|
||||
|
|
@ -830,7 +830,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID)
|
||||
triggered_from: Mapped[str] = mapped_column(String(255))
|
||||
triggered_from: Mapped[WorkflowNodeExecutionTriggeredFrom] = mapped_column(
|
||||
EnumText(WorkflowNodeExecutionTriggeredFrom, length=255)
|
||||
)
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
index: Mapped[int] = mapped_column(sa.Integer)
|
||||
predecessor_node_id: Mapped[str | None] = mapped_column(String(255))
|
||||
|
|
@ -846,7 +848,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
|
||||
execution_metadata: Mapped[str | None] = mapped_column(LongText)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
created_by_role: Mapped[str] = mapped_column(String(255))
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255))
|
||||
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
|
||||
|
|
@ -1130,7 +1132,7 @@ class WorkflowAppLog(TypeBase):
|
|||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
|
|
@ -1204,7 +1206,7 @@ class WorkflowArchiveLog(TypeBase):
|
|||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1213,7 +1215,9 @@ class WorkflowArchiveLog(TypeBase):
|
|||
|
||||
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
|
||||
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
|
||||
)
|
||||
run_error: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
|
|
|
|||
|
|
@ -1089,9 +1089,9 @@ class TenantService:
|
|||
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if ta:
|
||||
ta.role = role
|
||||
ta.role = TenantAccountRole(role)
|
||||
else:
|
||||
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
|
||||
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole(role))
|
||||
db.session.add(ta)
|
||||
|
||||
db.session.commit()
|
||||
|
|
@ -1319,10 +1319,10 @@ class TenantService:
|
|||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
)
|
||||
if current_owner_join:
|
||||
current_owner_join.role = "admin"
|
||||
current_owner_join.role = TenantAccountRole.ADMIN
|
||||
|
||||
# Update the role of the target member
|
||||
target_member_join.role = new_role
|
||||
target_member_join.role = TenantAccountRole(new_role)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -429,17 +429,18 @@ class AppDslService:
|
|||
|
||||
# Set icon type
|
||||
icon_type_value = icon_type or app_data.get("icon_type")
|
||||
resolved_icon_type: IconType
|
||||
if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]:
|
||||
icon_type = icon_type_value
|
||||
resolved_icon_type = IconType(icon_type_value)
|
||||
else:
|
||||
icon_type = IconType.EMOJI
|
||||
resolved_icon_type = IconType.EMOJI
|
||||
icon = icon or str(app_data.get("icon", ""))
|
||||
|
||||
if app:
|
||||
# Update existing app
|
||||
app.name = name or app_data.get("name", app.name)
|
||||
app.description = description or app_data.get("description", app.description)
|
||||
app.icon_type = icon_type
|
||||
app.icon_type = resolved_icon_type
|
||||
app.icon = icon
|
||||
app.icon_background = icon_background or app_data.get("icon_background", app.icon_background)
|
||||
app.updated_by = account.id
|
||||
|
|
@ -452,10 +453,10 @@ class AppDslService:
|
|||
app = App()
|
||||
app.id = str(uuid4())
|
||||
app.tenant_id = account.current_tenant_id
|
||||
app.mode = app_mode.value
|
||||
app.mode = app_mode
|
||||
app.name = name or app_data.get("name", "")
|
||||
app.description = description or app_data.get("description", "")
|
||||
app.icon_type = icon_type
|
||||
app.icon_type = resolved_icon_type
|
||||
app.icon = icon
|
||||
app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF")
|
||||
app.enable_site = True
|
||||
|
|
@ -549,7 +550,7 @@ class AppDslService:
|
|||
"kind": "app",
|
||||
"app": {
|
||||
"name": app_model.name,
|
||||
"mode": app_model.mode,
|
||||
"mode": app_model.mode.value if isinstance(app_model.mode, AppMode) else app_model.mode,
|
||||
"icon": app_model.icon if app_model.icon_type == "image" else "🤖",
|
||||
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
|
||||
"description": app_model.description,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from extensions.ext_database import db
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import App, AppMode, AppModelConfig, Site
|
||||
from models.model import App, AppMode, AppModelConfig, IconType, Site
|
||||
from models.tools import ApiToolProvider
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
|
@ -254,7 +254,7 @@ class AppService:
|
|||
assert current_user is not None
|
||||
app.name = args["name"]
|
||||
app.description = args["description"]
|
||||
app.icon_type = args["icon_type"]
|
||||
app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None
|
||||
app.icon = args["icon"]
|
||||
app.icon_background = args["icon_background"]
|
||||
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ class DatasetService:
|
|||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.model_name if embedding_model else None
|
||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||
dataset.permission = DatasetPermissionEnum(permission) if permission else DatasetPermissionEnum.ONLY_ME
|
||||
dataset.provider = provider
|
||||
if summary_index_setting is not None:
|
||||
dataset.summary_index_setting = summary_index_setting
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode
|
|||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -98,7 +99,7 @@ class HitTestingService:
|
|||
content=json.dumps(dataset_queries),
|
||||
source="hit_testing",
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
|
|
@ -138,7 +139,7 @@ class HitTestingService:
|
|||
content=query,
|
||||
source="hit_testing",
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Union
|
|||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, EndUser
|
||||
from models.web import SavedMessage
|
||||
from services.message_service import MessageService
|
||||
|
|
@ -54,7 +55,7 @@ class SavedMessageService:
|
|||
saved_message = SavedMessage(
|
||||
app_id=app_model.id,
|
||||
message_id=message.id,
|
||||
created_by_role="account" if isinstance(user, Account) else "end_user",
|
||||
created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER,
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, EndUser
|
||||
from models.web import PinnedConversation
|
||||
from services.conversation_service import ConversationService
|
||||
|
|
@ -84,7 +85,7 @@ class WebConversationService:
|
|||
pinned_conversation = PinnedConversation(
|
||||
app_id=app_model.id,
|
||||
conversation_id=conversation.id,
|
||||
created_by_role="account" if isinstance(user, Account) else "end_user",
|
||||
created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER,
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from events.app_event import app_was_created
|
|||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.model import App, AppMode, AppModelConfig, IconType
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
|
||||
|
|
@ -72,7 +72,7 @@ class WorkflowConverter:
|
|||
new_app.tenant_id = app_model.tenant_id
|
||||
new_app.name = name or app_model.name + "(workflow)"
|
||||
new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW
|
||||
new_app.icon_type = icon_type or app_model.icon_type
|
||||
new_app.icon_type = IconType(icon_type) if icon_type else app_model.icon_type
|
||||
new_app.icon = icon or app_model.icon
|
||||
new_app.icon_background = icon_background or app_model.icon_background
|
||||
new_app.enable_site = app_model.enable_site
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ def _record_trigger_failure_log(
|
|||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=created_by_role.value,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
created_at=now,
|
||||
finished_at=now,
|
||||
|
|
@ -179,7 +179,7 @@ def _record_trigger_failure_log(
|
|||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
|
||||
created_by_role=created_by_role.value,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
)
|
||||
session.add(workflow_app_log)
|
||||
|
|
@ -212,7 +212,7 @@ def _record_trigger_failure_log(
|
|||
error=error_message,
|
||||
queue_name=queue_name,
|
||||
retry_count=0,
|
||||
created_by_role=created_by_role.value,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
triggered_at=now,
|
||||
finished_at=now,
|
||||
|
|
|
|||
|
|
@ -94,13 +94,15 @@ def _create_workflow_run_from_execution(
|
|||
workflow_run.tenant_id = tenant_id
|
||||
workflow_run.app_id = app_id
|
||||
workflow_run.workflow_id = execution.workflow_id
|
||||
workflow_run.type = execution.workflow_type.value
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
from models.workflow import WorkflowType as ModelWorkflowType
|
||||
|
||||
workflow_run.type = ModelWorkflowType(execution.workflow_type.value)
|
||||
workflow_run.triggered_from = triggered_from
|
||||
workflow_run.version = execution.workflow_version
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
|
||||
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.status = execution.status
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
|
|
@ -108,7 +110,7 @@ def _create_workflow_run_from_execution(
|
|||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.created_by_role = creator_user_role.value
|
||||
workflow_run.created_by_role = creator_user_role
|
||||
workflow_run.created_by = creator_user_id
|
||||
workflow_run.created_at = execution.started_at
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
|
|
@ -121,7 +123,7 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo
|
|||
Update a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.status = execution.status
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def _create_node_execution_from_domain(
|
|||
node_execution.tenant_id = tenant_id
|
||||
node_execution.app_id = app_id
|
||||
node_execution.workflow_id = execution.workflow_id
|
||||
node_execution.triggered_from = triggered_from.value
|
||||
node_execution.triggered_from = triggered_from
|
||||
node_execution.workflow_run_id = execution.workflow_execution_id
|
||||
node_execution.index = execution.index
|
||||
node_execution.predecessor_node_id = execution.predecessor_node_id
|
||||
|
|
@ -128,7 +128,7 @@ def _create_node_execution_from_domain(
|
|||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.created_by_role = creator_user_role.value
|
||||
node_execution.created_by_role = creator_user_role
|
||||
node_execution.created_by = creator_user_id
|
||||
node_execution.created_at = execution.created_at
|
||||
node_execution.finished_at = execution.finished_at
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ class TestChatMessageApiPermissions:
|
|||
agent_thoughts=[],
|
||||
message_files=[],
|
||||
message_metadata_dict={},
|
||||
status="success",
|
||||
status="normal",
|
||||
error="",
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3331,7 +3331,7 @@ class TestRegisterService:
|
|||
TenantService.create_tenant_member(tenant, account, role="normal")
|
||||
|
||||
# Change tenant status to non-normal
|
||||
tenant.status = "suspended"
|
||||
tenant.status = "archive"
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import uuid
|
|||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -492,20 +493,20 @@ class TestAppGenerateService:
|
|||
)
|
||||
|
||||
# Manually set invalid mode after creation
|
||||
# With EnumText, invalid values are rejected at the DB level during flush,
|
||||
# raising StatementError wrapping ValueError
|
||||
app.mode = "invalid_mode"
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test and expect ValueError
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
# Execute the method under test and expect either ValueError (direct) or
|
||||
# StatementError (from EnumText validation during autoflush)
|
||||
with pytest.raises((ValueError, sa.exc.StatementError)):
|
||||
AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Invalid app mode" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_id_format_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ class TestSavedMessageService:
|
|||
answer_unit_price=0.002,
|
||||
total_price=0.003,
|
||||
currency="USD",
|
||||
status="success",
|
||||
status="normal",
|
||||
)
|
||||
|
||||
db_session_with_containers.add(message)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class TestWorkflowService:
|
|||
tenant = Tenant(
|
||||
name=f"Test Tenant {fake.company()}",
|
||||
plan="basic",
|
||||
status="active",
|
||||
status="normal",
|
||||
)
|
||||
tenant.id = account.current_tenant_id
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
|
|
@ -1090,20 +1090,19 @@ class TestWorkflowService:
|
|||
|
||||
This test ensures that the service correctly handles feature validation
|
||||
for unsupported app modes, preventing invalid operations.
|
||||
With EnumText, invalid values are rejected at the DB level during flush,
|
||||
raising StatementError wrapping ValueError.
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = "invalid_mode" # Invalid mode
|
||||
|
||||
db_session_with_containers.commit()
|
||||
# Act & Assert - EnumText validation rejects invalid values at DB flush
|
||||
import sqlalchemy as sa
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
features = {"test": "value"}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"):
|
||||
workflow_service.validate_features_structure(app_model=app, features=features)
|
||||
with pytest.raises((ValueError, sa.exc.StatementError)):
|
||||
db_session_with_containers.commit()
|
||||
|
||||
def test_update_workflow_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class TestCleanDatasetTask:
|
|||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
plan="basic",
|
||||
status="active",
|
||||
status="normal",
|
||||
)
|
||||
|
||||
db_session_with_containers.add(tenant)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class TestDeleteSegmentFromIndexTask:
|
|||
Tenant: Created test tenant instance
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active")
|
||||
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal")
|
||||
tenant.id = fake.uuid4()
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class TestDisableSegmentsFromIndexTask:
|
|||
tenant = Tenant(
|
||||
name=f"Test Tenant {fake.company()}",
|
||||
plan="basic",
|
||||
status="active",
|
||||
status="normal",
|
||||
)
|
||||
tenant.id = account.tenant_id
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class TestSendEmailCodeLoginMailTask:
|
|||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
plan="basic",
|
||||
status="active",
|
||||
status="normal",
|
||||
)
|
||||
|
||||
db_session_with_containers.add(tenant)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def make_message():
|
|||
msg.query = "hello"
|
||||
msg.re_sign_file_url_answer = ""
|
||||
msg.user_feedback = MagicMock(rating=None)
|
||||
msg.status = "success"
|
||||
msg.status = "normal"
|
||||
msg.error = None
|
||||
return msg
|
||||
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ def test_message_list_mapping(app: Flask) -> None:
|
|||
{"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"},
|
||||
message_file_obj,
|
||||
],
|
||||
status="success",
|
||||
status="normal",
|
||||
error=None,
|
||||
message_metadata_dict={"meta": "value"},
|
||||
extra_contents=[
|
||||
|
|
|
|||
|
|
@ -3730,7 +3730,7 @@ class TestDatasetRetrievalAdditionalHelpers:
|
|||
attachment_ids=None,
|
||||
dataset_ids=["d1"],
|
||||
app_id="a1",
|
||||
user_from="web",
|
||||
user_from="account",
|
||||
user_id="u1",
|
||||
)
|
||||
mock_session.add_all.assert_not_called()
|
||||
|
|
@ -3740,7 +3740,7 @@ class TestDatasetRetrievalAdditionalHelpers:
|
|||
attachment_ids=["f1"],
|
||||
dataset_ids=["d1", "d2"],
|
||||
app_id="a1",
|
||||
user_from="web",
|
||||
user_from="account",
|
||||
user_id="u1",
|
||||
)
|
||||
mock_session.add_all.assert_called()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Any
|
|||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
|
@ -112,37 +113,38 @@ def test_encrypt_tool_parameters():
|
|||
def test_decrypt_tool_parameters_cache_hit_and_miss():
|
||||
manager = _build_manager()
|
||||
|
||||
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
|
||||
cache = cache_cls.return_value
|
||||
cache.get.return_value = {"secret": "cached"}
|
||||
with (
|
||||
patch.object(ToolParameterCache, "get", return_value={"secret": "cached"}),
|
||||
patch.object(ToolParameterCache, "set") as mock_set,
|
||||
):
|
||||
assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"}
|
||||
cache.set.assert_not_called()
|
||||
mock_set.assert_not_called()
|
||||
|
||||
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
|
||||
cache = cache_cls.return_value
|
||||
cache.get.return_value = None
|
||||
with patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"):
|
||||
decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"})
|
||||
|
||||
assert decrypted["secret"] == "dec"
|
||||
cache.set.assert_called_once()
|
||||
with (
|
||||
patch.object(ToolParameterCache, "get", return_value=None),
|
||||
patch.object(ToolParameterCache, "set") as mock_set,
|
||||
patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"),
|
||||
):
|
||||
decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"})
|
||||
assert decrypted["secret"] == "dec"
|
||||
mock_set.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_tool_parameters_cache():
|
||||
manager = _build_manager()
|
||||
|
||||
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
|
||||
with patch.object(ToolParameterCache, "delete") as mock_delete:
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
cache_cls.return_value.delete.assert_called_once()
|
||||
mock_delete.assert_called_once()
|
||||
|
||||
|
||||
def test_configuration_manager_decrypt_suppresses_errors():
|
||||
manager = _build_manager()
|
||||
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
|
||||
cache = cache_cls.return_value
|
||||
cache.get.return_value = None
|
||||
with patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")):
|
||||
decrypted = manager.decrypt_tool_parameters({"secret": "enc"})
|
||||
with (
|
||||
patch.object(ToolParameterCache, "get", return_value=None),
|
||||
patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")),
|
||||
):
|
||||
decrypted = manager.decrypt_tool_parameters({"secret": "enc"})
|
||||
# decryption failure is suppressed, original value is retained.
|
||||
assert decrypted["secret"] == "enc"
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class TestAccountModelValidation:
|
|||
)
|
||||
|
||||
# Assert
|
||||
assert account.status == "active"
|
||||
assert account.status == AccountStatus.ACTIVE
|
||||
|
||||
def test_account_get_status_method(self):
|
||||
"""Test the get_status method returns AccountStatus enum."""
|
||||
|
|
@ -106,7 +106,7 @@ class TestAccountModelValidation:
|
|||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status="pending",
|
||||
status=AccountStatus.PENDING,
|
||||
)
|
||||
|
||||
# Act
|
||||
|
|
|
|||
Loading…
Reference in New Issue