mirror of https://github.com/langgenius/dify.git
refactor(api): add TypedDict definitions to models/model.py (#32925)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6bd1be9e16
commit
741d48560d
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Any, cast
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
|
|
@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource):
|
|||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Any, cast
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
|
|
@ -33,14 +35,14 @@ class AppParameterApi(Resource):
|
|||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
|
@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource):
|
|||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
|
||||
def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None:
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
|
|
@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager:
|
|||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
config=sensitive_word_avoidance_dict.get("config", {}),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from typing import Any, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
|
||||
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
|
||||
class AgentConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> AgentEntity | None:
|
||||
def convert(cls, config: AppModelConfigDict) -> AgentEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
|
@ -28,17 +31,17 @@ class AgentConfigManager:
|
|||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
tool_dict = cast(dict[str, Any], tool)
|
||||
if len(tool_dict) >= 4:
|
||||
if "enabled" not in tool_dict or not tool_dict["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": tool_dict["provider_type"],
|
||||
"provider_id": tool_dict["provider_id"],
|
||||
"tool_name": tool_dict["tool_name"],
|
||||
"tool_parameters": tool_dict.get("tool_parameters", {}),
|
||||
"credential_id": tool_dict.get("credential_id", None),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
|
||||
|
|
@ -47,7 +50,8 @@ class AgentConfigManager:
|
|||
"react_router",
|
||||
"router",
|
||||
}:
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
agent_prompt_raw = agent_dict.get("prompt", None)
|
||||
agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {}
|
||||
# check model mode
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
|
|
@ -75,7 +79,7 @@ class AgentConfigManager:
|
|||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get("max_iteration", 10),
|
||||
max_iteration=cast(int, agent_dict.get("max_iteration", 10)),
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import uuid
|
||||
from typing import Literal, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
|
|
@ -8,13 +8,13 @@ from core.app.app_config.entities import (
|
|||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class DatasetConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> DatasetEntity | None:
|
||||
def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
|
@ -25,11 +25,15 @@ class DatasetConfigManager:
|
|||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get("datasets", []):
|
||||
if not isinstance(dataset, dict):
|
||||
continue
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset["dataset"]
|
||||
if not isinstance(dataset, dict):
|
||||
continue
|
||||
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
|
|
@ -47,15 +51,14 @@ class DatasetConfigManager:
|
|||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
if len(tool) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
tool_item = cast(dict[str, Any], tool)[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -5,12 +5,13 @@ from core.app.app_config.entities import ModelConfigEntity
|
|||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.model import AppModelConfigDict
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> ModelConfigEntity:
|
||||
def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
|
@ -22,7 +23,7 @@ class ModelConfigManager:
|
|||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get("completion_params")
|
||||
completion_params = model_config.get("completion_params") or {}
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
|
|
@ -6,12 +8,12 @@ from core.app.app_config.entities import (
|
|||
)
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
|
||||
class PromptTemplateConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> PromptTemplateEntity:
|
||||
def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity:
|
||||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
|
|
@ -40,14 +42,15 @@ class PromptTemplateConfigManager:
|
|||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
completion_prompt_template_params: dict[str, Any] = {
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
conv_role = completion_prompt_config.get("conversation_histories_role")
|
||||
if conv_role:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
"user": conv_role["user_prefix"],
|
||||
"assistant": conv_role["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import re
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
|
|
@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
|||
|
||||
class BasicVariablesConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
|
@ -51,7 +53,9 @@ class BasicVariablesConfigManager:
|
|||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
variable=variable["variable"],
|
||||
type=variable.get("type", ""),
|
||||
config=variable.get("config", {}),
|
||||
)
|
||||
)
|
||||
elif variable_type in {
|
||||
|
|
@ -64,10 +68,10 @@ class BasicVariablesConfigManager:
|
|||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
type=cast(VariableEntityType, variable_type),
|
||||
variable=variable["variable"],
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
label=variable["label"],
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig):
|
|||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict
|
||||
app_model_config_dict: dict[str, Any]
|
||||
model: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
dataset: DatasetEntity | None = None
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
|||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
override_config_dict: AppModelConfigDict | None = None,
|
||||
) -> AgentChatAppConfig:
|
||||
"""
|
||||
Convert app model config to agent chat app config
|
||||
|
|
@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict or {}
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
config_dict = override_config_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AgentChatAppConfig(
|
||||
|
|
@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
|
|
@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for agent chat app model config
|
||||
|
||||
|
|
@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
||||
@classmethod
|
||||
def validate_agent_mode_and_set_defaults(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
|
|
@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
|||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
|
||||
class ChatAppConfig(EasyUIBasedAppConfig):
|
||||
|
|
@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
override_config_dict: AppModelConfigDict | None = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
|
|
@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
|
|
@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for chat app model config
|
||||
|
||||
|
|
@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
|||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
|
|
|||
|
|
@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner):
|
|||
memory=memory,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
|
|
@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
|
|||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict
|
||||
|
||||
|
||||
class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
|
|
@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
|
|||
class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None
|
||||
) -> CompletionAppConfig:
|
||||
"""
|
||||
Convert app model config to completion app config
|
||||
|
|
@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict or {}
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
config_dict = override_config_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = CompletionAppConfig(
|
||||
|
|
@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
|
|
@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for completion app model config
|
||||
|
||||
|
|
@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
|||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
raise ValueError("Message app_model_config is None")
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params")
|
||||
completion_params = model_dict.get("completion_params", {})
|
||||
completion_params["temperature"] = 0.9
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
|
|
|||
|
|
@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner):
|
|||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Union, cast
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -219,14 +219,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
|
||||
text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech"))
|
||||
if (
|
||||
text_to_speech_dict
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
|
||||
tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None)
|
||||
)
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence
|
|||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -15,6 +15,7 @@ from flask import request
|
|||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
|
|
@ -36,6 +37,259 @@ if TYPE_CHECKING:
|
|||
from .workflow import Workflow
|
||||
|
||||
|
||||
# --- TypedDict definitions for structured dict return types ---
|
||||
|
||||
|
||||
class EnabledConfig(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EmbeddingModelInfo(TypedDict):
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class AnnotationReplyDisabledConfig(TypedDict):
|
||||
enabled: Literal[False]
|
||||
|
||||
|
||||
class AnnotationReplyEnabledConfig(TypedDict):
|
||||
id: str
|
||||
enabled: Literal[True]
|
||||
score_threshold: float
|
||||
embedding_model: EmbeddingModelInfo
|
||||
|
||||
|
||||
AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfig(TypedDict):
|
||||
enabled: bool
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class AgentToolConfig(TypedDict):
|
||||
provider_type: str
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any]
|
||||
plugin_unique_identifier: NotRequired[str | None]
|
||||
credential_id: NotRequired[str | None]
|
||||
|
||||
|
||||
class AgentModeConfig(TypedDict):
|
||||
enabled: bool
|
||||
strategy: str | None
|
||||
tools: list[AgentToolConfig | dict[str, Any]]
|
||||
prompt: str | None
|
||||
|
||||
|
||||
class ImageUploadConfig(TypedDict):
|
||||
enabled: bool
|
||||
number_limits: int
|
||||
detail: str
|
||||
transfer_methods: list[str]
|
||||
|
||||
|
||||
class FileUploadConfig(TypedDict):
|
||||
image: ImageUploadConfig
|
||||
|
||||
|
||||
class DeletedToolInfo(TypedDict):
|
||||
type: str
|
||||
tool_name: str
|
||||
provider_id: str
|
||||
|
||||
|
||||
class ExternalDataToolConfig(TypedDict):
|
||||
enabled: bool
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class UserInputFormItemConfig(TypedDict):
|
||||
variable: str
|
||||
label: str
|
||||
description: NotRequired[str]
|
||||
required: NotRequired[bool]
|
||||
max_length: NotRequired[int]
|
||||
options: NotRequired[list[str]]
|
||||
default: NotRequired[str]
|
||||
type: NotRequired[str]
|
||||
config: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig}
|
||||
UserInputFormItem = dict[str, UserInputFormItemConfig]
|
||||
|
||||
|
||||
class DatasetConfigs(TypedDict):
|
||||
retrieval_model: str
|
||||
datasets: NotRequired[dict[str, Any]]
|
||||
top_k: NotRequired[int]
|
||||
score_threshold: NotRequired[float]
|
||||
score_threshold_enabled: NotRequired[bool]
|
||||
reranking_model: NotRequired[dict[str, Any] | None]
|
||||
weights: NotRequired[dict[str, Any] | None]
|
||||
reranking_enabled: NotRequired[bool]
|
||||
reranking_mode: NotRequired[str]
|
||||
metadata_filtering_mode: NotRequired[str]
|
||||
metadata_model_config: NotRequired[dict[str, Any] | None]
|
||||
metadata_filtering_conditions: NotRequired[dict[str, Any] | None]
|
||||
|
||||
|
||||
class ChatPromptMessage(TypedDict):
|
||||
text: str
|
||||
role: str
|
||||
|
||||
|
||||
class ChatPromptConfig(TypedDict, total=False):
|
||||
prompt: list[ChatPromptMessage]
|
||||
|
||||
|
||||
class CompletionPromptText(TypedDict):
|
||||
text: str
|
||||
|
||||
|
||||
class ConversationHistoriesRole(TypedDict):
|
||||
user_prefix: str
|
||||
assistant_prefix: str
|
||||
|
||||
|
||||
class CompletionPromptConfig(TypedDict):
|
||||
prompt: CompletionPromptText
|
||||
conversation_histories_role: NotRequired[ConversationHistoriesRole]
|
||||
|
||||
|
||||
class ModelConfig(TypedDict):
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class AppModelConfigDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: EnabledConfig
|
||||
speech_to_text: EnabledConfig
|
||||
text_to_speech: EnabledConfig
|
||||
retriever_resource: EnabledConfig
|
||||
annotation_reply: AnnotationReplyConfig
|
||||
more_like_this: EnabledConfig
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceConfig
|
||||
external_data_tools: list[ExternalDataToolConfig]
|
||||
model: ModelConfig
|
||||
user_input_form: list[UserInputFormItem]
|
||||
dataset_query_variable: str | None
|
||||
pre_prompt: str | None
|
||||
agent_mode: AgentModeConfig
|
||||
prompt_type: str
|
||||
chat_prompt_config: ChatPromptConfig
|
||||
completion_prompt_config: CompletionPromptConfig
|
||||
dataset_configs: DatasetConfigs
|
||||
file_upload: FileUploadConfig
|
||||
# Added dynamically in Conversation.model_config
|
||||
model_id: NotRequired[str | None]
|
||||
provider: NotRequired[str | None]
|
||||
|
||||
|
||||
class ConversationDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
app_model_config_id: str | None
|
||||
model_provider: str | None
|
||||
override_model_configs: str | None
|
||||
model_id: str | None
|
||||
mode: str
|
||||
name: str
|
||||
summary: str | None
|
||||
inputs: dict[str, Any]
|
||||
introduction: str | None
|
||||
system_instruction: str | None
|
||||
system_instruction_tokens: int
|
||||
status: str
|
||||
invoke_from: str | None
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
read_at: datetime | None
|
||||
read_account_id: str | None
|
||||
dialogue_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MessageDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
model_id: str | None
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
total_price: Decimal | None
|
||||
message: dict[str, Any]
|
||||
answer: str
|
||||
status: str
|
||||
error: str | None
|
||||
message_metadata: dict[str, Any]
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
agent_based: bool
|
||||
workflow_run_id: str | None
|
||||
|
||||
|
||||
class MessageFeedbackDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class MessageFileInfo(TypedDict, total=False):
|
||||
belongs_to: str | None
|
||||
upload_file_id: str | None
|
||||
id: str
|
||||
tenant_id: str
|
||||
type: str
|
||||
transfer_method: str
|
||||
remote_url: str | None
|
||||
related_id: str | None
|
||||
filename: str | None
|
||||
extension: str | None
|
||||
mime_type: str | None
|
||||
size: int
|
||||
dify_model_identity: str
|
||||
url: str | None
|
||||
|
||||
|
||||
class ExtraContentDict(TypedDict, total=False):
|
||||
type: str
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class TraceAppConfigDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
tracing_provider: str | None
|
||||
tracing_config: dict[str, Any]
|
||||
is_active: bool
|
||||
created_at: str | None
|
||||
updated_at: str | None
|
||||
|
||||
|
||||
class DifySetup(TypeBase):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
|
|
@ -176,7 +430,7 @@ class App(Base):
|
|||
return str(self.mode)
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list[dict[str, str]]:
|
||||
def deleted_tools(self) -> list[DeletedToolInfo]:
|
||||
from core.tools.tool_manager import ToolManager, ToolProviderType
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
|
@ -257,7 +511,7 @@ class App(Base):
|
|||
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
|
||||
}
|
||||
|
||||
deleted_tools: list[dict[str, str]] = []
|
||||
deleted_tools: list[DeletedToolInfo] = []
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
|
|
@ -364,35 +618,38 @@ class AppModelConfig(TypeBase):
|
|||
return app
|
||||
|
||||
@property
|
||||
def model_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.model) if self.model else {}
|
||||
def model_dict(self) -> ModelConfig:
|
||||
return cast(ModelConfig, json.loads(self.model) if self.model else {})
|
||||
|
||||
@property
|
||||
def suggested_questions_list(self) -> list[str]:
|
||||
return json.loads(self.suggested_questions) if self.suggested_questions else []
|
||||
|
||||
@property
|
||||
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig,
|
||||
json.loads(self.suggested_questions_after_answer)
|
||||
if self.suggested_questions_after_answer
|
||||
else {"enabled": False}
|
||||
else {"enabled": False},
|
||||
)
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
|
||||
def speech_to_text_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
|
||||
|
||||
@property
|
||||
def text_to_speech_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
|
||||
def text_to_speech_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
|
||||
|
||||
@property
|
||||
def retriever_resource_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
def retriever_resource_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
)
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> dict[str, Any]:
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
)
|
||||
|
|
@ -415,56 +672,62 @@ class AppModelConfig(TypeBase):
|
|||
return {"enabled": False}
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
|
||||
return cast(
|
||||
SensitiveWordAvoidanceConfig,
|
||||
json.loads(self.sensitive_word_avoidance)
|
||||
if self.sensitive_word_avoidance
|
||||
else {"enabled": False, "type": "", "configs": []}
|
||||
else {"enabled": False, "type": "", "config": {}},
|
||||
)
|
||||
|
||||
@property
|
||||
def external_data_tools_list(self) -> list[dict[str, Any]]:
|
||||
def external_data_tools_list(self) -> list[ExternalDataToolConfig]:
|
||||
return json.loads(self.external_data_tools) if self.external_data_tools else []
|
||||
|
||||
@property
|
||||
def user_input_form_list(self) -> list[dict[str, Any]]:
|
||||
def user_input_form_list(self) -> list[UserInputFormItem]:
|
||||
return json.loads(self.user_input_form) if self.user_input_form else []
|
||||
|
||||
@property
|
||||
def agent_mode_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def agent_mode_dict(self) -> AgentModeConfig:
|
||||
return cast(
|
||||
AgentModeConfig,
|
||||
json.loads(self.agent_mode)
|
||||
if self.agent_mode
|
||||
else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
|
||||
else {"enabled": False, "strategy": None, "tools": [], "prompt": None},
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
|
||||
def chat_prompt_config_dict(self) -> ChatPromptConfig:
|
||||
return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {})
|
||||
|
||||
@property
|
||||
def completion_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
|
||||
def completion_prompt_config_dict(self) -> CompletionPromptConfig:
|
||||
return cast(
|
||||
CompletionPromptConfig,
|
||||
json.loads(self.completion_prompt_config) if self.completion_prompt_config else {},
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_configs_dict(self) -> dict[str, Any]:
|
||||
def dataset_configs_dict(self) -> DatasetConfigs:
|
||||
if self.dataset_configs:
|
||||
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
|
||||
dataset_configs = json.loads(self.dataset_configs)
|
||||
if "retrieval_model" not in dataset_configs:
|
||||
return {"retrieval_model": "single"}
|
||||
else:
|
||||
return dataset_configs
|
||||
return cast(DatasetConfigs, dataset_configs)
|
||||
return {
|
||||
"retrieval_model": "multiple",
|
||||
}
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def file_upload_dict(self) -> FileUploadConfig:
|
||||
return cast(
|
||||
FileUploadConfig,
|
||||
json.loads(self.file_upload)
|
||||
if self.file_upload
|
||||
else {
|
||||
|
|
@ -474,10 +737,10 @@ class AppModelConfig(TypeBase):
|
|||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> AppModelConfigDict:
|
||||
return {
|
||||
"opening_statement": self.opening_statement,
|
||||
"suggested_questions": self.suggested_questions_list,
|
||||
|
|
@ -501,36 +764,42 @@ class AppModelConfig(TypeBase):
|
|||
"file_upload": self.file_upload_dict,
|
||||
}
|
||||
|
||||
def from_model_config_dict(self, model_config: Mapping[str, Any]):
|
||||
def from_model_config_dict(self, model_config: AppModelConfigDict):
|
||||
self.opening_statement = model_config.get("opening_statement")
|
||||
self.suggested_questions = (
|
||||
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
|
||||
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
|
||||
)
|
||||
self.suggested_questions_after_answer = (
|
||||
json.dumps(model_config["suggested_questions_after_answer"])
|
||||
json.dumps(model_config.get("suggested_questions_after_answer"))
|
||||
if model_config.get("suggested_questions_after_answer")
|
||||
else None
|
||||
)
|
||||
self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None
|
||||
self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None
|
||||
self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None
|
||||
self.speech_to_text = (
|
||||
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
|
||||
)
|
||||
self.text_to_speech = (
|
||||
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
|
||||
)
|
||||
self.more_like_this = (
|
||||
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
|
||||
)
|
||||
self.sensitive_word_avoidance = (
|
||||
json.dumps(model_config["sensitive_word_avoidance"])
|
||||
json.dumps(model_config.get("sensitive_word_avoidance"))
|
||||
if model_config.get("sensitive_word_avoidance")
|
||||
else None
|
||||
)
|
||||
self.external_data_tools = (
|
||||
json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None
|
||||
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
|
||||
)
|
||||
self.model = json.dumps(model_config["model"]) if model_config.get("model") else None
|
||||
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
|
||||
self.user_input_form = (
|
||||
json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None
|
||||
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
|
||||
)
|
||||
self.dataset_query_variable = model_config.get("dataset_query_variable")
|
||||
self.pre_prompt = model_config["pre_prompt"]
|
||||
self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None
|
||||
self.pre_prompt = model_config.get("pre_prompt")
|
||||
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
|
||||
self.retriever_resource = (
|
||||
json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None
|
||||
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
|
||||
)
|
||||
self.prompt_type = model_config.get("prompt_type", "simple")
|
||||
self.chat_prompt_config = (
|
||||
|
|
@ -823,24 +1092,26 @@ class Conversation(Base):
|
|||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
def model_config(self) -> AppModelConfigDict:
|
||||
model_config = cast(AppModelConfigDict, {})
|
||||
app_model_config: AppModelConfig | None = None
|
||||
|
||||
if self.mode == AppMode.ADVANCED_CHAT:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
model_config = override_model_configs
|
||||
model_config = cast(AppModelConfigDict, override_model_configs)
|
||||
else:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
|
||||
if "model" in override_model_configs:
|
||||
# where is app_id?
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(
|
||||
cast(AppModelConfigDict, override_model_configs)
|
||||
)
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
else:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
|
|
@ -1015,7 +1286,7 @@ class Conversation(Base):
|
|||
def in_debug_mode(self) -> bool:
|
||||
return self.override_model_configs is not None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> ConversationDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
|
|
@ -1295,7 +1566,7 @@ class Message(Base):
|
|||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||
|
||||
@property
|
||||
def message_files(self) -> list[dict[str, Any]]:
|
||||
def message_files(self) -> list[MessageFileInfo]:
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
|
|
@ -1350,10 +1621,13 @@ class Message(Base):
|
|||
)
|
||||
files.append(file)
|
||||
|
||||
result: list[dict[str, Any]] = [
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
]
|
||||
result = cast(
|
||||
list[MessageFileInfo],
|
||||
[
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
],
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
return result
|
||||
|
|
@ -1363,7 +1637,7 @@ class Message(Base):
|
|||
self._extra_contents = list(contents)
|
||||
|
||||
@property
|
||||
def extra_contents(self) -> list[dict[str, Any]]:
|
||||
def extra_contents(self) -> list[ExtraContentDict]:
|
||||
return getattr(self, "_extra_contents", [])
|
||||
|
||||
@property
|
||||
|
|
@ -1379,7 +1653,7 @@ class Message(Base):
|
|||
|
||||
return None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> MessageDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
|
|
@ -1403,7 +1677,7 @@ class Message(Base):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Message:
|
||||
def from_dict(cls, data: MessageDict) -> Message:
|
||||
return cls(
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
|
|
@ -1463,7 +1737,7 @@ class MessageFeedback(TypeBase):
|
|||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"app_id": str(self.app_id),
|
||||
|
|
@ -1726,8 +2000,8 @@ class AppMCPServer(TypeBase):
|
|||
return result
|
||||
|
||||
@property
|
||||
def parameters_dict(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], json.loads(self.parameters))
|
||||
def parameters_dict(self) -> dict[str, str]:
|
||||
return cast(dict[str, str], json.loads(self.parameters))
|
||||
|
||||
|
||||
class Site(Base):
|
||||
|
|
@ -2167,7 +2441,7 @@ class TraceAppConfig(TypeBase):
|
|||
def tracing_config_str(self) -> str:
|
||||
return json.dumps(self.tracing_config_dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> TraceAppConfigDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client
|
|||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, IconType
|
||||
from models.model import AppModelConfig, AppModelConfigDict, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
|
|
@ -523,7 +524,7 @@ class AppDslService:
|
|||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(model_config)
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import TypedDict, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
|
@ -187,7 +187,7 @@ class AppService:
|
|||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
|
|
@ -388,7 +388,7 @@ class AppService:
|
|||
agent_config = app_model_config.agent_mode_dict
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get("tools", [])
|
||||
tools = cast(list[dict[str, Any]], agent_config.get("tools", []))
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import io
|
|||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
|
@ -106,7 +107,7 @@ class AudioService:
|
|||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = text_to_speech_dict.get("voice")
|
||||
voice = cast(str | None, text_to_speech_dict.get("voice"))
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
|
|
|
|||
Loading…
Reference in New Issue