refactor(api): add TypedDict definitions to models/model.py (#32925)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
statxc 2026-03-06 01:42:54 +02:00 committed by GitHub
parent 6bd1be9e16
commit 741d48560d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 453 additions and 142 deletions

View File

@ -1,3 +1,5 @@
from typing import Any, cast
from controllers.common import fields
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource):
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
user_input_form = features_dict.get("user_input_form", [])

View File

@ -1,3 +1,5 @@
from typing import Any, cast
from flask_restx import Resource
from controllers.common.fields import Parameters
@ -33,14 +35,14 @@ class AppParameterApi(Resource):
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
user_input_form = features_dict.get("user_input_form", [])

View File

@ -1,4 +1,5 @@
import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource
@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource):
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
user_input_form = features_dict.get("user_input_form", [])

View File

@ -1,10 +1,13 @@
from collections.abc import Mapping
from typing import Any
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict:
return None
@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager:
if sensitive_word_avoidance_dict.get("enabled"):
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config"),
config=sensitive_word_avoidance_dict.get("config", {}),
)
else:
return None

View File

@ -1,10 +1,13 @@
from typing import Any, cast
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
from models.model import AppModelConfigDict
class AgentConfigManager:
@classmethod
def convert(cls, config: dict) -> AgentEntity | None:
def convert(cls, config: AppModelConfigDict) -> AgentEntity | None:
"""
Convert model config to model config
@ -28,17 +31,17 @@ class AgentConfigManager:
agent_tools = []
for tool in agent_dict.get("tools", []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
tool_dict = cast(dict[str, Any], tool)
if len(tool_dict) >= 4:
if "enabled" not in tool_dict or not tool_dict["enabled"]:
continue
agent_tool_properties = {
"provider_type": tool["provider_type"],
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
"credential_id": tool.get("credential_id", None),
"provider_type": tool_dict["provider_type"],
"provider_id": tool_dict["provider_id"],
"tool_name": tool_dict["tool_name"],
"tool_parameters": tool_dict.get("tool_parameters", {}),
"credential_id": tool_dict.get("credential_id", None),
}
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
@ -47,7 +50,8 @@ class AgentConfigManager:
"react_router",
"router",
}:
agent_prompt = agent_dict.get("prompt", None) or {}
agent_prompt_raw = agent_dict.get("prompt", None)
agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {}
# check model mode
model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == "completion":
@ -75,7 +79,7 @@ class AgentConfigManager:
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get("max_iteration", 10),
max_iteration=cast(int, agent_dict.get("max_iteration", 10)),
)
return None

View File

@ -1,5 +1,5 @@
import uuid
from typing import Literal, cast
from typing import Any, Literal, cast
from core.app.app_config.entities import (
DatasetEntity,
@ -8,13 +8,13 @@ from core.app.app_config.entities import (
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode
from models.model import AppMode, AppModelConfigDict
from services.dataset_service import DatasetService
class DatasetConfigManager:
@classmethod
def convert(cls, config: dict) -> DatasetEntity | None:
def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None:
"""
Convert model config to model config
@ -25,11 +25,15 @@ class DatasetConfigManager:
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
for dataset in datasets.get("datasets", []):
if not isinstance(dataset, dict):
continue
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != "dataset":
continue
dataset = dataset["dataset"]
if not isinstance(dataset, dict):
continue
if "enabled" not in dataset or not dataset["enabled"]:
continue
@ -47,15 +51,14 @@ class DatasetConfigManager:
agent_dict = config.get("agent_mode", {})
for tool in agent_dict.get("tools", []):
keys = tool.keys()
if len(keys) == 1:
if len(tool) == 1:
# old standard
key = list(tool.keys())[0]
if key != "dataset":
continue
tool_item = tool[key]
tool_item = cast(dict[str, Any], tool)[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
continue

View File

@ -5,12 +5,13 @@ from core.app.app_config.entities import ModelConfigEntity
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from models.model import AppModelConfigDict
from models.provider_ids import ModelProviderID
class ModelConfigManager:
@classmethod
def convert(cls, config: dict) -> ModelConfigEntity:
def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity:
"""
Convert model config to model config
@ -22,7 +23,7 @@ class ModelConfigManager:
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get("completion_params")
completion_params = model_config.get("completion_params") or {}
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]

View File

@ -1,3 +1,5 @@
from typing import Any
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
@ -6,12 +8,12 @@ from core.app.app_config.entities import (
)
from core.prompt.simple_prompt_transform import ModelMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode
from models.model import AppMode, AppModelConfigDict
class PromptTemplateConfigManager:
@classmethod
def convert(cls, config: dict) -> PromptTemplateEntity:
def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
@ -40,14 +42,15 @@ class PromptTemplateConfigManager:
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
completion_prompt_template_params: dict[str, Any] = {
"prompt": completion_prompt_config["prompt"]["text"],
}
if "conversation_histories_role" in completion_prompt_config:
conv_role = completion_prompt_config.get("conversation_histories_role")
if conv_role:
completion_prompt_template_params["role_prefix"] = {
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
"user": conv_role["user_prefix"],
"assistant": conv_role["assistant_prefix"],
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(

View File

@ -1,8 +1,10 @@
import re
from typing import cast
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
from models.model import AppModelConfigDict
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[
@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
class BasicVariablesConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
"""
Convert model config to model config
@ -51,7 +53,9 @@ class BasicVariablesConfigManager:
external_data_variables.append(
ExternalDataVariableEntity(
variable=variable["variable"], type=variable["type"], config=variable["config"]
variable=variable["variable"],
type=variable.get("type", ""),
config=variable.get("config", {}),
)
)
elif variable_type in {
@ -64,10 +68,10 @@ class BasicVariablesConfigManager:
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
type=cast(VariableEntityType, variable_type),
variable=variable["variable"],
description=variable.get("description") or "",
label=variable.get("label"),
label=variable["label"],
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],

View File

@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig):
app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str
app_model_config_dict: dict
app_model_config_dict: dict[str, Any]
model: ModelConfigEntity
prompt_template: PromptTemplateEntity
dataset: DatasetEntity | None = None

View File

@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.entities.agent_entities import PlanningStrategy
from models.model import App, AppMode, AppModelConfig, Conversation
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model: App,
app_model_config: AppModelConfig,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
override_config_dict: AppModelConfigDict | None = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict or {}
if not override_config_dict:
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentChatAppConfig(
@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
app_model_config_dict=cast(dict[str, Any], config_dict),
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict:
"""
Validate for agent chat app model config
@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
return cast(AppModelConfigDict, filtered_config)
@classmethod
def validate_agent_mode_and_set_defaults(

View File

@ -1,3 +1,5 @@
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig, Conversation
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
class ChatAppConfig(EasyUIBasedAppConfig):
@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_model: App,
app_model_config: AppModelConfig,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
override_config_dict: AppModelConfigDict | None = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config
@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
app_model_config_dict=cast(dict[str, Any], config_dict),
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict):
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
"""
Validate for chat app model config
@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
return cast(AppModelConfigDict, filtered_config)

View File

@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner):
memory=memory,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
vision_enabled=bool(
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
.get("image", {})
.get("enabled", False)
),
)
context_files = retrieved_files or []

View File

@ -1,3 +1,5 @@
from typing import Any, cast
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict
class CompletionAppConfig(EasyUIBasedAppConfig):
@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config
@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict or {}
if not override_config_dict:
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = CompletionAppConfig(
@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
app_model_config_dict=cast(dict[str, Any], config_dict),
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict):
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
"""
Validate for completion app model config
@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager):
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
return cast(AppModelConfigDict, filtered_config)

View File

@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")
completion_params = model_dict.get("completion_params", {})
completion_params["temperature"] = 0.9
model_dict["completion_params"] = completion_params
override_model_config_dict["model"] = model_dict

View File

@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner):
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
"enabled", False
vision_enabled=bool(
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
.get("image", {})
.get("enabled", False)
),
)
context_files = retrieved_files or []

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Union, cast
from typing import Any, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -219,14 +219,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech"))
if (
text_to_speech_dict
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None)
)
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Union
from typing import Any, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if workflow is None:
raise ValueError("unexpected app type")
features_dict = workflow.features_dict
features_dict: dict[str, Any] = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = app_model_config.to_dict()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
user_input_form = features_dict.get("user_input_form", [])

View File

@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
from uuid import uuid4
import sqlalchemy as sa
@ -15,6 +15,7 @@ from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
@ -36,6 +37,259 @@ if TYPE_CHECKING:
from .workflow import Workflow
# --- TypedDict definitions for structured dict return types ---
class EnabledConfig(TypedDict):
enabled: bool
class EmbeddingModelInfo(TypedDict):
embedding_provider_name: str
embedding_model_name: str
class AnnotationReplyDisabledConfig(TypedDict):
enabled: Literal[False]
class AnnotationReplyEnabledConfig(TypedDict):
id: str
enabled: Literal[True]
score_threshold: float
embedding_model: EmbeddingModelInfo
AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig
class SensitiveWordAvoidanceConfig(TypedDict):
enabled: bool
type: str
config: dict[str, Any]
class AgentToolConfig(TypedDict):
provider_type: str
provider_id: str
tool_name: str
tool_parameters: dict[str, Any]
plugin_unique_identifier: NotRequired[str | None]
credential_id: NotRequired[str | None]
class AgentModeConfig(TypedDict):
enabled: bool
strategy: str | None
tools: list[AgentToolConfig | dict[str, Any]]
prompt: str | None
class ImageUploadConfig(TypedDict):
enabled: bool
number_limits: int
detail: str
transfer_methods: list[str]
class FileUploadConfig(TypedDict):
image: ImageUploadConfig
class DeletedToolInfo(TypedDict):
type: str
tool_name: str
provider_id: str
class ExternalDataToolConfig(TypedDict):
enabled: bool
variable: str
type: str
config: dict[str, Any]
class UserInputFormItemConfig(TypedDict):
variable: str
label: str
description: NotRequired[str]
required: NotRequired[bool]
max_length: NotRequired[int]
options: NotRequired[list[str]]
default: NotRequired[str]
type: NotRequired[str]
config: NotRequired[dict[str, Any]]
# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig}
UserInputFormItem = dict[str, UserInputFormItemConfig]
class DatasetConfigs(TypedDict):
retrieval_model: str
datasets: NotRequired[dict[str, Any]]
top_k: NotRequired[int]
score_threshold: NotRequired[float]
score_threshold_enabled: NotRequired[bool]
reranking_model: NotRequired[dict[str, Any] | None]
weights: NotRequired[dict[str, Any] | None]
reranking_enabled: NotRequired[bool]
reranking_mode: NotRequired[str]
metadata_filtering_mode: NotRequired[str]
metadata_model_config: NotRequired[dict[str, Any] | None]
metadata_filtering_conditions: NotRequired[dict[str, Any] | None]
class ChatPromptMessage(TypedDict):
text: str
role: str
class ChatPromptConfig(TypedDict, total=False):
prompt: list[ChatPromptMessage]
class CompletionPromptText(TypedDict):
text: str
class ConversationHistoriesRole(TypedDict):
user_prefix: str
assistant_prefix: str
class CompletionPromptConfig(TypedDict):
prompt: CompletionPromptText
conversation_histories_role: NotRequired[ConversationHistoriesRole]
class ModelConfig(TypedDict):
provider: str
name: str
mode: str
completion_params: NotRequired[dict[str, Any]]
class AppModelConfigDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: EnabledConfig
speech_to_text: EnabledConfig
text_to_speech: EnabledConfig
retriever_resource: EnabledConfig
annotation_reply: AnnotationReplyConfig
more_like_this: EnabledConfig
sensitive_word_avoidance: SensitiveWordAvoidanceConfig
external_data_tools: list[ExternalDataToolConfig]
model: ModelConfig
user_input_form: list[UserInputFormItem]
dataset_query_variable: str | None
pre_prompt: str | None
agent_mode: AgentModeConfig
prompt_type: str
chat_prompt_config: ChatPromptConfig
completion_prompt_config: CompletionPromptConfig
dataset_configs: DatasetConfigs
file_upload: FileUploadConfig
# Added dynamically in Conversation.model_config
model_id: NotRequired[str | None]
provider: NotRequired[str | None]
class ConversationDict(TypedDict):
id: str
app_id: str
app_model_config_id: str | None
model_provider: str | None
override_model_configs: str | None
model_id: str | None
mode: str
name: str
summary: str | None
inputs: dict[str, Any]
introduction: str | None
system_instruction: str | None
system_instruction_tokens: int
status: str
invoke_from: str | None
from_source: str
from_end_user_id: str | None
from_account_id: str | None
read_at: datetime | None
read_account_id: str | None
dialogue_count: int
created_at: datetime
updated_at: datetime
class MessageDict(TypedDict):
id: str
app_id: str
conversation_id: str
model_id: str | None
inputs: dict[str, Any]
query: str
total_price: Decimal | None
message: dict[str, Any]
answer: str
status: str
error: str | None
message_metadata: dict[str, Any]
from_source: str
from_end_user_id: str | None
from_account_id: str | None
created_at: str
updated_at: str
agent_based: bool
workflow_run_id: str | None
class MessageFeedbackDict(TypedDict):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None
from_source: str
from_end_user_id: str | None
from_account_id: str | None
created_at: str
updated_at: str
class MessageFileInfo(TypedDict, total=False):
belongs_to: str | None
upload_file_id: str | None
id: str
tenant_id: str
type: str
transfer_method: str
remote_url: str | None
related_id: str | None
filename: str | None
extension: str | None
mime_type: str | None
size: int
dify_model_identity: str
url: str | None
class ExtraContentDict(TypedDict, total=False):
type: str
workflow_run_id: str
class TraceAppConfigDict(TypedDict):
id: str
app_id: str
tracing_provider: str | None
tracing_config: dict[str, Any]
is_active: bool
created_at: str | None
updated_at: str | None
class DifySetup(TypeBase):
__tablename__ = "dify_setups"
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
@ -176,7 +430,7 @@ class App(Base):
return str(self.mode)
@property
def deleted_tools(self) -> list[dict[str, str]]:
def deleted_tools(self) -> list[DeletedToolInfo]:
from core.tools.tool_manager import ToolManager, ToolProviderType
from services.plugin.plugin_service import PluginService
@ -257,7 +511,7 @@ class App(Base):
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
}
deleted_tools: list[dict[str, str]] = []
deleted_tools: list[DeletedToolInfo] = []
for tool in tools:
keys = list(tool.keys())
@ -364,35 +618,38 @@ class AppModelConfig(TypeBase):
return app
@property
def model_dict(self) -> dict[str, Any]:
return json.loads(self.model) if self.model else {}
def model_dict(self) -> ModelConfig:
return cast(ModelConfig, json.loads(self.model) if self.model else {})
@property
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
@property
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
return (
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
return cast(
EnabledConfig,
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
else {"enabled": False}
else {"enabled": False},
)
@property
def speech_to_text_dict(self) -> dict[str, Any]:
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
def speech_to_text_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
@property
def text_to_speech_dict(self) -> dict[str, Any]:
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
def text_to_speech_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
@property
def retriever_resource_dict(self) -> dict[str, Any]:
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
def retriever_resource_dict(self) -> EnabledConfig:
return cast(
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
)
@property
def annotation_reply_dict(self) -> dict[str, Any]:
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
@ -415,56 +672,62 @@ class AppModelConfig(TypeBase):
return {"enabled": False}
@property
def more_like_this_dict(self) -> dict[str, Any]:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
def more_like_this_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
@property
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
return (
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
return cast(
SensitiveWordAvoidanceConfig,
json.loads(self.sensitive_word_avoidance)
if self.sensitive_word_avoidance
else {"enabled": False, "type": "", "configs": []}
else {"enabled": False, "type": "", "config": {}},
)
@property
def external_data_tools_list(self) -> list[dict[str, Any]]:
def external_data_tools_list(self) -> list[ExternalDataToolConfig]:
return json.loads(self.external_data_tools) if self.external_data_tools else []
@property
def user_input_form_list(self) -> list[dict[str, Any]]:
def user_input_form_list(self) -> list[UserInputFormItem]:
return json.loads(self.user_input_form) if self.user_input_form else []
@property
def agent_mode_dict(self) -> dict[str, Any]:
return (
def agent_mode_dict(self) -> AgentModeConfig:
return cast(
AgentModeConfig,
json.loads(self.agent_mode)
if self.agent_mode
else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
else {"enabled": False, "strategy": None, "tools": [], "prompt": None},
)
@property
def chat_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
def chat_prompt_config_dict(self) -> ChatPromptConfig:
return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {})
@property
def completion_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
def completion_prompt_config_dict(self) -> CompletionPromptConfig:
return cast(
CompletionPromptConfig,
json.loads(self.completion_prompt_config) if self.completion_prompt_config else {},
)
@property
def dataset_configs_dict(self) -> dict[str, Any]:
def dataset_configs_dict(self) -> DatasetConfigs:
if self.dataset_configs:
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
dataset_configs = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
return dataset_configs
return cast(DatasetConfigs, dataset_configs)
return {
"retrieval_model": "multiple",
}
@property
def file_upload_dict(self) -> dict[str, Any]:
return (
def file_upload_dict(self) -> FileUploadConfig:
return cast(
FileUploadConfig,
json.loads(self.file_upload)
if self.file_upload
else {
@ -474,10 +737,10 @@ class AppModelConfig(TypeBase):
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
}
},
)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> AppModelConfigDict:
return {
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
@ -501,36 +764,42 @@ class AppModelConfig(TypeBase):
"file_upload": self.file_upload_dict,
}
def from_model_config_dict(self, model_config: Mapping[str, Any]):
def from_model_config_dict(self, model_config: AppModelConfigDict):
self.opening_statement = model_config.get("opening_statement")
self.suggested_questions = (
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
)
self.suggested_questions_after_answer = (
json.dumps(model_config["suggested_questions_after_answer"])
json.dumps(model_config.get("suggested_questions_after_answer"))
if model_config.get("suggested_questions_after_answer")
else None
)
self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None
self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None
self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None
self.speech_to_text = (
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
)
self.text_to_speech = (
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
)
self.more_like_this = (
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
)
self.sensitive_word_avoidance = (
json.dumps(model_config["sensitive_word_avoidance"])
json.dumps(model_config.get("sensitive_word_avoidance"))
if model_config.get("sensitive_word_avoidance")
else None
)
self.external_data_tools = (
json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
)
self.model = json.dumps(model_config["model"]) if model_config.get("model") else None
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
self.user_input_form = (
json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
)
self.dataset_query_variable = model_config.get("dataset_query_variable")
self.pre_prompt = model_config["pre_prompt"]
self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None
self.pre_prompt = model_config.get("pre_prompt")
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
self.retriever_resource = (
json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
)
self.prompt_type = model_config.get("prompt_type", "simple")
self.chat_prompt_config = (
@ -823,24 +1092,26 @@ class Conversation(Base):
self._inputs = inputs
@property
def model_config(self):
model_config = {}
def model_config(self) -> AppModelConfigDict:
model_config = cast(AppModelConfigDict, {})
app_model_config: AppModelConfig | None = None
if self.mode == AppMode.ADVANCED_CHAT:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
model_config = override_model_configs
model_config = cast(AppModelConfigDict, override_model_configs)
else:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
if "model" in override_model_configs:
# where is app_id?
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(
cast(AppModelConfigDict, override_model_configs)
)
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
@ -1015,7 +1286,7 @@ class Conversation(Base):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> ConversationDict:
return {
"id": self.id,
"app_id": self.app_id,
@ -1295,7 +1566,7 @@ class Message(Base):
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property
def message_files(self) -> list[dict[str, Any]]:
def message_files(self) -> list[MessageFileInfo]:
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
@ -1350,10 +1621,13 @@ class Message(Base):
)
files.append(file)
result: list[dict[str, Any]] = [
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
]
result = cast(
list[MessageFileInfo],
[
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
],
)
db.session.commit()
return result
@ -1363,7 +1637,7 @@ class Message(Base):
self._extra_contents = list(contents)
@property
def extra_contents(self) -> list[dict[str, Any]]:
def extra_contents(self) -> list[ExtraContentDict]:
return getattr(self, "_extra_contents", [])
@property
@ -1379,7 +1653,7 @@ class Message(Base):
return None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> MessageDict:
return {
"id": self.id,
"app_id": self.app_id,
@ -1403,7 +1677,7 @@ class Message(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Message:
def from_dict(cls, data: MessageDict) -> Message:
return cls(
id=data["id"],
app_id=data["app_id"],
@ -1463,7 +1737,7 @@ class MessageFeedback(TypeBase):
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> MessageFeedbackDict:
return {
"id": str(self.id),
"app_id": str(self.app_id),
@ -1726,8 +2000,8 @@ class AppMCPServer(TypeBase):
return result
@property
def parameters_dict(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.parameters))
def parameters_dict(self) -> dict[str, str]:
return cast(dict[str, str], json.loads(self.parameters))
class Site(Base):
@ -2167,7 +2441,7 @@ class TraceAppConfig(TypeBase):
def tracing_config_str(self) -> str:
return json.dumps(self.tracing_config_dict)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> TraceAppConfigDict:
return {
"id": self.id,
"app_id": self.app_id,

View File

@ -4,6 +4,7 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
from models.model import AppModelConfig, IconType
from models.model import AppModelConfig, AppModelConfigDict, IconType
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
@ -523,7 +524,7 @@ class AppDslService:
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(model_config)
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id

View File

@ -1,12 +1,12 @@
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from models.model import AppMode
from models.model import AppMode, AppModelConfigDict
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import TypedDict, cast
from typing import Any, TypedDict, cast
import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
@ -187,7 +187,7 @@ class AppService:
for tool in agent_mode.get("tools") or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
@ -388,7 +388,7 @@ class AppService:
agent_config = app_model_config.agent_mode_dict
# get all tools
tools = agent_config.get("tools", [])
tools = cast(list[dict[str, Any]], agent_config.get("tools", []))
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"

View File

@ -2,6 +2,7 @@ import io
import logging
import uuid
from collections.abc import Generator
from typing import cast
from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
@ -106,7 +107,7 @@ class AudioService:
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = text_to_speech_dict.get("voice")
voice = cast(str | None, text_to_speech_dict.get("voice"))
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(