From 741d48560da5d22a868ab5a249650c4aaa5e55b8 Mon Sep 17 00:00:00 2001 From: statxc <181730535+statxc@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:42:54 +0200 Subject: [PATCH] refactor(api): add TypedDict definitions to models/model.py (#32925) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/explore/parameter.py | 6 +- api/controllers/service_api/app/app.py | 6 +- api/controllers/web/app.py | 5 +- .../sensitive_word_avoidance/manager.py | 7 +- .../easy_ui_based_app/agent/manager.py | 26 +- .../easy_ui_based_app/dataset/manager.py | 15 +- .../easy_ui_based_app/model_config/manager.py | 5 +- .../prompt_template/manager.py | 15 +- .../easy_ui_based_app/variables/manager.py | 14 +- api/core/app/app_config/entities.py | 2 +- .../app/apps/agent_chat/app_config_manager.py | 14 +- api/core/app/apps/chat/app_config_manager.py | 12 +- api/core/app/apps/chat/app_runner.py | 6 +- .../app/apps/completion/app_config_manager.py | 16 +- api/core/app/apps/completion/app_generator.py | 2 +- api/core/app/apps/completion/app_runner.py | 6 +- .../easy_ui_based_generate_task_pipeline.py | 6 +- api/core/plugin/backwards_invocation/app.py | 6 +- api/models/model.py | 408 +++++++++++++++--- api/services/app_dsl_service.py | 5 +- api/services/app_model_config_service.py | 4 +- api/services/app_service.py | 6 +- api/services/audio_service.py | 3 +- 23 files changed, 453 insertions(+), 142 deletions(-) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 660a4d5aea..0f29627746 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from controllers.common import fields from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError @@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 562f5e33cc..abcaa0e240 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from flask_restx import Resource from controllers.common.fields import Parameters @@ -33,14 +35,14 @@ class AppParameterApi(Resource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 62ea532eac..25bbedce54 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,5 @@ import logging +from typing import Any, cast from flask import request from flask_restx import Resource @@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index e925d6dd52..7d1b11c008 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,10 +1,13 @@ +from collections.abc import Mapping +from typing import Any + from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod - def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None: + def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None: sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None @@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager: if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( type=sensitive_word_avoidance_dict.get("type"), - config=sensitive_word_avoidance_dict.get("config"), + config=sensitive_word_avoidance_dict.get("config", {}), ) else: return None diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 9b981dfc09..10db380d1f 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,10 +1,13 @@ +from typing import Any, cast + from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity from core.agent.prompt.template import REACT_PROMPT_TEMPLATES +from models.model import AppModelConfigDict class AgentConfigManager: @classmethod - def convert(cls, config: dict) -> AgentEntity | None: + def convert(cls, config: AppModelConfigDict) -> AgentEntity | None: """ Convert model config to model config @@ -28,17 +31,17 @@ class AgentConfigManager: agent_tools = [] for tool in agent_dict.get("tools", []): - keys = tool.keys() - if len(keys) >= 4: - if "enabled" not in tool or not tool["enabled"]: + tool_dict = cast(dict[str, Any], tool) + if len(tool_dict) >= 4: + if "enabled" not in tool_dict or not tool_dict["enabled"]: continue agent_tool_properties = { - "provider_type": tool["provider_type"], - "provider_id": tool["provider_id"], - "tool_name": tool["tool_name"], - "tool_parameters": tool.get("tool_parameters", {}), - "credential_id": tool.get("credential_id", None), + "provider_type": tool_dict["provider_type"], + "provider_id": tool_dict["provider_id"], + "tool_name": tool_dict["tool_name"], + "tool_parameters": tool_dict.get("tool_parameters", {}), + "credential_id": tool_dict.get("credential_id", None), } agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties)) @@ -47,7 +50,8 @@ class AgentConfigManager: "react_router", "router", }: - agent_prompt = agent_dict.get("prompt", None) or {} + agent_prompt_raw = agent_dict.get("prompt", None) + agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {} # check model mode model_mode = config.get("model", {}).get("mode", "completion") if model_mode == "completion": @@ -75,7 +79,7 @@ class AgentConfigManager: strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get("max_iteration", 10), + max_iteration=cast(int, agent_dict.get("max_iteration", 10)), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index aacafb2dad..70f43b2c83 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,5 +1,5 @@ import uuid -from typing import Literal, cast +from typing import Any, Literal, cast from core.app.app_config.entities import ( DatasetEntity, @@ -8,13 +8,13 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.entities.agent_entities import PlanningStrategy -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict from services.dataset_service import DatasetService class DatasetConfigManager: @classmethod - def convert(cls, config: dict) -> DatasetEntity | None: + def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None: """ Convert model config to model config @@ -25,11 +25,15 @@ class DatasetConfigManager: datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) for dataset in datasets.get("datasets", []): + if not isinstance(dataset, dict): + continue keys = list(dataset.keys()) if len(keys) == 0 or keys[0] != "dataset": continue dataset = dataset["dataset"] + if not isinstance(dataset, dict): + continue if "enabled" not in dataset or not dataset["enabled"]: continue @@ -47,15 +51,14 @@ class DatasetConfigManager: agent_dict = config.get("agent_mode", {}) for tool in agent_dict.get("tools", []): - keys = tool.keys() - if len(keys) == 1: + if len(tool) == 1: # old standard key = list(tool.keys())[0] if key != "dataset": continue - tool_item = tool[key] + tool_item = cast(dict[str, Any], tool)[key] if "enabled" not in tool_item or not tool_item["enabled"]: continue diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index e4e750c735..0929f52e33 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -5,12 +5,13 @@ from core.app.app_config.entities import ModelConfigEntity from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID class ModelConfigManager: @classmethod - def convert(cls, config: dict) -> ModelConfigEntity: + def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity: """ Convert model config to model config @@ -22,7 +23,7 @@ class ModelConfigManager: if not model_config: raise ValueError("model is required") - completion_params = model_config.get("completion_params") + completion_params = model_config.get("completion_params") or {} stop = [] if "stop" in completion_params: stop = completion_params["stop"] diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 01b9601965..b7073898d6 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,3 +1,5 @@ +from typing import Any + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -6,12 +8,12 @@ from core.app.app_config.entities import ( ) from core.prompt.simple_prompt_transform import ModelMode from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict class PromptTemplateConfigManager: @classmethod - def convert(cls, config: dict) -> PromptTemplateEntity: + def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity: if not config.get("prompt_type"): raise ValueError("prompt_type is required") @@ -40,14 +42,15 @@ class PromptTemplateConfigManager: advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: - completion_prompt_template_params = { + completion_prompt_template_params: dict[str, Any] = { "prompt": completion_prompt_config["prompt"]["text"], } - if "conversation_histories_role" in completion_prompt_config: + conv_role = completion_prompt_config.get("conversation_histories_role") + if conv_role: completion_prompt_template_params["role_prefix"] = { - "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], - "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], + "user": conv_role["user_prefix"], + "assistant": conv_role["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 157e5d8bc0..8de1224a89 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,8 +1,10 @@ import re +from typing import cast from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( [ @@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( class BasicVariablesConfigManager: @classmethod - def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: """ Convert model config to model config @@ -51,7 +53,9 @@ class BasicVariablesConfigManager: external_data_variables.append( ExternalDataVariableEntity( - variable=variable["variable"], type=variable["type"], config=variable["config"] + variable=variable["variable"], + type=variable.get("type", ""), + config=variable.get("config", {}), ) ) elif variable_type in { @@ -64,10 +68,10 @@ class BasicVariablesConfigManager: variable = variables[variable_type] variable_entities.append( VariableEntity( - type=variable_type, - variable=variable.get("variable"), + type=cast(VariableEntityType, variable_type), + variable=variable["variable"], description=variable.get("description") or "", - label=variable.get("label"), + label=variable["label"], required=variable.get("required", False), max_length=variable.get("max_length"), options=variable.get("options") or [], diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index f26351d93e..ac21577d57 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig): app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str - app_model_config_dict: dict + app_model_config_dict: dict[str, Any] model: ModelConfigEntity prompt_template: PromptTemplateEntity dataset: DatasetEntity | None = None diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 801619ddbc..f0d81e0c59 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model: App, app_model_config: AppModelConfig, conversation: Conversation | None = None, - override_config_dict: dict | None = None, + override_config_dict: AppModelConfigDict | None = None, ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config @@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict or {} + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + config_dict = override_config_dict app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( @@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: Mapping[str, Any]): + def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict: """ Validate for agent chat app model config @@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) @classmethod def validate_agent_mode_and_set_defaults( diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 4b6720a3c3..5f087f6066 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor SuggestedQuestionsAfterAnswerConfigManager, ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation class ChatAppConfig(EasyUIBasedAppConfig): @@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager): app_model: App, app_model_config: AppModelConfig, conversation: Conversation | None = None, - override_config_dict: dict | None = None, + override_config_dict: AppModelConfigDict | None = None, ) -> ChatAppConfig: """ Convert app model config to chat app config @@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict): + def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict: """ Validate for chat app model config @@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 23546a47bb..f63b38fc86 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, - vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( - "enabled", False + vision_enabled=bool( + application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}) + .get("image", {}) + .get("enabled", False) ), ) context_files = retrieved_files or [] diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index eb1902f12e..f49e7b8b5e 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict class CompletionAppConfig(EasyUIBasedAppConfig): @@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config( - cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None ) -> CompletionAppConfig: """ Convert app model config to completion app config @@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict or {} + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + config_dict = override_config_dict app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( @@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict): + def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict: """ Validate for completion app model config @@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index e8b0e4f179..002b914ef1 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise ValueError("Message app_model_config is None") override_model_config_dict = app_model_config.to_dict() model_dict = override_model_config_dict["model"] - completion_params = model_dict.get("completion_params") + completion_params = model_dict.get("completion_params", {}) completion_params["temperature"] = 0.9 model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index ac05172945..56a4519879 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, - vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( - "enabled", False + vision_enabled=bool( + application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}) + .get("image", {}) + .get("enabled", False) ), ) context_files = retrieved_files or [] diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 1fa782eb6c..57ef0c078f 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator from threading import Thread -from typing import Union, cast +from typing import Any, Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -219,14 +219,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech")) if ( text_to_speech_dict and text_to_speech_dict.get("autoPlay") == "enabled" and text_to_speech_dict.get("enabled") ): publisher = AppGeneratorTTSPublisher( - tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None) + tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None) ) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 3c5df2b905..60d08b26c9 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Union +from typing import Any, Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if workflow is None: raise ValueError("unexpected app type") - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app.app_model_config if app_model_config is None: raise ValueError("unexpected app type") - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/models/model.py b/api/models/model.py index 2bf80edb80..ed0614c195 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 import sqlalchemy as sa @@ -15,6 +15,7 @@ from flask import request from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column +from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS @@ -36,6 +37,259 @@ if TYPE_CHECKING: from .workflow import Workflow +# --- TypedDict definitions for structured dict return types --- + + +class EnabledConfig(TypedDict): + enabled: bool + + +class EmbeddingModelInfo(TypedDict): + embedding_provider_name: str + embedding_model_name: str + + +class AnnotationReplyDisabledConfig(TypedDict): + enabled: Literal[False] + + +class AnnotationReplyEnabledConfig(TypedDict): + id: str + enabled: Literal[True] + score_threshold: float + embedding_model: EmbeddingModelInfo + + +AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig + + +class SensitiveWordAvoidanceConfig(TypedDict): + enabled: bool + type: str + config: dict[str, Any] + + +class AgentToolConfig(TypedDict): + provider_type: str + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] + plugin_unique_identifier: NotRequired[str | None] + credential_id: NotRequired[str | None] + + +class AgentModeConfig(TypedDict): + enabled: bool + strategy: str | None + tools: list[AgentToolConfig | dict[str, Any]] + prompt: str | None + + +class ImageUploadConfig(TypedDict): + enabled: bool + number_limits: int + detail: str + transfer_methods: list[str] + + +class FileUploadConfig(TypedDict): + image: ImageUploadConfig + + +class DeletedToolInfo(TypedDict): + type: str + tool_name: str + provider_id: str + + +class ExternalDataToolConfig(TypedDict): + enabled: bool + variable: str + type: str + config: dict[str, Any] + + +class UserInputFormItemConfig(TypedDict): + variable: str + label: str + description: NotRequired[str] + required: NotRequired[bool] + max_length: NotRequired[int] + options: NotRequired[list[str]] + default: NotRequired[str] + type: NotRequired[str] + config: NotRequired[dict[str, Any]] + + +# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig} +UserInputFormItem = dict[str, UserInputFormItemConfig] + + +class DatasetConfigs(TypedDict): + retrieval_model: str + datasets: NotRequired[dict[str, Any]] + top_k: NotRequired[int] + score_threshold: NotRequired[float] + score_threshold_enabled: NotRequired[bool] + reranking_model: NotRequired[dict[str, Any] | None] + weights: NotRequired[dict[str, Any] | None] + reranking_enabled: NotRequired[bool] + reranking_mode: NotRequired[str] + metadata_filtering_mode: NotRequired[str] + metadata_model_config: NotRequired[dict[str, Any] | None] + metadata_filtering_conditions: NotRequired[dict[str, Any] | None] + + +class ChatPromptMessage(TypedDict): + text: str + role: str + + +class ChatPromptConfig(TypedDict, total=False): + prompt: list[ChatPromptMessage] + + +class CompletionPromptText(TypedDict): + text: str + + +class ConversationHistoriesRole(TypedDict): + user_prefix: str + assistant_prefix: str + + +class CompletionPromptConfig(TypedDict): + prompt: CompletionPromptText + conversation_histories_role: NotRequired[ConversationHistoriesRole] + + +class ModelConfig(TypedDict): + provider: str + name: str + mode: str + completion_params: NotRequired[dict[str, Any]] + + +class AppModelConfigDict(TypedDict): + opening_statement: str | None + suggested_questions: list[str] + suggested_questions_after_answer: EnabledConfig + speech_to_text: EnabledConfig + text_to_speech: EnabledConfig + retriever_resource: EnabledConfig + annotation_reply: AnnotationReplyConfig + more_like_this: EnabledConfig + sensitive_word_avoidance: SensitiveWordAvoidanceConfig + external_data_tools: list[ExternalDataToolConfig] + model: ModelConfig + user_input_form: list[UserInputFormItem] + dataset_query_variable: str | None + pre_prompt: str | None + agent_mode: AgentModeConfig + prompt_type: str + chat_prompt_config: ChatPromptConfig + completion_prompt_config: CompletionPromptConfig + dataset_configs: DatasetConfigs + file_upload: FileUploadConfig + # Added dynamically in Conversation.model_config + model_id: NotRequired[str | None] + provider: NotRequired[str | None] + + +class ConversationDict(TypedDict): + id: str + app_id: str + app_model_config_id: str | None + model_provider: str | None + override_model_configs: str | None + model_id: str | None + mode: str + name: str + summary: str | None + inputs: dict[str, Any] + introduction: str | None + system_instruction: str | None + system_instruction_tokens: int + status: str + invoke_from: str | None + from_source: str + from_end_user_id: str | None + from_account_id: str | None + read_at: datetime | None + read_account_id: str | None + dialogue_count: int + created_at: datetime + updated_at: datetime + + +class MessageDict(TypedDict): + id: str + app_id: str + conversation_id: str + model_id: str | None + inputs: dict[str, Any] + query: str + total_price: Decimal | None + message: dict[str, Any] + answer: str + status: str + error: str | None + message_metadata: dict[str, Any] + from_source: str + from_end_user_id: str | None + from_account_id: str | None + created_at: str + updated_at: str + agent_based: bool + workflow_run_id: str | None + + +class MessageFeedbackDict(TypedDict): + id: str + app_id: str + conversation_id: str + message_id: str + rating: str + content: str | None + from_source: str + from_end_user_id: str | None + from_account_id: str | None + created_at: str + updated_at: str + + +class MessageFileInfo(TypedDict, total=False): + belongs_to: str | None + upload_file_id: str | None + id: str + tenant_id: str + type: str + transfer_method: str + remote_url: str | None + related_id: str | None + filename: str | None + extension: str | None + mime_type: str | None + size: int + dify_model_identity: str + url: str | None + + +class ExtraContentDict(TypedDict, total=False): + type: str + workflow_run_id: str + + +class TraceAppConfigDict(TypedDict): + id: str + app_id: str + tracing_provider: str | None + tracing_config: dict[str, Any] + is_active: bool + created_at: str | None + updated_at: str | None + + class DifySetup(TypeBase): __tablename__ = "dify_setups" __table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -176,7 +430,7 @@ class App(Base): return str(self.mode) @property - def deleted_tools(self) -> list[dict[str, str]]: + def deleted_tools(self) -> list[DeletedToolInfo]: from core.tools.tool_manager import ToolManager, ToolProviderType from services.plugin.plugin_service import PluginService @@ -257,7 +511,7 @@ class App(Base): provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) } - deleted_tools: list[dict[str, str]] = [] + deleted_tools: list[DeletedToolInfo] = [] for tool in tools: keys = list(tool.keys()) @@ -364,35 +618,38 @@ class AppModelConfig(TypeBase): return app @property - def model_dict(self) -> dict[str, Any]: - return json.loads(self.model) if self.model else {} + def model_dict(self) -> ModelConfig: + return cast(ModelConfig, json.loads(self.model) if self.model else {}) @property def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] @property - def suggested_questions_after_answer_dict(self) -> dict[str, Any]: - return ( + def suggested_questions_after_answer_dict(self) -> EnabledConfig: + return cast( + EnabledConfig, json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer - else {"enabled": False} + else {"enabled": False}, ) @property - def speech_to_text_dict(self) -> dict[str, Any]: - return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} + def speech_to_text_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) @property - def text_to_speech_dict(self) -> dict[str, Any]: - return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} + def text_to_speech_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) @property - def retriever_resource_dict(self) -> dict[str, Any]: - return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + def retriever_resource_dict(self) -> EnabledConfig: + return cast( + EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + ) @property - def annotation_reply_dict(self) -> dict[str, Any]: + def annotation_reply_dict(self) -> AnnotationReplyConfig: annotation_setting = ( db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() ) @@ -415,56 +672,62 @@ class AppModelConfig(TypeBase): return {"enabled": False} @property - def more_like_this_dict(self) -> dict[str, Any]: - return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} + def more_like_this_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) @property - def sensitive_word_avoidance_dict(self) -> dict[str, Any]: - return ( + def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: + return cast( + SensitiveWordAvoidanceConfig, json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance - else {"enabled": False, "type": "", "configs": []} + else {"enabled": False, "type": "", "config": {}}, ) @property - def external_data_tools_list(self) -> list[dict[str, Any]]: + def external_data_tools_list(self) -> list[ExternalDataToolConfig]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> list[dict[str, Any]]: + def user_input_form_list(self) -> list[UserInputFormItem]: return json.loads(self.user_input_form) if self.user_input_form else [] @property - def agent_mode_dict(self) -> dict[str, Any]: - return ( + def agent_mode_dict(self) -> AgentModeConfig: + return cast( + AgentModeConfig, json.loads(self.agent_mode) if self.agent_mode - else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + else {"enabled": False, "strategy": None, "tools": [], "prompt": None}, ) @property - def chat_prompt_config_dict(self) -> dict[str, Any]: - return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} + def chat_prompt_config_dict(self) -> ChatPromptConfig: + return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}) @property - def completion_prompt_config_dict(self) -> dict[str, Any]: - return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} + def completion_prompt_config_dict(self) -> CompletionPromptConfig: + return cast( + CompletionPromptConfig, + json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}, + ) @property - def dataset_configs_dict(self) -> dict[str, Any]: + def dataset_configs_dict(self) -> DatasetConfigs: if self.dataset_configs: - dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) + dataset_configs = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: - return dataset_configs + return cast(DatasetConfigs, dataset_configs) return { "retrieval_model": "multiple", } @property - def file_upload_dict(self) -> dict[str, Any]: - return ( + def file_upload_dict(self) -> FileUploadConfig: + return cast( + FileUploadConfig, json.loads(self.file_upload) if self.file_upload else { @@ -474,10 +737,10 @@ class AppModelConfig(TypeBase): "detail": "high", "transfer_methods": ["remote_url", "local_file"], } - } + }, ) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> AppModelConfigDict: return { "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, @@ -501,36 +764,42 @@ class AppModelConfig(TypeBase): "file_upload": self.file_upload_dict, } - def from_model_config_dict(self, model_config: Mapping[str, Any]): + def from_model_config_dict(self, model_config: AppModelConfigDict): self.opening_statement = model_config.get("opening_statement") self.suggested_questions = ( - json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None ) self.suggested_questions_after_answer = ( - json.dumps(model_config["suggested_questions_after_answer"]) + json.dumps(model_config.get("suggested_questions_after_answer")) if model_config.get("suggested_questions_after_answer") else None ) - self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None - self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None - self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.speech_to_text = ( + json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None + ) + self.text_to_speech = ( + json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None + ) + self.more_like_this = ( + json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None + ) self.sensitive_word_avoidance = ( - json.dumps(model_config["sensitive_word_avoidance"]) + json.dumps(model_config.get("sensitive_word_avoidance")) if model_config.get("sensitive_word_avoidance") else None ) self.external_data_tools = ( - json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None ) - self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None self.user_input_form = ( - json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None ) self.dataset_query_variable = model_config.get("dataset_query_variable") - self.pre_prompt = model_config["pre_prompt"] - self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.pre_prompt = model_config.get("pre_prompt") + self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None self.retriever_resource = ( - json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None ) self.prompt_type = model_config.get("prompt_type", "simple") self.chat_prompt_config = ( @@ -823,24 +1092,26 @@ class Conversation(Base): self._inputs = inputs @property - def model_config(self): - model_config = {} + def model_config(self) -> AppModelConfigDict: + model_config = cast(AppModelConfigDict, {}) app_model_config: AppModelConfig | None = None if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - model_config = override_model_configs + model_config = cast(AppModelConfigDict, override_model_configs) else: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) if "model" in override_model_configs: # where is app_id? - app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs) + app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict( + cast(AppModelConfigDict, override_model_configs) + ) model_config = app_model_config.to_dict() else: - model_config["configs"] = override_model_configs + model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key] else: app_model_config = ( db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() @@ -1015,7 +1286,7 @@ class Conversation(Base): def in_debug_mode(self) -> bool: return self.override_model_configs is not None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> ConversationDict: return { "id": self.id, "app_id": self.app_id, @@ -1295,7 +1566,7 @@ class Message(Base): return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @property - def message_files(self) -> list[dict[str, Any]]: + def message_files(self) -> list[MessageFileInfo]: from factories import file_factory message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() @@ -1350,10 +1621,13 @@ class Message(Base): ) files.append(file) - result: list[dict[str, Any]] = [ - {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} - for (file, message_file) in zip(files, message_files) - ] + result = cast( + list[MessageFileInfo], + [ + {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} + for (file, message_file) in zip(files, message_files) + ], + ) db.session.commit() return result @@ -1363,7 +1637,7 @@ class Message(Base): self._extra_contents = list(contents) @property - def extra_contents(self) -> list[dict[str, Any]]: + def extra_contents(self) -> list[ExtraContentDict]: return getattr(self, "_extra_contents", []) @property @@ -1379,7 +1653,7 @@ class Message(Base): return None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MessageDict: return { "id": self.id, "app_id": self.app_id, @@ -1403,7 +1677,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> Message: + def from_dict(cls, data: MessageDict) -> Message: return cls( id=data["id"], app_id=data["app_id"], @@ -1463,7 +1737,7 @@ class MessageFeedback(TypeBase): account = db.session.query(Account).where(Account.id == self.from_account_id).first() return account - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MessageFeedbackDict: return { "id": str(self.id), "app_id": str(self.app_id), @@ -1726,8 +2000,8 @@ class AppMCPServer(TypeBase): return result @property - def parameters_dict(self) -> dict[str, Any]: - return cast(dict[str, Any], json.loads(self.parameters)) + def parameters_dict(self) -> dict[str, str]: + return cast(dict[str, str], json.loads(self.parameters)) class Site(Base): @@ -2167,7 +2441,7 @@ class TraceAppConfig(TypeBase): def tracing_config_str(self) -> str: return json.dumps(self.tracing_config_dict) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> TraceAppConfigDict: return { "id": self.id, "app_id": self.app_id, diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 5790c8b9ec..06f4ccb90e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,6 +4,7 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum +from typing import cast from urllib.parse import urlparse from uuid import uuid4 @@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig, IconType +from models.model import AppModelConfig, AppModelConfigDict, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -523,7 +524,7 @@ class AppDslService: if not app.app_model_config: app_model_config = AppModelConfig( app_id=app.id, created_by=account.id, updated_by=account.id - ).from_model_config_dict(model_config) + ).from_model_config_dict(cast(AppModelConfigDict, model_config)) app_model_config.id = str(uuid4()) app.app_model_config_id = app_model_config.id diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 6f54f90734..3bc30cb323 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,12 +1,12 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict class AppModelConfigService: @classmethod - def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode): + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict: if app_mode == AppMode.CHAT: return ChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.AGENT_CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index ce6826ef5c..aba8954f1a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import TypedDict, cast +from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination @@ -187,7 +187,7 @@ class AppService: for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool)) # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( @@ -388,7 +388,7 @@ class AppService: agent_config = app_model_config.agent_mode_dict # get all tools - tools = agent_config.get("tools", []) + tools = cast(list[dict[str, Any]], agent_config.get("tools", [])) url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1b698fad17..1794ea9947 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -2,6 +2,7 @@ import io import logging import uuid from collections.abc import Generator +from typing import cast from flask import Response, stream_with_context from werkzeug.datastructures import FileStorage @@ -106,7 +107,7 @@ class AudioService: if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get("voice") + voice = cast(str | None, text_to_speech_dict.get("voice")) model_manager = ModelManager() model_instance = model_manager.get_default_model_instance(