mirror of https://github.com/langgenius/dify.git
refactor(api): replace dict/Mapping with TypedDict in trigger.py and workflow.py (#33562)
This commit is contained in:
parent
569748189e
commit
7e34faaf51
|
|
@ -3,7 +3,7 @@ import time
|
|||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -24,6 +24,44 @@ from .model import Account
|
|||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
class WorkflowTriggerLogDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str | None
|
||||
root_node_id: str | None
|
||||
trigger_metadata: Any
|
||||
trigger_type: str
|
||||
trigger_data: Any
|
||||
inputs: Any
|
||||
outputs: Any
|
||||
status: str
|
||||
error: str | None
|
||||
queue_name: str
|
||||
celery_task_id: str | None
|
||||
retry_count: int
|
||||
elapsed_time: float | None
|
||||
total_tokens: int | None
|
||||
created_by_role: str
|
||||
created_by: str
|
||||
created_at: str | None
|
||||
triggered_at: str | None
|
||||
finished_at: str | None
|
||||
|
||||
|
||||
class WorkflowSchedulePlanDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
cron_expression: str
|
||||
timezone: str
|
||||
next_run_at: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TriggerSubscription(TypeBase):
|
||||
"""
|
||||
Trigger provider model for managing credentials
|
||||
|
|
@ -250,7 +288,7 @@ class WorkflowTriggerLog(TypeBase):
|
|||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> WorkflowTriggerLogDict:
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
"id": self.id,
|
||||
|
|
@ -481,7 +519,7 @@ class WorkflowSchedulePlan(TypeBase):
|
|||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> WorkflowSchedulePlanDict:
|
||||
"""Convert to dictionary representation"""
|
||||
return {
|
||||
"id": self.id,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -60,6 +60,22 @@ from .types import EnumText, LongText, StringUUID
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowContentDict(TypedDict):
|
||||
graph: Mapping[str, Any]
|
||||
features: dict[str, Any]
|
||||
environment_variables: list[dict[str, Any]]
|
||||
conversation_variables: list[dict[str, Any]]
|
||||
rag_pipeline_variables: list[dict[str, Any]]
|
||||
|
||||
|
||||
class WorkflowRunSummaryDict(TypedDict):
|
||||
id: str
|
||||
status: str
|
||||
triggered_from: str
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum
|
||||
|
|
@ -502,14 +518,14 @@ class Workflow(Base): # bug
|
|||
)
|
||||
self._environment_variables = environment_variables_json
|
||||
|
||||
def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]:
|
||||
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
|
||||
environment_variables = list(self.environment_variables)
|
||||
environment_variables = [
|
||||
v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""})
|
||||
for v in environment_variables
|
||||
]
|
||||
|
||||
result = {
|
||||
result: WorkflowContentDict = {
|
||||
"graph": self.graph_dict,
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
|
|
@ -1231,7 +1247,7 @@ class WorkflowArchiveLog(TypeBase):
|
|||
)
|
||||
|
||||
@property
|
||||
def workflow_run_summary(self) -> dict[str, Any]:
|
||||
def workflow_run_summary(self) -> WorkflowRunSummaryDict:
|
||||
return {
|
||||
"id": self.workflow_run_id,
|
||||
"status": self.run_status,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from extensions.ext_database import db
|
|||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
|
|
@ -224,7 +224,9 @@ class AsyncWorkflowService:
|
|||
return cls.trigger_workflow_async(session, user, trigger_data)
|
||||
|
||||
@classmethod
|
||||
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None:
|
||||
def get_trigger_log(
|
||||
cls, workflow_trigger_log_id: str, tenant_id: str | None = None
|
||||
) -> WorkflowTriggerLogDict | None:
|
||||
"""
|
||||
Get trigger log by ID
|
||||
|
||||
|
|
@ -247,7 +249,7 @@ class AsyncWorkflowService:
|
|||
@classmethod
|
||||
def get_recent_logs(
|
||||
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[WorkflowTriggerLogDict]:
|
||||
"""
|
||||
Get recent trigger logs
|
||||
|
||||
|
|
@ -272,7 +274,7 @@ class AsyncWorkflowService:
|
|||
@classmethod
|
||||
def get_failed_logs_for_retry(
|
||||
cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[WorkflowTriggerLogDict]:
|
||||
"""
|
||||
Get failed logs eligible for retry
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue