dify/api/core/plugin/impl/model_runtime.py

494 lines
17 KiB
Python
Raw Normal View History

from __future__ import annotations
import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Union
from pydantic import ValidationError
from redis import RedisError
from configs import dify_config
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.entities.provider_entities import ProviderEntity
from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from dify_graph.model_runtime.runtime import ModelRuntime
from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and a default user."""
tenant_id: str
user_id: str | None
client: PluginModelClient
_provider_entities: tuple[ProviderEntity, ...] | None
_provider_entities_lock: Lock
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
if client is None:
raise ValueError("client is required.")
self.tenant_id = tenant_id
self.user_id = user_id
self.client = client
self._provider_entities = None
self._provider_entities_lock = Lock()
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
if self._provider_entities is not None:
return self._provider_entities
with self._provider_entities_lock:
if self._provider_entities is None:
self._provider_entities = tuple(
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
)
return self._provider_entities
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
provider_schema = self._get_provider_schema(provider)
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_Hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_US
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
image_mime_types = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"bmp": "image/bmp",
"tiff": "image/tiff",
"tif": "image/tiff",
"webp": "image/webp",
"svg": "image/svg+xml",
"ico": "image/vnd.microsoft.icon",
"heif": "image/heif",
"heic": "image/heic",
}
extension = file_name.split(".")[-1]
mime_type = image_mime_types.get(extension, "image/png")
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
credentials=credentials,
)
def validate_model_credentials(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_model_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
def get_model_schema(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> AIModelEntity | None:
cache_key = self._get_schema_cache_key(
provider=provider,
model_type=model_type,
model=model,
credentials=credentials,
)
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
plugin_id, provider_name = self._split_provider(provider)
schema = self.client.get_model_schema(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=list(prompt_messages),
tools=tools,
stop=list(stop) if stop else None,
stream=stream,
)
def get_llm_num_tokens(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
) -> int:
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
return 0
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def invoke_text_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
def invoke_multimodal_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
documents: list[dict[str, Any]],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
documents=documents,
input_type=input_type,
)
def get_text_embedding_num_tokens(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
)
def invoke_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_multimodal_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_tts(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Iterable[bytes]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_tts(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def get_tts_model_voices(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
language: str | None,
) -> Any:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
language=language,
)
def invoke_speech_to_text(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
file=file,
)
def invoke_moderation(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
text: str,
) -> bool:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_moderation(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
text=text,
)
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
"""
Expose a bare provider alias only for the canonical provider mapping.
Multiple plugins can publish the same short provider slug. If every
provider entity keeps that slug in ``provider_name``, callers that still
resolve by short name become order-dependent. Restrict the alias to the
provider selected by ``ModelProviderID`` so legacy short-name lookups
remain deterministic while the runtime surface stays canonical.
"""
try:
canonical_provider_id = ModelProviderID(provider.provider)
except ValueError:
return ""
if canonical_provider_id.plugin_id != provider.plugin_id:
return ""
if canonical_provider_id.provider_name != provider.provider:
return ""
return provider.provider
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
declaration = provider.declaration.model_copy(deep=True)
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
declaration.provider_name = self._get_provider_short_name_alias(provider)
return declaration
def _get_provider_schema(self, provider: str) -> ProviderEntity:
providers = self.fetch_model_providers()
provider_entity = next((item for item in providers if item.provider == provider), None)
if provider_entity is None:
provider_entity = next((item for item in providers if provider == item.provider_name), None)
if provider_entity is None:
raise ValueError(f"Invalid provider: {provider}")
return provider_entity
def _get_schema_cache_key(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> str:
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}"
sorted_credentials = sorted(credentials.items()) if credentials else []
if not sorted_credentials:
return cache_key
hashed_credentials = ":".join(
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
)
return f"{cache_key}:{hashed_credentials}"
def _split_provider(self, provider: str) -> tuple[str, str]:
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name