mirror of https://github.com/langgenius/dify.git
refactor(workflow): inject credential/model access ports into LLM nodes (#32569)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
d20880d102
commit
a694533fc9
|
|
@ -89,7 +89,6 @@ forbidden_modules =
|
||||||
core.logging
|
core.logging
|
||||||
core.mcp
|
core.mcp
|
||||||
core.memory
|
core.memory
|
||||||
core.model_manager
|
|
||||||
core.moderation
|
core.moderation
|
||||||
core.ops
|
core.ops
|
||||||
core.plugin
|
core.plugin
|
||||||
|
|
@ -117,6 +116,7 @@ ignore_imports =
|
||||||
core.workflow.nodes.llm.llm_utils -> configs
|
core.workflow.nodes.llm.llm_utils -> configs
|
||||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||||
|
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||||
core.workflow.nodes.llm.llm_utils -> models.model
|
core.workflow.nodes.llm.llm_utils -> models.model
|
||||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner):
|
||||||
|
|
||||||
# check if model supports stream tool call
|
# check if model supports stream tool call
|
||||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||||
features = model_schema.features if model_schema and model_schema.features else []
|
features = model_schema.features if model_schema and model_schema.features else []
|
||||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||||
|
|
|
||||||
|
|
@ -245,7 +245,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||||
iteration_step += 1
|
iteration_step += 1
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model_instance.model,
|
model=model_instance.model_name,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||||
|
|
@ -268,7 +268,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueMessageEndEvent(
|
QueueMessageEndEvent(
|
||||||
llm_result=LLMResult(
|
llm_result=LLMResult(
|
||||||
model=model_instance.model,
|
model=model_instance.model_name,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(content=final_answer),
|
message=AssistantPromptMessage(content=final_answer),
|
||||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||||
|
|
|
||||||
|
|
@ -178,7 +178,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
)
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model_instance.model,
|
model=model_instance.model_name,
|
||||||
prompt_messages=result.prompt_messages,
|
prompt_messages=result.prompt_messages,
|
||||||
system_fingerprint=result.system_fingerprint,
|
system_fingerprint=result.system_fingerprint,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
|
|
@ -308,7 +308,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueMessageEndEvent(
|
QueueMessageEndEvent(
|
||||||
llm_result=LLMResult(
|
llm_result=LLMResult(
|
||||||
model=model_instance.model,
|
model=model_instance.model_name,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(content=final_answer),
|
message=AssistantPromptMessage(content=final_answer),
|
||||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||||
|
|
|
||||||
|
|
@ -178,7 +178,7 @@ class AgentChatAppRunner(AppRunner):
|
||||||
|
|
||||||
# change function call strategy based on LLM model
|
# change function call strategy based on LLM model
|
||||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ValueError("Model schema not found")
|
raise ValueError("Model schema not found")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""LLM-related application services."""
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.provider_manager import ProviderManager
|
||||||
|
from core.workflow.nodes.llm.entities import ModelConfig
|
||||||
|
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||||
|
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||||
|
|
||||||
|
|
||||||
|
class DifyCredentialsProvider:
|
||||||
|
tenant_id: str
|
||||||
|
provider_manager: ProviderManager
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.provider_manager = provider_manager or ProviderManager()
|
||||||
|
|
||||||
|
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||||
|
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||||
|
provider_configuration = provider_configurations.get(provider_name)
|
||||||
|
if not provider_configuration:
|
||||||
|
raise ValueError(f"Provider {provider_name} does not exist.")
|
||||||
|
|
||||||
|
provider_model = provider_configuration.get_provider_model(model_type=ModelType.LLM, model=model_name)
|
||||||
|
if provider_model is None:
|
||||||
|
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||||
|
provider_model.raise_for_status()
|
||||||
|
|
||||||
|
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model_name)
|
||||||
|
if credentials is None:
|
||||||
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
|
class DifyModelFactory:
|
||||||
|
tenant_id: str
|
||||||
|
model_manager: ModelManager
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.model_manager = model_manager or ModelManager()
|
||||||
|
|
||||||
|
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||||
|
return self.model_manager.get_model_instance(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=provider_name,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
|
||||||
|
return (
|
||||||
|
DifyCredentialsProvider(tenant_id=tenant_id),
|
||||||
|
DifyModelFactory(tenant_id=tenant_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_model_config(
|
||||||
|
*,
|
||||||
|
node_data_model: ModelConfig,
|
||||||
|
credentials_provider: CredentialsProvider,
|
||||||
|
model_factory: ModelFactory,
|
||||||
|
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
|
if not node_data_model.mode:
|
||||||
|
raise LLMModeRequiredError("LLM mode is required.")
|
||||||
|
|
||||||
|
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||||
|
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||||
|
provider_model_bundle = model_instance.provider_model_bundle
|
||||||
|
|
||||||
|
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||||
|
model=node_data_model.name,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
)
|
||||||
|
if provider_model is None:
|
||||||
|
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||||
|
provider_model.raise_for_status()
|
||||||
|
|
||||||
|
stop: list[str] = []
|
||||||
|
if "stop" in node_data_model.completion_params:
|
||||||
|
stop = node_data_model.completion_params.pop("stop")
|
||||||
|
|
||||||
|
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||||
|
if not model_schema:
|
||||||
|
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||||
|
|
||||||
|
return model_instance, ModelConfigWithCredentialsEntity(
|
||||||
|
provider=node_data_model.provider,
|
||||||
|
model=node_data_model.name,
|
||||||
|
model_schema=model_schema,
|
||||||
|
mode=node_data_model.mode,
|
||||||
|
provider_model_bundle=provider_model_bundle,
|
||||||
|
credentials=credentials,
|
||||||
|
parameters=node_data_model.completion_params,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, final
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.app.llm.model_access import build_dify_model_access
|
||||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||||
from core.helper.ssrf_proxy import ssrf_proxy
|
from core.helper.ssrf_proxy import ssrf_proxy
|
||||||
|
|
@ -20,8 +21,13 @@ from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||||
from core.workflow.nodes.template_transform.template_renderer import CodeExecutorJinja2TemplateRenderer
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
|
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||||
|
from core.workflow.nodes.template_transform.template_renderer import (
|
||||||
|
CodeExecutorJinja2TemplateRenderer,
|
||||||
|
)
|
||||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -95,6 +101,8 @@ class DifyNodeFactory(NodeFactory):
|
||||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||||
"""
|
"""
|
||||||
|
|
@ -160,6 +168,16 @@ class DifyNodeFactory(NodeFactory):
|
||||||
file_manager=self._http_request_file_manager,
|
file_manager=self._http_request_file_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if node_type == NodeType.LLM:
|
||||||
|
return LLMNode(
|
||||||
|
id=node_id,
|
||||||
|
config=node_config,
|
||||||
|
graph_init_params=self.graph_init_params,
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
|
credentials_provider=self._llm_credentials_provider,
|
||||||
|
model_factory=self._llm_model_factory,
|
||||||
|
)
|
||||||
|
|
||||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||||
return KnowledgeRetrievalNode(
|
return KnowledgeRetrievalNode(
|
||||||
id=node_id,
|
id=node_id,
|
||||||
|
|
@ -178,6 +196,26 @@ class DifyNodeFactory(NodeFactory):
|
||||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||||
|
return QuestionClassifierNode(
|
||||||
|
id=node_id,
|
||||||
|
config=node_config,
|
||||||
|
graph_init_params=self.graph_init_params,
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
|
credentials_provider=self._llm_credentials_provider,
|
||||||
|
model_factory=self._llm_model_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||||
|
return ParameterExtractorNode(
|
||||||
|
id=node_id,
|
||||||
|
config=node_config,
|
||||||
|
graph_init_params=self.graph_init_params,
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
|
credentials_provider=self._llm_credentials_provider,
|
||||||
|
model_factory=self._llm_model_factory,
|
||||||
|
)
|
||||||
|
|
||||||
return node_class(
|
return node_class(
|
||||||
id=node_id,
|
id=node_id,
|
||||||
config=node_config,
|
config=node_config,
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class ModelInstance:
|
||||||
|
|
||||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||||
self.provider_model_bundle = provider_model_bundle
|
self.provider_model_bundle = provider_model_bundle
|
||||||
self.model = model
|
self.model_name = model
|
||||||
self.provider = provider_model_bundle.configuration.provider.provider
|
self.provider = provider_model_bundle.configuration.provider.provider
|
||||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||||
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
||||||
|
|
@ -163,7 +163,7 @@ class ModelInstance:
|
||||||
Union[LLMResult, Generator],
|
Union[LLMResult, Generator],
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
|
|
@ -191,7 +191,7 @@ class ModelInstance:
|
||||||
int,
|
int,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.get_num_tokens,
|
function=self.model_type_instance.get_num_tokens,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
|
@ -215,7 +215,7 @@ class ModelInstance:
|
||||||
EmbeddingResult,
|
EmbeddingResult,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
texts=texts,
|
texts=texts,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
@ -243,7 +243,7 @@ class ModelInstance:
|
||||||
EmbeddingResult,
|
EmbeddingResult,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
multimodel_documents=multimodel_documents,
|
multimodel_documents=multimodel_documents,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
@ -264,7 +264,7 @@ class ModelInstance:
|
||||||
list[int],
|
list[int],
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.get_num_tokens,
|
function=self.model_type_instance.get_num_tokens,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
texts=texts,
|
texts=texts,
|
||||||
),
|
),
|
||||||
|
|
@ -294,7 +294,7 @@ class ModelInstance:
|
||||||
RerankResult,
|
RerankResult,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
query=query,
|
query=query,
|
||||||
docs=docs,
|
docs=docs,
|
||||||
|
|
@ -328,7 +328,7 @@ class ModelInstance:
|
||||||
RerankResult,
|
RerankResult,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
query=query,
|
query=query,
|
||||||
docs=docs,
|
docs=docs,
|
||||||
|
|
@ -352,7 +352,7 @@ class ModelInstance:
|
||||||
bool,
|
bool,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
text=text,
|
text=text,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
@ -373,7 +373,7 @@ class ModelInstance:
|
||||||
str,
|
str,
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
file=file,
|
file=file,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
@ -396,7 +396,7 @@ class ModelInstance:
|
||||||
Iterable[bytes],
|
Iterable[bytes],
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
content_text=content_text,
|
content_text=content_text,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
@ -469,7 +469,7 @@ class ModelInstance:
|
||||||
if not isinstance(self.model_type_instance, TTSModel):
|
if not isinstance(self.model_type_instance, TTSModel):
|
||||||
raise Exception("Model type instance is not TTSModel")
|
raise Exception("Model type instance is not TTSModel")
|
||||||
return self.model_type_instance.get_tts_model_voices(
|
return self.model_type_instance.get_tts_model_voices(
|
||||||
model=self.model, credentials=self.credentials, language=language
|
model=self.model_name, credentials=self.credentials, language=language
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,9 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
|
||||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||||
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
|
self.model_config.model,
|
||||||
|
self.model_config.credentials,
|
||||||
|
self.history_messages,
|
||||||
)
|
)
|
||||||
if curr_message_tokens <= max_token_limit:
|
if curr_message_tokens <= max_token_limit:
|
||||||
return self.history_messages
|
return self.history_messages
|
||||||
|
|
@ -63,7 +65,9 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||||
# a message is start with UserPromptMessage
|
# a message is start with UserPromptMessage
|
||||||
if isinstance(prompt_message, UserPromptMessage):
|
if isinstance(prompt_message, UserPromptMessage):
|
||||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||||
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
|
self.model_config.model,
|
||||||
|
self.model_config.credentials,
|
||||||
|
prompt_messages,
|
||||||
)
|
)
|
||||||
# if current message token is overflow, drop all the prompts in current message and break
|
# if current message token is overflow, drop all the prompts in current message and break
|
||||||
if curr_message_tokens > max_token_limit:
|
if curr_message_tokens > max_token_limit:
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,9 @@ class CacheEmbedding(Embeddings):
|
||||||
embedding = (
|
embedding = (
|
||||||
db.session.query(Embedding)
|
db.session.query(Embedding)
|
||||||
.filter_by(
|
.filter_by(
|
||||||
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
|
model_name=self._model_instance.model_name,
|
||||||
|
hash=hash,
|
||||||
|
provider_name=self._model_instance.provider,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
@ -52,7 +54,7 @@ class CacheEmbedding(Embeddings):
|
||||||
try:
|
try:
|
||||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||||
model_schema = model_type_instance.get_model_schema(
|
model_schema = model_type_instance.get_model_schema(
|
||||||
self._model_instance.model, self._model_instance.credentials
|
self._model_instance.model_name, self._model_instance.credentials
|
||||||
)
|
)
|
||||||
max_chunks = (
|
max_chunks = (
|
||||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||||
|
|
@ -87,7 +89,7 @@ class CacheEmbedding(Embeddings):
|
||||||
hash = helper.generate_text_hash(texts[i])
|
hash = helper.generate_text_hash(texts[i])
|
||||||
if hash not in cache_embeddings:
|
if hash not in cache_embeddings:
|
||||||
embedding_cache = Embedding(
|
embedding_cache = Embedding(
|
||||||
model_name=self._model_instance.model,
|
model_name=self._model_instance.model_name,
|
||||||
hash=hash,
|
hash=hash,
|
||||||
provider_name=self._model_instance.provider,
|
provider_name=self._model_instance.provider,
|
||||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||||
|
|
@ -114,7 +116,9 @@ class CacheEmbedding(Embeddings):
|
||||||
embedding = (
|
embedding = (
|
||||||
db.session.query(Embedding)
|
db.session.query(Embedding)
|
||||||
.filter_by(
|
.filter_by(
|
||||||
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
|
model_name=self._model_instance.model_name,
|
||||||
|
hash=file_id,
|
||||||
|
provider_name=self._model_instance.provider,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
@ -131,7 +135,7 @@ class CacheEmbedding(Embeddings):
|
||||||
try:
|
try:
|
||||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||||
model_schema = model_type_instance.get_model_schema(
|
model_schema = model_type_instance.get_model_schema(
|
||||||
self._model_instance.model, self._model_instance.credentials
|
self._model_instance.model_name, self._model_instance.credentials
|
||||||
)
|
)
|
||||||
max_chunks = (
|
max_chunks = (
|
||||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||||
|
|
@ -168,7 +172,7 @@ class CacheEmbedding(Embeddings):
|
||||||
file_id = multimodel_documents[i]["file_id"]
|
file_id = multimodel_documents[i]["file_id"]
|
||||||
if file_id not in cache_embeddings:
|
if file_id not in cache_embeddings:
|
||||||
embedding_cache = Embedding(
|
embedding_cache = Embedding(
|
||||||
model_name=self._model_instance.model,
|
model_name=self._model_instance.model_name,
|
||||||
hash=file_id,
|
hash=file_id,
|
||||||
provider_name=self._model_instance.provider,
|
provider_name=self._model_instance.provider,
|
||||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||||
|
|
@ -190,7 +194,7 @@ class CacheEmbedding(Embeddings):
|
||||||
"""Embed query text."""
|
"""Embed query text."""
|
||||||
# use doc embedding cache or store if not exists
|
# use doc embedding cache or store if not exists
|
||||||
hash = helper.generate_text_hash(text)
|
hash = helper.generate_text_hash(text)
|
||||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
|
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
|
||||||
embedding = redis_client.get(embedding_cache_key)
|
embedding = redis_client.get(embedding_cache_key)
|
||||||
if embedding:
|
if embedding:
|
||||||
redis_client.expire(embedding_cache_key, 600)
|
redis_client.expire(embedding_cache_key, 600)
|
||||||
|
|
@ -233,7 +237,7 @@ class CacheEmbedding(Embeddings):
|
||||||
"""Embed multimodal documents."""
|
"""Embed multimodal documents."""
|
||||||
# use doc embedding cache or store if not exists
|
# use doc embedding cache or store if not exists
|
||||||
file_id = multimodel_document["file_id"]
|
file_id = multimodel_document["file_id"]
|
||||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
|
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
|
||||||
embedding = redis_client.get(embedding_cache_key)
|
embedding = redis_client.get(embedding_cache_key)
|
||||||
if embedding:
|
if embedding:
|
||||||
redis_client.expire(embedding_cache_key, 600)
|
redis_client.expire(embedding_cache_key, 600)
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||||
is_support_vision = model_manager.check_model_support_vision(
|
is_support_vision = model_manager.check_model_support_vision(
|
||||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||||
provider=self.rerank_model_instance.provider,
|
provider=self.rerank_model_instance.provider,
|
||||||
model=self.rerank_model_instance.model,
|
model=self.rerank_model_instance.model_name,
|
||||||
model_type=ModelType.RERANK,
|
model_type=ModelType.RERANK,
|
||||||
)
|
)
|
||||||
if not is_support_vision:
|
if not is_support_vision:
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class ModelInvocationUtils:
|
||||||
raise InvokeModelError("Model not found")
|
raise InvokeModelError("Model not found")
|
||||||
|
|
||||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||||
|
|
||||||
if not schema:
|
if not schema:
|
||||||
raise InvokeModelError("No model schema found")
|
raise InvokeModelError("No model schema found")
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
@ -17,6 +17,8 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.file.models import File
|
from core.workflow.file.models import File
|
||||||
from core.workflow.nodes.llm.entities import ModelConfig
|
from core.workflow.nodes.llm.entities import ModelConfig
|
||||||
|
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||||
|
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
|
@ -24,49 +26,46 @@ from models.model import Conversation
|
||||||
from models.provider import Provider, ProviderType
|
from models.provider import Provider, ProviderType
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
|
|
||||||
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
from .exc import InvalidVariableTypeError
|
||||||
|
|
||||||
|
|
||||||
def fetch_model_config(
|
def fetch_model_config(
|
||||||
tenant_id: str, node_data_model: ModelConfig
|
*,
|
||||||
|
node_data_model: ModelConfig,
|
||||||
|
credentials_provider: CredentialsProvider,
|
||||||
|
model_factory: ModelFactory,
|
||||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
if not node_data_model.mode:
|
if not node_data_model.mode:
|
||||||
raise LLMModeRequiredError("LLM mode is required.")
|
raise LLMModeRequiredError("LLM mode is required.")
|
||||||
|
|
||||||
model = ModelManager().get_model_instance(
|
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||||
tenant_id=tenant_id,
|
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||||
model_type=ModelType.LLM,
|
provider_model_bundle = model_instance.provider_model_bundle
|
||||||
provider=node_data_model.provider,
|
|
||||||
|
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||||
model=node_data_model.name,
|
model=node_data_model.name,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
|
||||||
|
|
||||||
# check model
|
|
||||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
|
||||||
model=node_data_model.name, model_type=ModelType.LLM
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider_model is None:
|
if provider_model is None:
|
||||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||||
provider_model.raise_for_status()
|
provider_model.raise_for_status()
|
||||||
|
|
||||||
# model config
|
|
||||||
stop: list[str] = []
|
stop: list[str] = []
|
||||||
if "stop" in node_data_model.completion_params:
|
if "stop" in node_data_model.completion_params:
|
||||||
stop = node_data_model.completion_params.pop("stop")
|
stop = node_data_model.completion_params.pop("stop")
|
||||||
|
|
||||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||||
|
|
||||||
return model, ModelConfigWithCredentialsEntity(
|
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||||
|
return model_instance, ModelConfigWithCredentialsEntity(
|
||||||
provider=node_data_model.provider,
|
provider=node_data_model.provider,
|
||||||
model=node_data_model.name,
|
model=node_data_model.name,
|
||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
mode=node_data_model.mode,
|
mode=node_data_model.mode,
|
||||||
provider_model_bundle=model.provider_model_bundle,
|
provider_model_bundle=provider_model_bundle,
|
||||||
credentials=model.credentials,
|
credentials=credentials,
|
||||||
parameters=node_data_model.completion_params,
|
parameters=node_data_model.completion_params,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
@ -131,7 +130,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||||
if quota_unit == QuotaUnit.TOKENS:
|
if quota_unit == QuotaUnit.TOKENS:
|
||||||
used_quota = usage.total_tokens
|
used_quota = usage.total_tokens
|
||||||
elif quota_unit == QuotaUnit.CREDITS:
|
elif quota_unit == QuotaUnit.CREDITS:
|
||||||
used_quota = dify_config.get_model_credits(model_instance.model)
|
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||||
else:
|
else:
|
||||||
used_quota = 1
|
used_quota = 1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||||
from core.llm_generator.output_parser.errors import OutputParserError
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
|
@ -38,11 +38,7 @@ from core.model_runtime.entities.message_entities import (
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||||
ModelFeature,
|
|
||||||
ModelPropertyKey,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
|
|
@ -76,6 +72,7 @@ from core.workflow.node_events import (
|
||||||
from core.workflow.nodes.base.entities import VariableSelector
|
from core.workflow.nodes.base.entities import VariableSelector
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
|
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import SegmentAttachmentBinding
|
from models.dataset import SegmentAttachmentBinding
|
||||||
|
|
@ -93,7 +90,6 @@ from .exc import (
|
||||||
InvalidVariableTypeError,
|
InvalidVariableTypeError,
|
||||||
LLMNodeError,
|
LLMNodeError,
|
||||||
MemoryRolePrefixRequiredError,
|
MemoryRolePrefixRequiredError,
|
||||||
ModelNotExistError,
|
|
||||||
NoPromptFoundError,
|
NoPromptFoundError,
|
||||||
TemplateTypeNotSupportError,
|
TemplateTypeNotSupportError,
|
||||||
VariableNotFoundError,
|
VariableNotFoundError,
|
||||||
|
|
@ -118,6 +114,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
_file_outputs: list[File]
|
_file_outputs: list[File]
|
||||||
|
|
||||||
_llm_file_saver: LLMFileSaver
|
_llm_file_saver: LLMFileSaver
|
||||||
|
_credentials_provider: CredentialsProvider
|
||||||
|
_model_factory: ModelFactory
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -126,6 +124,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
graph_init_params: GraphInitParams,
|
graph_init_params: GraphInitParams,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
*,
|
*,
|
||||||
|
credentials_provider: CredentialsProvider,
|
||||||
|
model_factory: ModelFactory,
|
||||||
llm_file_saver: LLMFileSaver | None = None,
|
llm_file_saver: LLMFileSaver | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -137,6 +137,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
# LLM file outputs, used for MultiModal outputs.
|
# LLM file outputs, used for MultiModal outputs.
|
||||||
self._file_outputs = []
|
self._file_outputs = []
|
||||||
|
|
||||||
|
self._credentials_provider = credentials_provider
|
||||||
|
self._model_factory = model_factory
|
||||||
|
|
||||||
if llm_file_saver is None:
|
if llm_file_saver is None:
|
||||||
llm_file_saver = FileSaverImpl(
|
llm_file_saver = FileSaverImpl(
|
||||||
user_id=graph_init_params.user_id,
|
user_id=graph_init_params.user_id,
|
||||||
|
|
@ -199,10 +202,21 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
||||||
|
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = LLMNode._fetch_model_config(
|
model_instance, model_config = self._fetch_model_config(
|
||||||
node_data_model=self.node_data.model,
|
node_data_model=self.node_data.model,
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
)
|
)
|
||||||
|
model_name = getattr(model_instance, "model_name", None)
|
||||||
|
if not isinstance(model_name, str):
|
||||||
|
model_name = model_config.model
|
||||||
|
model_provider = getattr(model_instance, "provider", None)
|
||||||
|
if not isinstance(model_provider, str):
|
||||||
|
model_provider = model_config.provider
|
||||||
|
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||||
|
model_name,
|
||||||
|
model_instance.credentials,
|
||||||
|
)
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Model schema not found for {model_name}")
|
||||||
|
|
||||||
# fetch memory
|
# fetch memory
|
||||||
memory = llm_utils.fetch_memory(
|
memory = llm_utils.fetch_memory(
|
||||||
|
|
@ -225,14 +239,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
sys_files=files,
|
sys_files=files,
|
||||||
context=context,
|
context=context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_instance=model_instance,
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_parameters=self.node_data.model.completion_params,
|
||||||
|
stop=model_config.stop,
|
||||||
prompt_template=self.node_data.prompt_template,
|
prompt_template=self.node_data.prompt_template,
|
||||||
memory_config=self.node_data.memory,
|
memory_config=self.node_data.memory,
|
||||||
vision_enabled=self.node_data.vision.enabled,
|
vision_enabled=self.node_data.vision.enabled,
|
||||||
vision_detail=self.node_data.vision.configs.detail,
|
vision_detail=self.node_data.vision.configs.detail,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
context_files=context_files,
|
context_files=context_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -286,14 +302,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
structured_output = event
|
structured_output = event
|
||||||
|
|
||||||
process_data = {
|
process_data = {
|
||||||
"model_mode": model_config.mode,
|
"model_mode": self.node_data.model.mode,
|
||||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
|
||||||
),
|
),
|
||||||
"usage": jsonable_encoder(usage),
|
"usage": jsonable_encoder(usage),
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
"model_provider": model_config.provider,
|
"model_provider": model_provider,
|
||||||
"model_name": model_config.model,
|
"model_name": model_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
|
|
@ -755,21 +771,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _fetch_model_config(
|
def _fetch_model_config(
|
||||||
|
self,
|
||||||
*,
|
*,
|
||||||
node_data_model: ModelConfig,
|
node_data_model: ModelConfig,
|
||||||
tenant_id: str,
|
|
||||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||||
tenant_id=tenant_id, node_data_model=node_data_model
|
node_data_model=node_data_model,
|
||||||
|
credentials_provider=self._credentials_provider,
|
||||||
|
model_factory=self._model_factory,
|
||||||
)
|
)
|
||||||
completion_params = model_config_with_cred.parameters
|
completion_params = model_config_with_cred.parameters
|
||||||
|
|
||||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
|
||||||
if not model_schema:
|
|
||||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
|
||||||
|
|
||||||
model_config_with_cred.parameters = completion_params
|
model_config_with_cred.parameters = completion_params
|
||||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||||
node_data_model.completion_params = completion_params
|
node_data_model.completion_params = completion_params
|
||||||
|
|
@ -782,14 +795,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
sys_files: Sequence[File],
|
sys_files: Sequence[File],
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_instance: ModelInstance,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
model_parameters: Mapping[str, Any],
|
||||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||||
|
stop: Sequence[str] | None = None,
|
||||||
memory_config: MemoryConfig | None = None,
|
memory_config: MemoryConfig | None = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
tenant_id: str,
|
|
||||||
context_files: list[File] | None = None,
|
context_files: list[File] | None = None,
|
||||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
|
@ -810,7 +825,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
memory_messages = _handle_memory_chat_mode(
|
memory_messages = _handle_memory_chat_mode(
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
model_config=model_config,
|
model_instance=model_instance,
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_parameters=model_parameters,
|
||||||
)
|
)
|
||||||
# Extend prompt_messages with memory messages
|
# Extend prompt_messages with memory messages
|
||||||
prompt_messages.extend(memory_messages)
|
prompt_messages.extend(memory_messages)
|
||||||
|
|
@ -847,7 +864,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
memory_text = _handle_memory_completion_mode(
|
memory_text = _handle_memory_completion_mode(
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
model_config=model_config,
|
model_instance=model_instance,
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_parameters=model_parameters,
|
||||||
)
|
)
|
||||||
# Insert histories into the prompt
|
# Insert histories into the prompt
|
||||||
prompt_content = prompt_messages[0].content
|
prompt_content = prompt_messages[0].content
|
||||||
|
|
@ -924,7 +943,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||||
for content_item in prompt_message.content:
|
for content_item in prompt_message.content:
|
||||||
# Skip content if features are not defined
|
# Skip content if features are not defined
|
||||||
if not model_config.model_schema.features:
|
if not model_schema.features:
|
||||||
if content_item.type != PromptMessageContentType.TEXT:
|
if content_item.type != PromptMessageContentType.TEXT:
|
||||||
continue
|
continue
|
||||||
prompt_message_content.append(content_item)
|
prompt_message_content.append(content_item)
|
||||||
|
|
@ -934,19 +953,19 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
if (
|
if (
|
||||||
(
|
(
|
||||||
content_item.type == PromptMessageContentType.IMAGE
|
content_item.type == PromptMessageContentType.IMAGE
|
||||||
and ModelFeature.VISION not in model_config.model_schema.features
|
and ModelFeature.VISION not in model_schema.features
|
||||||
)
|
)
|
||||||
or (
|
or (
|
||||||
content_item.type == PromptMessageContentType.DOCUMENT
|
content_item.type == PromptMessageContentType.DOCUMENT
|
||||||
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
and ModelFeature.DOCUMENT not in model_schema.features
|
||||||
)
|
)
|
||||||
or (
|
or (
|
||||||
content_item.type == PromptMessageContentType.VIDEO
|
content_item.type == PromptMessageContentType.VIDEO
|
||||||
and ModelFeature.VIDEO not in model_config.model_schema.features
|
and ModelFeature.VIDEO not in model_schema.features
|
||||||
)
|
)
|
||||||
or (
|
or (
|
||||||
content_item.type == PromptMessageContentType.AUDIO
|
content_item.type == PromptMessageContentType.AUDIO
|
||||||
and ModelFeature.AUDIO not in model_config.model_schema.features
|
and ModelFeature.AUDIO not in model_schema.features
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
@ -965,19 +984,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
|
|
||||||
model = ModelManager().get_model_instance(
|
return filtered_prompt_messages, stop
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_type=ModelType.LLM,
|
|
||||||
provider=model_config.provider,
|
|
||||||
model=model_config.model,
|
|
||||||
)
|
|
||||||
model_schema = model.model_type_instance.get_model_schema(
|
|
||||||
model=model_config.model,
|
|
||||||
credentials=model.credentials,
|
|
||||||
)
|
|
||||||
if not model_schema:
|
|
||||||
raise ModelNotExistError(f"Model {model_config.model} not exist.")
|
|
||||||
return filtered_prompt_messages, model_config.stop
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
|
@ -1306,26 +1313,26 @@ def _render_jinja2_message(
|
||||||
|
|
||||||
|
|
||||||
def _calculate_rest_token(
|
def _calculate_rest_token(
|
||||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
*,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_instance: ModelInstance,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
model_parameters: Mapping[str, Any],
|
||||||
) -> int:
|
) -> int:
|
||||||
rest_tokens = 2000
|
rest_tokens = 2000
|
||||||
|
|
||||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
if model_context_tokens:
|
if model_context_tokens:
|
||||||
model_instance = ModelInstance(
|
|
||||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
|
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
for parameter_rule in model_schema.parameter_rules:
|
||||||
if parameter_rule.name == "max_tokens" or (
|
if parameter_rule.name == "max_tokens" or (
|
||||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||||
):
|
):
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
model_config.parameters.get(parameter_rule.name)
|
model_parameters.get(parameter_rule.name)
|
||||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
or model_parameters.get(str(parameter_rule.use_template))
|
||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1339,12 +1346,19 @@ def _handle_memory_chat_mode(
|
||||||
*,
|
*,
|
||||||
memory: TokenBufferMemory | None,
|
memory: TokenBufferMemory | None,
|
||||||
memory_config: MemoryConfig | None,
|
memory_config: MemoryConfig | None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_instance: ModelInstance,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
model_parameters: Mapping[str, Any],
|
||||||
) -> Sequence[PromptMessage]:
|
) -> Sequence[PromptMessage]:
|
||||||
memory_messages: Sequence[PromptMessage] = []
|
memory_messages: Sequence[PromptMessage] = []
|
||||||
# Get messages from memory for chat model
|
# Get messages from memory for chat model
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
rest_tokens = _calculate_rest_token(
|
||||||
|
prompt_messages=[],
|
||||||
|
model_instance=model_instance,
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
)
|
||||||
memory_messages = memory.get_history_prompt_messages(
|
memory_messages = memory.get_history_prompt_messages(
|
||||||
max_token_limit=rest_tokens,
|
max_token_limit=rest_tokens,
|
||||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
|
@ -1356,12 +1370,19 @@ def _handle_memory_completion_mode(
|
||||||
*,
|
*,
|
||||||
memory: TokenBufferMemory | None,
|
memory: TokenBufferMemory | None,
|
||||||
memory_config: MemoryConfig | None,
|
memory_config: MemoryConfig | None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_instance: ModelInstance,
|
||||||
|
model_schema: AIModelEntity,
|
||||||
|
model_parameters: Mapping[str, Any],
|
||||||
) -> str:
|
) -> str:
|
||||||
memory_text = ""
|
memory_text = ""
|
||||||
# Get history text from memory for completion model
|
# Get history text from memory for completion model
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
rest_tokens = _calculate_rest_token(
|
||||||
|
prompt_messages=[],
|
||||||
|
model_instance=model_instance,
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
)
|
||||||
if not memory_config.role_prefix:
|
if not memory_config.role_prefix:
|
||||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||||
memory_text = memory.get_history_prompt_text(
|
memory_text = memory.get_history_prompt_text(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsProvider(Protocol):
|
||||||
|
"""Port for loading runtime credentials for a provider/model pair."""
|
||||||
|
|
||||||
|
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||||
|
"""Return credentials for the target provider/model or raise a domain error."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFactory(Protocol):
|
||||||
|
"""Port for creating initialized LLM model instances for execution."""
|
||||||
|
|
||||||
|
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||||
|
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||||
|
...
|
||||||
|
|
@ -3,7 +3,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
|
|
@ -60,6 +60,11 @@ from .prompts import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||||
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
def extract_json(text):
|
def extract_json(text):
|
||||||
"""
|
"""
|
||||||
|
|
@ -92,6 +97,27 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||||
|
|
||||||
_model_instance: ModelInstance | None = None
|
_model_instance: ModelInstance | None = None
|
||||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||||
|
_credentials_provider: "CredentialsProvider"
|
||||||
|
_model_factory: "ModelFactory"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
config: Mapping[str, Any],
|
||||||
|
graph_init_params: "GraphInitParams",
|
||||||
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
|
*,
|
||||||
|
credentials_provider: "CredentialsProvider",
|
||||||
|
model_factory: "ModelFactory",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
id=id,
|
||||||
|
config=config,
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
self._credentials_provider = credentials_provider
|
||||||
|
self._model_factory = model_factory
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
|
|
@ -806,7 +832,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||||
"""
|
"""
|
||||||
if not self._model_instance or not self._model_config:
|
if not self._model_instance or not self._model_config:
|
||||||
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
||||||
tenant_id=self.tenant_id, node_data_model=node_data_model
|
node_data_model=node_data_model,
|
||||||
|
credentials_provider=self._credentials_provider,
|
||||||
|
model_factory=self._model_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._model_instance, self._model_config
|
return self._model_instance, self._model_config
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||||
|
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
|
||||||
from .entities import QuestionClassifierNodeData
|
from .entities import QuestionClassifierNodeData
|
||||||
|
|
@ -49,6 +50,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
|
|
||||||
_file_outputs: list["File"]
|
_file_outputs: list["File"]
|
||||||
_llm_file_saver: LLMFileSaver
|
_llm_file_saver: LLMFileSaver
|
||||||
|
_credentials_provider: "CredentialsProvider"
|
||||||
|
_model_factory: "ModelFactory"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -57,6 +60,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
graph_init_params: "GraphInitParams",
|
graph_init_params: "GraphInitParams",
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
*,
|
*,
|
||||||
|
credentials_provider: "CredentialsProvider",
|
||||||
|
model_factory: "ModelFactory",
|
||||||
llm_file_saver: LLMFileSaver | None = None,
|
llm_file_saver: LLMFileSaver | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -68,6 +73,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
# LLM file outputs, used for MultiModal outputs.
|
# LLM file outputs, used for MultiModal outputs.
|
||||||
self._file_outputs = []
|
self._file_outputs = []
|
||||||
|
|
||||||
|
self._credentials_provider = credentials_provider
|
||||||
|
self._model_factory = model_factory
|
||||||
|
|
||||||
if llm_file_saver is None:
|
if llm_file_saver is None:
|
||||||
llm_file_saver = FileSaverImpl(
|
llm_file_saver = FileSaverImpl(
|
||||||
user_id=graph_init_params.user_id,
|
user_id=graph_init_params.user_id,
|
||||||
|
|
@ -89,9 +97,16 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
variables = {"query": query}
|
variables = {"query": query}
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = llm_utils.fetch_model_config(
|
model_instance, model_config = llm_utils.fetch_model_config(
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
node_data_model=node_data.model,
|
node_data_model=node_data.model,
|
||||||
|
credentials_provider=self._credentials_provider,
|
||||||
|
model_factory=self._model_factory,
|
||||||
)
|
)
|
||||||
|
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||||
|
model_instance.model_name,
|
||||||
|
model_instance.credentials,
|
||||||
|
)
|
||||||
|
if not model_schema:
|
||||||
|
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||||
# fetch memory
|
# fetch memory
|
||||||
memory = llm_utils.fetch_memory(
|
memory = llm_utils.fetch_memory(
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
|
|
@ -133,13 +148,15 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
sys_query="",
|
sys_query="",
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_instance=model_instance,
|
||||||
|
model_schema=model_schema,
|
||||||
|
model_parameters=node_data.model.completion_params,
|
||||||
|
stop=model_config.stop,
|
||||||
sys_files=files,
|
sys_files=files,
|
||||||
vision_enabled=node_data.vision.enabled,
|
vision_enabled=node_data.vision.enabled,
|
||||||
vision_detail=node_data.vision.configs.detail,
|
vision_detail=node_data.vision.configs.detail,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
jinja2_variables=[],
|
jinja2_variables=[],
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result_text = ""
|
result_text = ""
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.exc import GenerateTaskStoppedError
|
from core.app.apps.exc import GenerateTaskStoppedError
|
||||||
|
|
@ -11,6 +10,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer
|
||||||
from core.app.workflow.node_factory import DifyNodeFactory
|
from core.app.workflow.node_factory import DifyNodeFactory
|
||||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.file.models import File
|
from core.workflow.file.models import File
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
|
|
@ -168,7 +168,8 @@ class WorkflowEntry:
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
node = node_factory.create_node(node_config)
|
typed_node_config = cast(dict[str, object], node_config)
|
||||||
|
node = cast(Any, node_factory).create_node(typed_node_config)
|
||||||
node_cls = type(node)
|
node_cls = type(node)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -256,7 +257,7 @@ class WorkflowEntry:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def run_free_node(
|
def run_free_node(
|
||||||
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||||
"""
|
"""
|
||||||
Run free node
|
Run free node
|
||||||
|
|
@ -302,16 +303,15 @@ class WorkflowEntry:
|
||||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|
||||||
# init workflow run state
|
# init workflow run state
|
||||||
node_config = {
|
node_config: NodeConfigDict = {
|
||||||
"id": node_id,
|
"id": node_id,
|
||||||
"data": node_data,
|
"data": cast(NodeConfigData, node_data),
|
||||||
}
|
}
|
||||||
node: Node = node_cls(
|
node_factory = DifyNodeFactory(
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
config=node_config,
|
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
|
node = node_factory.create_node(node_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# variable selector to variable mapping
|
# variable selector to variable mapping
|
||||||
|
|
|
||||||
|
|
@ -107,19 +107,19 @@ class AppService:
|
||||||
|
|
||||||
if model_instance:
|
if model_instance:
|
||||||
if (
|
if (
|
||||||
model_instance.model == default_model_config["model"]["name"]
|
model_instance.model_name == default_model_config["model"]["name"]
|
||||||
and model_instance.provider == default_model_config["model"]["provider"]
|
and model_instance.provider == default_model_config["model"]["provider"]
|
||||||
):
|
):
|
||||||
default_model_dict = default_model_config["model"]
|
default_model_dict = default_model_config["model"]
|
||||||
else:
|
else:
|
||||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||||
if model_schema is None:
|
if model_schema is None:
|
||||||
raise ValueError(f"model schema not found for model {model_instance.model}")
|
raise ValueError(f"model schema not found for model {model_instance.model_name}")
|
||||||
|
|
||||||
default_model_dict = {
|
default_model_dict = {
|
||||||
"provider": model_instance.provider,
|
"provider": model_instance.provider,
|
||||||
"name": model_instance.model,
|
"name": model_instance.model_name,
|
||||||
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
|
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
|
||||||
"completion_params": {},
|
"completion_params": {},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -252,7 +252,7 @@ class DatasetService:
|
||||||
dataset.updated_by = account.id
|
dataset.updated_by = account.id
|
||||||
dataset.tenant_id = tenant_id
|
dataset.tenant_id = tenant_id
|
||||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
||||||
dataset.embedding_model = embedding_model.model if embedding_model else None
|
dataset.embedding_model = embedding_model.model_name if embedding_model else None
|
||||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
||||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||||
dataset.provider = provider
|
dataset.provider = provider
|
||||||
|
|
@ -384,7 +384,7 @@ class DatasetService:
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
|
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
|
||||||
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
|
model_schema = text_embedding_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ValueError("Model schema not found")
|
raise ValueError("Model schema not found")
|
||||||
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||||
|
|
@ -743,10 +743,12 @@ class DatasetService:
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=data["embedding_model"],
|
model=data["embedding_model"],
|
||||||
)
|
)
|
||||||
filtered_data["embedding_model"] = embedding_model.model
|
embedding_model_name = embedding_model.model_name
|
||||||
|
filtered_data["embedding_model"] = embedding_model_name
|
||||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
embedding_model.provider, embedding_model.model
|
embedding_model.provider,
|
||||||
|
embedding_model_name,
|
||||||
)
|
)
|
||||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||||
except LLMBadRequestError:
|
except LLMBadRequestError:
|
||||||
|
|
@ -876,10 +878,12 @@ class DatasetService:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Apply new embedding model settings
|
# Apply new embedding model settings
|
||||||
filtered_data["embedding_model"] = embedding_model.model
|
embedding_model_name = embedding_model.model_name
|
||||||
|
filtered_data["embedding_model"] = embedding_model_name
|
||||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
embedding_model.provider, embedding_model.model
|
embedding_model.provider,
|
||||||
|
embedding_model_name,
|
||||||
)
|
)
|
||||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||||
|
|
||||||
|
|
@ -955,10 +959,12 @@ class DatasetService:
|
||||||
knowledge_configuration.embedding_model,
|
knowledge_configuration.embedding_model,
|
||||||
)
|
)
|
||||||
dataset.is_multimodal = is_multimodal
|
dataset.is_multimodal = is_multimodal
|
||||||
dataset.embedding_model = embedding_model.model
|
embedding_model_name = embedding_model.model_name
|
||||||
|
dataset.embedding_model = embedding_model_name
|
||||||
dataset.embedding_model_provider = embedding_model.provider
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
embedding_model.provider, embedding_model.model
|
embedding_model.provider,
|
||||||
|
embedding_model_name,
|
||||||
)
|
)
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
elif knowledge_configuration.indexing_technique == "economy":
|
elif knowledge_configuration.indexing_technique == "economy":
|
||||||
|
|
@ -989,10 +995,12 @@ class DatasetService:
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=knowledge_configuration.embedding_model,
|
model=knowledge_configuration.embedding_model,
|
||||||
)
|
)
|
||||||
dataset.embedding_model = embedding_model.model
|
embedding_model_name = embedding_model.model_name
|
||||||
|
dataset.embedding_model = embedding_model_name
|
||||||
dataset.embedding_model_provider = embedding_model.provider
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
embedding_model.provider, embedding_model.model
|
embedding_model.provider,
|
||||||
|
embedding_model_name,
|
||||||
)
|
)
|
||||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||||
current_user.current_tenant_id,
|
current_user.current_tenant_id,
|
||||||
|
|
@ -1049,11 +1057,13 @@ class DatasetService:
|
||||||
skip_embedding_update = True
|
skip_embedding_update = True
|
||||||
if not skip_embedding_update:
|
if not skip_embedding_update:
|
||||||
if embedding_model:
|
if embedding_model:
|
||||||
dataset.embedding_model = embedding_model.model
|
embedding_model_name = embedding_model.model_name
|
||||||
|
dataset.embedding_model = embedding_model_name
|
||||||
dataset.embedding_model_provider = embedding_model.provider
|
dataset.embedding_model_provider = embedding_model.provider
|
||||||
dataset_collection_binding = (
|
dataset_collection_binding = (
|
||||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
embedding_model.provider, embedding_model.model
|
embedding_model.provider,
|
||||||
|
embedding_model_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
|
@ -1884,7 +1894,7 @@ class DocumentService:
|
||||||
embedding_model = model_manager.get_default_model_instance(
|
embedding_model = model_manager.get_default_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||||
)
|
)
|
||||||
dataset_embedding_model = embedding_model.model
|
dataset_embedding_model = embedding_model.model_name
|
||||||
dataset_embedding_model_provider = embedding_model.provider
|
dataset_embedding_model_provider = embedding_model.provider
|
||||||
dataset.embedding_model = dataset_embedding_model
|
dataset.embedding_model = dataset_embedding_model
|
||||||
dataset.embedding_model_provider = dataset_embedding_model_provider
|
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,8 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||||
config=config,
|
config=config,
|
||||||
graph_init_params=init_params,
|
graph_init_params=init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
credentials_provider=MagicMock(),
|
||||||
|
model_factory=MagicMock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
@ -115,7 +117,7 @@ def test_execute_llm():
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
# Mock the _fetch_model_config to avoid database calls
|
# Mock the _fetch_model_config to avoid database calls
|
||||||
def mock_fetch_model_config(**_kwargs):
|
def mock_fetch_model_config(*_args, **_kwargs):
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
@ -227,7 +229,7 @@ def test_execute_llm_with_jinja2():
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
# Mock the _fetch_model_config method
|
# Mock the _fetch_model_config method
|
||||||
def mock_fetch_model_config(**_kwargs):
|
def mock_fetch_model_config(*_args, **_kwargs):
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from core.model_runtime.entities import AssistantPromptMessage
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
@ -84,6 +85,8 @@ def init_parameter_extractor_node(config: dict):
|
||||||
config=config,
|
config=config,
|
||||||
graph_init_params=init_params,
|
graph_init_params=init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||||
|
model_factory=MagicMock(spec=ModelFactory),
|
||||||
)
|
)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -331,7 +331,7 @@ class TestDatasetServiceUpdateDataset:
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_model = Mock()
|
embedding_model = Mock()
|
||||||
embedding_model.model = "text-embedding-ada-002"
|
embedding_model.model_name = "text-embedding-ada-002"
|
||||||
embedding_model.provider = "openai"
|
embedding_model.provider = "openai"
|
||||||
|
|
||||||
binding = Mock()
|
binding = Mock()
|
||||||
|
|
@ -424,7 +424,7 @@ class TestDatasetServiceUpdateDataset:
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_model = Mock()
|
embedding_model = Mock()
|
||||||
embedding_model.model = "text-embedding-3-small"
|
embedding_model.model_name = "text-embedding-3-small"
|
||||||
embedding_model.provider = "openai"
|
embedding_model.provider = "openai"
|
||||||
|
|
||||||
binding = Mock()
|
binding = Mock()
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ class TestCacheEmbeddingDocuments:
|
||||||
Mock: Configured ModelInstance with text embedding capabilities
|
Mock: Configured ModelInstance with text embedding capabilities
|
||||||
"""
|
"""
|
||||||
model_instance = Mock()
|
model_instance = Mock()
|
||||||
model_instance.model = "text-embedding-ada-002"
|
model_instance.model_name = "text-embedding-ada-002"
|
||||||
model_instance.provider = "openai"
|
model_instance.provider = "openai"
|
||||||
model_instance.credentials = {"api_key": "test-key"}
|
model_instance.credentials = {"api_key": "test-key"}
|
||||||
|
|
||||||
|
|
@ -597,7 +597,7 @@ class TestCacheEmbeddingQuery:
|
||||||
def mock_model_instance(self):
|
def mock_model_instance(self):
|
||||||
"""Create a mock ModelInstance for testing."""
|
"""Create a mock ModelInstance for testing."""
|
||||||
model_instance = Mock()
|
model_instance = Mock()
|
||||||
model_instance.model = "text-embedding-ada-002"
|
model_instance.model_name = "text-embedding-ada-002"
|
||||||
model_instance.provider = "openai"
|
model_instance.provider = "openai"
|
||||||
model_instance.credentials = {"api_key": "test-key"}
|
model_instance.credentials = {"api_key": "test-key"}
|
||||||
return model_instance
|
return model_instance
|
||||||
|
|
@ -830,7 +830,7 @@ class TestEmbeddingModelSwitching:
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
model_instance_ada = Mock()
|
model_instance_ada = Mock()
|
||||||
model_instance_ada.model = "text-embedding-ada-002"
|
model_instance_ada.model_name = "text-embedding-ada-002"
|
||||||
model_instance_ada.provider = "openai"
|
model_instance_ada.provider = "openai"
|
||||||
|
|
||||||
# Mock model type instance for ada
|
# Mock model type instance for ada
|
||||||
|
|
@ -841,7 +841,7 @@ class TestEmbeddingModelSwitching:
|
||||||
model_type_instance_ada.get_model_schema.return_value = model_schema_ada
|
model_type_instance_ada.get_model_schema.return_value = model_schema_ada
|
||||||
|
|
||||||
model_instance_3_small = Mock()
|
model_instance_3_small = Mock()
|
||||||
model_instance_3_small.model = "text-embedding-3-small"
|
model_instance_3_small.model_name = "text-embedding-3-small"
|
||||||
model_instance_3_small.provider = "openai"
|
model_instance_3_small.provider = "openai"
|
||||||
|
|
||||||
# Mock model type instance for 3-small
|
# Mock model type instance for 3-small
|
||||||
|
|
@ -914,11 +914,11 @@ class TestEmbeddingModelSwitching:
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
model_instance_openai = Mock()
|
model_instance_openai = Mock()
|
||||||
model_instance_openai.model = "text-embedding-ada-002"
|
model_instance_openai.model_name = "text-embedding-ada-002"
|
||||||
model_instance_openai.provider = "openai"
|
model_instance_openai.provider = "openai"
|
||||||
|
|
||||||
model_instance_cohere = Mock()
|
model_instance_cohere = Mock()
|
||||||
model_instance_cohere.model = "embed-english-v3.0"
|
model_instance_cohere.model_name = "embed-english-v3.0"
|
||||||
model_instance_cohere.provider = "cohere"
|
model_instance_cohere.provider = "cohere"
|
||||||
|
|
||||||
cache_openai = CacheEmbedding(model_instance_openai)
|
cache_openai = CacheEmbedding(model_instance_openai)
|
||||||
|
|
@ -1001,7 +1001,7 @@ class TestEmbeddingDimensionValidation:
|
||||||
def mock_model_instance(self):
|
def mock_model_instance(self):
|
||||||
"""Create a mock ModelInstance for testing."""
|
"""Create a mock ModelInstance for testing."""
|
||||||
model_instance = Mock()
|
model_instance = Mock()
|
||||||
model_instance.model = "text-embedding-ada-002"
|
model_instance.model_name = "text-embedding-ada-002"
|
||||||
model_instance.provider = "openai"
|
model_instance.provider = "openai"
|
||||||
model_instance.credentials = {"api_key": "test-key"}
|
model_instance.credentials = {"api_key": "test-key"}
|
||||||
|
|
||||||
|
|
@ -1123,7 +1123,7 @@ class TestEmbeddingDimensionValidation:
|
||||||
"""
|
"""
|
||||||
# Arrange - OpenAI ada-002 (1536 dimensions)
|
# Arrange - OpenAI ada-002 (1536 dimensions)
|
||||||
model_instance_ada = Mock()
|
model_instance_ada = Mock()
|
||||||
model_instance_ada.model = "text-embedding-ada-002"
|
model_instance_ada.model_name = "text-embedding-ada-002"
|
||||||
model_instance_ada.provider = "openai"
|
model_instance_ada.provider = "openai"
|
||||||
|
|
||||||
# Mock model type instance for ada
|
# Mock model type instance for ada
|
||||||
|
|
@ -1156,7 +1156,7 @@ class TestEmbeddingDimensionValidation:
|
||||||
|
|
||||||
# Arrange - Cohere embed-english-v3.0 (1024 dimensions)
|
# Arrange - Cohere embed-english-v3.0 (1024 dimensions)
|
||||||
model_instance_cohere = Mock()
|
model_instance_cohere = Mock()
|
||||||
model_instance_cohere.model = "embed-english-v3.0"
|
model_instance_cohere.model_name = "embed-english-v3.0"
|
||||||
model_instance_cohere.provider = "cohere"
|
model_instance_cohere.provider = "cohere"
|
||||||
|
|
||||||
# Mock model type instance for cohere
|
# Mock model type instance for cohere
|
||||||
|
|
@ -1225,7 +1225,7 @@ class TestEmbeddingEdgeCases:
|
||||||
- MAX_CHUNKS: 10
|
- MAX_CHUNKS: 10
|
||||||
"""
|
"""
|
||||||
model_instance = Mock()
|
model_instance = Mock()
|
||||||
model_instance.model = "text-embedding-ada-002"
|
model_instance.model_name = "text-embedding-ada-002"
|
||||||
model_instance.provider = "openai"
|
model_instance.provider = "openai"
|
||||||
|
|
||||||
model_type_instance = Mock()
|
model_type_instance = Mock()
|
||||||
|
|
@ -1702,7 +1702,7 @@ class TestEmbeddingCachePerformance:
|
||||||
- MAX_CHUNKS: 10
|
- MAX_CHUNKS: 10
|
||||||
"""
|
"""
|
||||||
model_instance = Mock()
|
model_instance = Mock()
|
||||||
model_instance.model = "text-embedding-ada-002"
|
model_instance.model_name = "text-embedding-ada-002"
|
||||||
model_instance.provider = "openai"
|
model_instance.provider = "openai"
|
||||||
|
|
||||||
model_type_instance = Mock()
|
model_type_instance = Mock()
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ def create_mock_model_instance():
|
||||||
mock_instance.provider_model_bundle.configuration = Mock()
|
mock_instance.provider_model_bundle.configuration = Mock()
|
||||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||||
mock_instance.provider = "test-provider"
|
mock_instance.provider = "test-provider"
|
||||||
mock_instance.model = "test-model"
|
mock_instance.model_name = "test-model"
|
||||||
return mock_instance
|
return mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ class TestRerankModelRunner:
|
||||||
mock_instance.provider_model_bundle.configuration = Mock()
|
mock_instance.provider_model_bundle.configuration = Mock()
|
||||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||||
mock_instance.provider = "test-provider"
|
mock_instance.provider = "test-provider"
|
||||||
mock_instance.model = "test-model"
|
mock_instance.model_name = "test-model"
|
||||||
return mock_instance
|
return mock_instance
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -199,11 +199,32 @@ def test_mock_config_builder():
|
||||||
|
|
||||||
def test_mock_factory_node_type_detection():
|
def test_mock_factory_node_type_detection():
|
||||||
"""Test that MockNodeFactory correctly identifies nodes to mock."""
|
"""Test that MockNodeFactory correctly identifies nodes to mock."""
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
|
from models.enums import UserFrom
|
||||||
|
|
||||||
from .test_mock_factory import MockNodeFactory
|
from .test_mock_factory import MockNodeFactory
|
||||||
|
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id="test",
|
||||||
|
app_id="test",
|
||||||
|
workflow_id="test",
|
||||||
|
graph_config={},
|
||||||
|
user_id="test",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||||
|
start_at=0,
|
||||||
|
total_tokens=0,
|
||||||
|
node_run_steps=0,
|
||||||
|
)
|
||||||
factory = MockNodeFactory(
|
factory = MockNodeFactory(
|
||||||
graph_init_params=None, # Will be set by test
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=None, # Will be set by test
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=None,
|
mock_config=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -288,7 +309,11 @@ def test_workflow_without_auto_mock():
|
||||||
|
|
||||||
def test_register_custom_mock_node():
|
def test_register_custom_mock_node():
|
||||||
"""Test registering a custom mock implementation for a node type."""
|
"""Test registering a custom mock implementation for a node type."""
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||||
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
|
from models.enums import UserFrom
|
||||||
|
|
||||||
from .test_mock_factory import MockNodeFactory
|
from .test_mock_factory import MockNodeFactory
|
||||||
|
|
||||||
|
|
@ -298,9 +323,25 @@ def test_register_custom_mock_node():
|
||||||
# Custom mock implementation
|
# Custom mock implementation
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id="test",
|
||||||
|
app_id="test",
|
||||||
|
workflow_id="test",
|
||||||
|
graph_config={},
|
||||||
|
user_id="test",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||||
|
start_at=0,
|
||||||
|
total_tokens=0,
|
||||||
|
node_run_steps=0,
|
||||||
|
)
|
||||||
factory = MockNodeFactory(
|
factory = MockNodeFactory(
|
||||||
graph_init_params=None,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=None,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=None,
|
mock_config=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
|
|
@ -82,7 +82,7 @@ def _build_branching_graph(
|
||||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||||
llm_data = LLMNodeData(
|
llm_data = LLMNodeData(
|
||||||
title=title,
|
title=title,
|
||||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||||
prompt_template=[
|
prompt_template=[
|
||||||
LLMNodeChatModelMessage(
|
LLMNodeChatModelMessage(
|
||||||
text=prompt_text,
|
text=prompt_text,
|
||||||
|
|
@ -101,6 +101,8 @@ def _build_branching_graph(
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
|
credentials_provider=mock.Mock(),
|
||||||
|
model_factory=mock.Mock(),
|
||||||
)
|
)
|
||||||
return llm_node
|
return llm_node
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
|
|
@ -78,7 +78,7 @@ def _build_llm_human_llm_graph(
|
||||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||||
llm_data = LLMNodeData(
|
llm_data = LLMNodeData(
|
||||||
title=title,
|
title=title,
|
||||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||||
prompt_template=[
|
prompt_template=[
|
||||||
LLMNodeChatModelMessage(
|
LLMNodeChatModelMessage(
|
||||||
text=prompt_text,
|
text=prompt_text,
|
||||||
|
|
@ -97,6 +97,8 @@ def _build_llm_human_llm_graph(
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
|
credentials_provider=mock.Mock(),
|
||||||
|
model_factory=mock.Mock(),
|
||||||
)
|
)
|
||||||
return llm_node
|
return llm_node
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import time
|
import time
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
|
|
@ -85,6 +86,8 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
|
credentials_provider=mock.Mock(),
|
||||||
|
model_factory=mock.Mock(),
|
||||||
)
|
)
|
||||||
return llm_node
|
return llm_node
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ This module provides a MockNodeFactory that automatically detects and mocks node
|
||||||
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
|
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from core.app.workflow.node_factory import DifyNodeFactory
|
from core.app.workflow.node_factory import DifyNodeFactory
|
||||||
|
|
@ -74,7 +75,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||||
NodeType.CODE: MockCodeNode,
|
NodeType.CODE: MockCodeNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
def create_node(self, node_config: dict[str, Any]) -> Node:
|
def create_node(self, node_config: Mapping[str, Any]) -> Node:
|
||||||
"""
|
"""
|
||||||
Create a node instance, using mock implementations for third-party service nodes.
|
Create a node instance, using mock implementations for third-party service nodes.
|
||||||
|
|
||||||
|
|
@ -123,6 +124,16 @@ class MockNodeFactory(DifyNodeFactory):
|
||||||
mock_config=self.mock_config,
|
mock_config=self.mock_config,
|
||||||
http_request_config=self._http_request_config,
|
http_request_config=self._http_request_config,
|
||||||
)
|
)
|
||||||
|
elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}:
|
||||||
|
mock_instance = mock_class(
|
||||||
|
id=node_id,
|
||||||
|
config=node_config,
|
||||||
|
graph_init_params=self.graph_init_params,
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
|
mock_config=self.mock_config,
|
||||||
|
credentials_provider=self._llm_credentials_provider,
|
||||||
|
model_factory=self._llm_model_factory,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
mock_instance = mock_class(
|
mock_instance = mock_class(
|
||||||
id=node_id,
|
id=node_id,
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,33 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo
|
||||||
|
|
||||||
def test_mock_factory_registers_iteration_node():
|
def test_mock_factory_registers_iteration_node():
|
||||||
"""Test that MockNodeFactory has iteration node registered."""
|
"""Test that MockNodeFactory has iteration node registered."""
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
|
from models.enums import UserFrom
|
||||||
|
|
||||||
# Create a MockNodeFactory instance
|
# Create a MockNodeFactory instance
|
||||||
factory = MockNodeFactory(graph_init_params=None, graph_runtime_state=None, mock_config=None)
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id="test",
|
||||||
|
app_id="test",
|
||||||
|
workflow_id="test",
|
||||||
|
graph_config={"nodes": [], "edges": []},
|
||||||
|
user_id="test",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||||
|
start_at=0,
|
||||||
|
total_tokens=0,
|
||||||
|
node_run_steps=0,
|
||||||
|
)
|
||||||
|
factory = MockNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
mock_config=None,
|
||||||
|
)
|
||||||
|
|
||||||
# Check that iteration node is registered
|
# Check that iteration node is registered
|
||||||
assert NodeType.ITERATION in factory._mock_node_types
|
assert NodeType.ITERATION in factory._mock_node_types
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ allowing tests to run without external dependencies.
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
|
|
@ -18,6 +19,7 @@ from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||||
from core.workflow.nodes.http_request import HttpRequestNode
|
from core.workflow.nodes.http_request import HttpRequestNode
|
||||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||||
from core.workflow.nodes.llm import LLMNode
|
from core.workflow.nodes.llm import LLMNode
|
||||||
|
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||||
|
|
@ -42,6 +44,10 @@ class MockNodeMixin:
|
||||||
mock_config: Optional["MockConfig"] = None,
|
mock_config: Optional["MockConfig"] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
|
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||||
|
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||||
|
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id=id,
|
id=id,
|
||||||
config=config,
|
config=config,
|
||||||
|
|
|
||||||
|
|
@ -101,11 +101,32 @@ def test_node_mock_config():
|
||||||
|
|
||||||
def test_mock_factory_detection():
|
def test_mock_factory_detection():
|
||||||
"""Test MockNodeFactory node type detection."""
|
"""Test MockNodeFactory node type detection."""
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
|
from models.enums import UserFrom
|
||||||
|
|
||||||
print("Testing MockNodeFactory detection...")
|
print("Testing MockNodeFactory detection...")
|
||||||
|
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id="test",
|
||||||
|
app_id="test",
|
||||||
|
workflow_id="test",
|
||||||
|
graph_config={},
|
||||||
|
user_id="test",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||||
|
start_at=0,
|
||||||
|
total_tokens=0,
|
||||||
|
node_run_steps=0,
|
||||||
|
)
|
||||||
factory = MockNodeFactory(
|
factory = MockNodeFactory(
|
||||||
graph_init_params=None,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=None,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=None,
|
mock_config=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -133,11 +154,32 @@ def test_mock_factory_detection():
|
||||||
|
|
||||||
def test_mock_factory_registration():
|
def test_mock_factory_registration():
|
||||||
"""Test registering and unregistering mock node types."""
|
"""Test registering and unregistering mock node types."""
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
|
from models.enums import UserFrom
|
||||||
|
|
||||||
print("Testing MockNodeFactory registration...")
|
print("Testing MockNodeFactory registration...")
|
||||||
|
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id="test",
|
||||||
|
app_id="test",
|
||||||
|
workflow_id="test",
|
||||||
|
graph_config={},
|
||||||
|
user_id="test",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||||
|
start_at=0,
|
||||||
|
total_tokens=0,
|
||||||
|
node_run_steps=0,
|
||||||
|
)
|
||||||
factory = MockNodeFactory(
|
factory = MockNodeFactory(
|
||||||
graph_init_params=None,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=None,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=None,
|
mock_config=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from unittest import mock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||||
|
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
|
||||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
|
@ -32,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
|
|
@ -100,6 +102,8 @@ def llm_node(
|
||||||
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState
|
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState
|
||||||
) -> LLMNode:
|
) -> LLMNode:
|
||||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||||
|
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||||
|
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||||
node_config = {
|
node_config = {
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": llm_node_data.model_dump(),
|
"data": llm_node_data.model_dump(),
|
||||||
|
|
@ -109,13 +113,29 @@ def llm_node(
|
||||||
config=node_config,
|
config=node_config,
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
credentials_provider=mock_credentials_provider,
|
||||||
|
model_factory=mock_model_factory,
|
||||||
llm_file_saver=mock_file_saver,
|
llm_file_saver=mock_file_saver,
|
||||||
)
|
)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_config():
|
def model_config(monkeypatch):
|
||||||
|
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass
|
||||||
|
|
||||||
|
def mock_plugin_model_providers(_self):
|
||||||
|
providers = MockModelClass().fetch_model_providers("test")
|
||||||
|
for provider in providers:
|
||||||
|
provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}"
|
||||||
|
return providers
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ModelProviderFactory,
|
||||||
|
"get_plugin_model_providers",
|
||||||
|
mock_plugin_model_providers,
|
||||||
|
)
|
||||||
|
|
||||||
# Create actual provider and model type instances
|
# Create actual provider and model type instances
|
||||||
model_provider_factory = ModelProviderFactory(tenant_id="test")
|
model_provider_factory = ModelProviderFactory(tenant_id="test")
|
||||||
provider_instance = model_provider_factory.get_plugin_model_provider("openai")
|
provider_instance = model_provider_factory.get_plugin_model_provider("openai")
|
||||||
|
|
@ -125,7 +145,7 @@ def model_config():
|
||||||
provider_model_bundle = ProviderModelBundle(
|
provider_model_bundle = ProviderModelBundle(
|
||||||
configuration=ProviderConfiguration(
|
configuration=ProviderConfiguration(
|
||||||
tenant_id="1",
|
tenant_id="1",
|
||||||
provider=provider_instance,
|
provider=provider_instance.declaration,
|
||||||
preferred_provider_type=ProviderType.CUSTOM,
|
preferred_provider_type=ProviderType.CUSTOM,
|
||||||
using_provider_type=ProviderType.CUSTOM,
|
using_provider_type=ProviderType.CUSTOM,
|
||||||
system_configuration=SystemConfiguration(enabled=False),
|
system_configuration=SystemConfiguration(enabled=False),
|
||||||
|
|
@ -153,6 +173,89 @@ def model_config():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity):
|
||||||
|
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||||
|
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||||
|
|
||||||
|
provider_model_bundle = model_config.provider_model_bundle
|
||||||
|
model_type_instance = provider_model_bundle.model_type_instance
|
||||||
|
provider_model = mock.MagicMock()
|
||||||
|
|
||||||
|
model_instance = mock.MagicMock(
|
||||||
|
model_type_instance=model_type_instance,
|
||||||
|
provider_model_bundle=provider_model_bundle,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_credentials_provider.fetch.return_value = {"api_key": "test"}
|
||||||
|
mock_model_factory.init_model_instance.return_value = model_instance
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch.object(
|
||||||
|
provider_model_bundle.configuration.__class__,
|
||||||
|
"get_provider_model",
|
||||||
|
return_value=provider_model,
|
||||||
|
),
|
||||||
|
mock.patch.object(
|
||||||
|
model_type_instance.__class__,
|
||||||
|
"get_model_schema",
|
||||||
|
return_value=model_config.model_schema,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
fetch_model_config(
|
||||||
|
node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||||
|
credentials_provider=mock_credentials_provider,
|
||||||
|
model_factory=mock_model_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo")
|
||||||
|
mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo")
|
||||||
|
provider_model.raise_for_status.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dify_model_access_adapters_call_managers():
|
||||||
|
mock_provider_manager = mock.MagicMock()
|
||||||
|
mock_model_manager = mock.MagicMock()
|
||||||
|
mock_configurations = mock.MagicMock()
|
||||||
|
mock_provider_configuration = mock.MagicMock()
|
||||||
|
mock_provider_model = mock.MagicMock()
|
||||||
|
|
||||||
|
mock_configurations.get.return_value = mock_provider_configuration
|
||||||
|
mock_provider_configuration.get_provider_model.return_value = mock_provider_model
|
||||||
|
mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"}
|
||||||
|
|
||||||
|
credentials_provider = DifyCredentialsProvider(
|
||||||
|
tenant_id="tenant",
|
||||||
|
provider_manager=mock_provider_manager,
|
||||||
|
)
|
||||||
|
model_factory = DifyModelFactory(
|
||||||
|
tenant_id="tenant",
|
||||||
|
model_manager=mock_model_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_provider_manager.get_configurations.return_value = mock_configurations
|
||||||
|
|
||||||
|
credentials_provider.fetch("openai", "gpt-3.5-turbo")
|
||||||
|
model_factory.init_model_instance("openai", "gpt-3.5-turbo")
|
||||||
|
|
||||||
|
mock_provider_manager.get_configurations.assert_called_once_with("tenant")
|
||||||
|
mock_configurations.get.assert_called_once_with("openai")
|
||||||
|
mock_provider_configuration.get_provider_model.assert_called_once_with(
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
mock_provider_configuration.get_current_credentials.assert_called_once_with(
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
mock_provider_model.raise_for_status.assert_called_once()
|
||||||
|
mock_model_manager.get_model_instance.assert_called_once_with(
|
||||||
|
tenant_id="tenant",
|
||||||
|
provider="openai",
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_files_with_file_segment():
|
def test_fetch_files_with_file_segment():
|
||||||
file = File(
|
file = File(
|
||||||
id="1",
|
id="1",
|
||||||
|
|
@ -485,6 +588,8 @@ def test_handle_list_messages_basic(llm_node):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
|
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
|
||||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||||
|
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||||
|
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||||
node_config = {
|
node_config = {
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": llm_node_data.model_dump(),
|
"data": llm_node_data.model_dump(),
|
||||||
|
|
@ -494,6 +599,8 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||||
config=node_config,
|
config=node_config,
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
credentials_provider=mock_credentials_provider,
|
||||||
|
model_factory=mock_model_factory,
|
||||||
llm_file_saver=mock_file_saver,
|
llm_file_saver=mock_file_saver,
|
||||||
)
|
)
|
||||||
return node, mock_file_saver
|
return node, mock_file_saver
|
||||||
|
|
|
||||||
|
|
@ -642,8 +642,16 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||||
|
|
||||||
# Mock embedding model
|
# Mock embedding model
|
||||||
mock_embedding_model = Mock()
|
mock_embedding_model = Mock()
|
||||||
mock_embedding_model.model = "text-embedding-ada-002"
|
mock_embedding_model.model_name = "text-embedding-ada-002"
|
||||||
mock_embedding_model.provider = "openai"
|
mock_embedding_model.provider = "openai"
|
||||||
|
mock_embedding_model.credentials = {}
|
||||||
|
|
||||||
|
mock_model_schema = Mock()
|
||||||
|
mock_model_schema.features = []
|
||||||
|
|
||||||
|
mock_text_embedding_model = Mock()
|
||||||
|
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
|
||||||
|
mock_embedding_model.model_type_instance = mock_text_embedding_model
|
||||||
|
|
||||||
mock_model_instance = Mock()
|
mock_model_instance = Mock()
|
||||||
mock_model_instance.get_model_instance.return_value = mock_embedding_model
|
mock_model_instance.get_model_instance.return_value = mock_embedding_model
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ class DatasetServiceTestDataFactory:
|
||||||
Mock: Embedding model mock with model and provider attributes
|
Mock: Embedding model mock with model and provider attributes
|
||||||
"""
|
"""
|
||||||
embedding_model = Mock()
|
embedding_model = Mock()
|
||||||
embedding_model.model = model
|
embedding_model.model_name = model
|
||||||
embedding_model.provider = provider
|
embedding_model.provider = provider
|
||||||
return embedding_model
|
return embedding_model
|
||||||
|
|
||||||
|
|
@ -434,7 +434,7 @@ class TestDatasetServiceCreateDataset:
|
||||||
# Assert
|
# Assert
|
||||||
assert result.indexing_technique == "high_quality"
|
assert result.indexing_technique == "high_quality"
|
||||||
assert result.embedding_model_provider == embedding_model.provider
|
assert result.embedding_model_provider == embedding_model.provider
|
||||||
assert result.embedding_model == embedding_model.model
|
assert result.embedding_model == embedding_model.model_name
|
||||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ class DatasetCreateTestDataFactory:
|
||||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||||
"""Create a mock embedding model."""
|
"""Create a mock embedding model."""
|
||||||
embedding_model = Mock()
|
embedding_model = Mock()
|
||||||
embedding_model.model = model
|
embedding_model.model_name = model
|
||||||
embedding_model.provider = provider
|
embedding_model.provider = provider
|
||||||
return embedding_model
|
return embedding_model
|
||||||
|
|
||||||
|
|
@ -244,7 +244,7 @@ class TestDatasetServiceCreateEmptyDataset:
|
||||||
# Assert
|
# Assert
|
||||||
assert result.indexing_technique == "high_quality"
|
assert result.indexing_technique == "high_quality"
|
||||||
assert result.embedding_model_provider == embedding_model.provider
|
assert result.embedding_model_provider == embedding_model.provider
|
||||||
assert result.embedding_model == embedding_model.model
|
assert result.embedding_model == embedding_model.model_name
|
||||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue