mirror of https://github.com/langgenius/dify.git
Compare commits
77 Commits
27877b22fc
...
085af2fdfd
| Author | SHA1 | Date |
|---|---|---|
|
|
085af2fdfd | |
|
|
d14635625c | |
|
|
0c3d11f920 | |
|
|
bf422dfd13 | |
|
|
1674f8c2fb | |
|
|
7fe25f1365 | |
|
|
508350ec6a | |
|
|
b0920ecd17 | |
|
|
8b634a9bee | |
|
|
ecd3a964c1 | |
|
|
0589fa423b | |
|
|
27c4faad4f | |
|
|
fbd558762d | |
|
|
075b8bf1ae | |
|
|
49a1fae555 | |
|
|
cc17c8e883 | |
|
|
5d2cb3cd80 | |
|
|
f2c71f3668 | |
|
|
0492ed7034 | |
|
|
dd4f504b39 | |
|
|
75c3ef82d9 | |
|
|
8ca1ebb96d | |
|
|
3f086b97b6 | |
|
|
4a2e9633db | |
|
|
20fc69ae7f | |
|
|
f5cc1c8b75 | |
|
|
6698b42f97 | |
|
|
848a041c25 | |
|
|
29cff809b9 | |
|
|
30deeb6f1c | |
|
|
30dd36505c | |
|
|
65223c8092 | |
|
|
72e3fcd25f | |
|
|
4b4a5c058e | |
|
|
56e0907548 | |
|
|
d956b919a0 | |
|
|
8b6fc07019 | |
|
|
1b1df37d23 | |
|
|
6be7ba2928 | |
|
|
2c8322c7b9 | |
|
|
fdc880bc67 | |
|
|
abda859075 | |
|
|
dc1a68661c | |
|
|
edb261bc90 | |
|
|
407f5f0cde | |
|
|
d7cafc6296 | |
|
|
9336935295 | |
|
|
e5e8c0711c | |
|
|
02e13e6d05 | |
|
|
a942d4c926 | |
|
|
df69997d8e | |
|
|
4ab7ba4f2e | |
|
|
76a23deba7 | |
|
|
25a83065d2 | |
|
|
82b094a2d5 | |
|
|
3c672703bc | |
|
|
33000d1c60 | |
|
|
2809e4cc40 | |
|
|
3f8f1fa003 | |
|
|
6604f8d506 | |
|
|
368fc0bbe5 | |
|
|
6014853d45 | |
|
|
a71b7909fd | |
|
|
1bf296982b | |
|
|
2b6f761dfe | |
|
|
6ecf89e262 | |
|
|
e844edcf26 | |
|
|
244f9e0c11 | |
|
|
abd68d2ea6 | |
|
|
01d97fa2cf | |
|
|
0478023900 | |
|
|
110b8c925e | |
|
|
eae821d645 | |
|
|
282e76b1ee | |
|
|
8384a836b4 | |
|
|
886854eff8 | |
|
|
6a8fa7b54e |
|
|
@ -4,10 +4,9 @@ runs:
|
|||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
node-version-file: web/.nvmrc
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
cache-dependency-path: web/pnpm-lock.yaml
|
||||
run-install: |
|
||||
cwd: ./web
|
||||
run-install: true
|
||||
|
|
|
|||
|
|
@ -84,20 +84,20 @@ jobs:
|
|||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
vp run lint:ci
|
||||
# pnpm run lint:report
|
||||
# continue-on-error: true
|
||||
|
||||
# - name: Annotate Code
|
||||
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
|
||||
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
|
||||
# with:
|
||||
# eslint-report: web/eslint_report.json
|
||||
# github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
|
@ -114,6 +114,13 @@ jobs:
|
|||
working-directory: ./web
|
||||
run: vp run knip
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ jobs:
|
|||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
|
||||
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
|
|||
|
|
@ -353,6 +353,9 @@ BAIDU_VECTOR_DB_SHARD=1
|
|||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_enterprise_telemetry,
|
||||
ext_fastopenapi,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
|
|
@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_commands,
|
||||
ext_fastopenapi,
|
||||
ext_otel,
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from configs import dify_config
|
|||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
|
|
@ -269,7 +270,7 @@ def migrate_knowledge_vector_database():
|
|||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == "hierarchical_model":
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
|
|||
from libs.file_utils import search_file_upwards
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
|
||||
from .extra import ExtraServiceConfig
|
||||
from .feature import FeatureConfig
|
||||
from .middleware import MiddlewareConfig
|
||||
|
|
@ -73,6 +73,8 @@ class DifyConfig(
|
|||
# Enterprise feature configs
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
# Enterprise telemetry configs
|
||||
EnterpriseTelemetryConfig,
|
||||
):
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
|
|
|
|||
|
|
@ -22,3 +22,49 @@ class EnterpriseFeatureConfig(BaseSettings):
|
|||
ENTERPRISE_REQUEST_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseTelemetryConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for enterprise telemetry.
|
||||
"""
|
||||
|
||||
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
|
||||
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_ENDPOINT: str = Field(
|
||||
description="Enterprise OTEL collector endpoint.",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_HEADERS: str = Field(
|
||||
description="Auth headers for OTLP export (key=value,key2=value2).",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_PROTOCOL: str = Field(
|
||||
description="OTLP protocol: 'http' or 'grpc' (default: http).",
|
||||
default="http",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_API_KEY: str = Field(
|
||||
description="Bearer token for enterprise OTLP export authentication.",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
|
||||
description="Include input/output content in traces (privacy toggle).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
ENTERPRISE_SERVICE_NAME: str = Field(
|
||||
description="Service name for OTEL resource.",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
|
||||
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
|
||||
default=1.0,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings):
|
|||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field(
|
||||
description="Auto build row count increment threshold (default is 500)",
|
||||
default=500,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field(
|
||||
description="Auto build row count increment ratio threshold (default is 0.05)",
|
||||
default=0.05,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field(
|
||||
description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)",
|
||||
default=300,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from extensions.ext_database import db
|
|||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
|
|
@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
|||
class BaseApiKeyListResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
token_prefix: str | None = None
|
||||
|
|
@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource):
|
|||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
assert self.resource_type is not None, "resource_type must be set"
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_tenant_id
|
||||
|
|
@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource):
|
|||
class BaseApiKeyResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
|
|
@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
|
@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource):
|
|||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
|
||||
|
|
@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
token_prefix = "ds-"
|
||||
|
|
@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel):
|
|||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
|
|
@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel):
|
|||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
|
|
@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel):
|
|||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
|
|
@ -594,7 +594,7 @@ class AppApi(Resource):
|
|||
args_dict: AppService.ArgsDict = {
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon_type": args.icon_type,
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
|
|
|
|||
|
|
@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
|
|||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||
)
|
||||
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||
.subquery()
|
||||
)
|
||||
|
|
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
|
|||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
|
|||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
app = db.session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
|
|||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
return server
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
|
|
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
|
|||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
|
|
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
|
|||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Literal
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -244,27 +244,25 @@ class ChatMessageListApi(Resource):
|
|||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args.first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
first_message = db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise NotFound("First message not found")
|
||||
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
|
|
@ -272,16 +270,14 @@ class ChatMessageListApi(Resource):
|
|||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args.limit:
|
||||
|
|
@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
||||
count = db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||
)
|
||||
|
||||
return {"count": count}
|
||||
|
||||
|
|
@ -479,7 +479,9 @@ class MessageApi(Resource):
|
|||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
|
|
|||
|
|
@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
|
|||
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
|
||||
if original_app_model_config is None:
|
||||
raise ValueError("Original app model config not found")
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Literal
|
|||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
|
|
@ -75,7 +76,7 @@ class AppSite(Resource):
|
|||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
|
|
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
|||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
|
|
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
|
|||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
|
||||
return app_model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
|
|
@ -112,6 +113,9 @@ class OAuthCallback(Resource):
|
|||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
except ValueError as e:
|
||||
logger.warning("OAuth error with %s", provider, exc_info=True)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ from fields.document_fields import document_status_fields
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import SegmentStatus
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
|
@ -777,7 +777,7 @@ class DatasetIndexingStatusApi(Resource):
|
|||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = "dataset-"
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
|
|
@ -826,7 +826,7 @@ class DatasetApiKeyApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete dataset API key")
|
||||
|
|
|
|||
|
|
@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
created_from=created_from.value,
|
||||
created_from=created_from,
|
||||
created_by_role=self._created_by_role,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class RateLimit:
|
|||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict: dict[str, "RateLimit"] = {}
|
||||
max_active_requests: int
|
||||
|
||||
def __new__(cls, client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
|
|
@ -27,7 +28,13 @@ class RateLimit:
|
|||
return cls._instance_dict[client_id]
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests
|
||||
self.max_active_requests = max_active_requests
|
||||
# Only flush here if this instance has already been fully initialized,
|
||||
# i.e. the Redis key attributes exist. Otherwise, rely on the flush at
|
||||
# the end of initialization below.
|
||||
if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"):
|
||||
self.flush_cache(use_local_value=True)
|
||||
# must be called after max_active_requests is set
|
||||
if self.disabled():
|
||||
return
|
||||
|
|
@ -41,8 +48,6 @@ class RateLimit:
|
|||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
if self.disabled():
|
||||
return
|
||||
self.last_recalculate_time = time.time()
|
||||
# flush max active requests
|
||||
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||
|
|
@ -50,7 +55,8 @@ class RateLimit:
|
|||
else:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
if self.disabled():
|
||||
return
|
||||
# flush max active requests (in-transit request list)
|
||||
if not redis_client.exists(self.active_requests_key):
|
||||
return
|
||||
|
|
|
|||
|
|
@ -181,10 +181,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
|||
class BaseTraceInfo(BaseModel):
|
||||
message_id: str | None = None
|
||||
message_data: Any | None = None
|
||||
inputs: Union[str, dict[str, Any], list] | None = None
|
||||
outputs: Union[str, dict[str, Any], list] | None = None
|
||||
inputs: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
outputs: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
metadata: dict[str, Any]
|
||||
|
|
@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
|
|||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_type(cls, v):
|
||||
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str | dict | list):
|
||||
|
|
@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel):
|
|||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def resolved_trace_id(self) -> str | None:
|
||||
"""Get trace_id with intelligent fallback.
|
||||
|
||||
Priority:
|
||||
1. External trace_id (from X-Trace-Id header)
|
||||
2. workflow_run_id (if this trace type has it)
|
||||
3. message_id (as final fallback)
|
||||
"""
|
||||
if self.trace_id:
|
||||
return self.trace_id
|
||||
|
||||
# Try workflow_run_id (only exists on workflow-related traces)
|
||||
workflow_run_id = getattr(self, "workflow_run_id", None)
|
||||
if workflow_run_id:
|
||||
return workflow_run_id
|
||||
|
||||
# Final fallback to message_id
|
||||
return str(self.message_id) if self.message_id else None
|
||||
|
||||
@property
|
||||
def resolved_parent_context(self) -> tuple[str | None, str | None]:
|
||||
"""Resolve cross-workflow parent linking from metadata.
|
||||
|
||||
Extracts typed parent IDs from the untyped ``parent_trace_context``
|
||||
metadata dict (set by tool_node when invoking nested workflows).
|
||||
|
||||
Returns:
|
||||
(trace_correlation_override, parent_span_id_source) where
|
||||
trace_correlation_override is the outer workflow_run_id and
|
||||
parent_span_id_source is the outer node_execution_id.
|
||||
"""
|
||||
parent_ctx = self.metadata.get("parent_trace_context")
|
||||
if not isinstance(parent_ctx, dict):
|
||||
return None, None
|
||||
trace_override = parent_ctx.get("parent_workflow_run_id")
|
||||
parent_span = parent_ctx.get("parent_node_execution_id")
|
||||
return (
|
||||
trace_override if isinstance(trace_override, str) else None,
|
||||
parent_span if isinstance(parent_span, str) else None,
|
||||
)
|
||||
|
||||
@field_serializer("start_time", "end_time")
|
||||
def serialize_datetime(self, dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
|
|
@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo):
|
|||
workflow_run_version: str
|
||||
error: str | None = None
|
||||
total_tokens: int
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
file_list: list[str]
|
||||
invoked_by: str | None = None
|
||||
query: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
|
@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo):
|
|||
answer_tokens: int
|
||||
total_tokens: int
|
||||
error: str | None = None
|
||||
file_list: Union[str, dict[str, Any], list] | None = None
|
||||
file_list: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
message_file_data: Any | None = None
|
||||
conversation_mode: str
|
||||
gen_ai_server_time_to_first_token: float | None = None
|
||||
|
|
@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo):
|
|||
tool_config: dict[str, Any]
|
||||
time_cost: Union[int, float]
|
||||
tool_parameters: dict[str, Any]
|
||||
file_url: Union[str, None, list] = None
|
||||
file_url: Union[str, None, list[str]] = None
|
||||
|
||||
|
||||
class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
|
|
@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo):
|
|||
tenant_id: str
|
||||
|
||||
|
||||
class PromptGenerationTraceInfo(BaseTraceInfo):
|
||||
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
app_id: str | None = None
|
||||
|
||||
operation_type: str
|
||||
instruction: str
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
model_provider: str
|
||||
model_name: str
|
||||
|
||||
latency: float
|
||||
|
||||
total_price: float | None = None
|
||||
currency: str | None = None
|
||||
|
||||
error: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class WorkflowNodeTraceInfo(BaseTraceInfo):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
tenant_id: str
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
|
||||
status: str
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
|
||||
index: int
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
total_tokens: int = 0
|
||||
total_price: float = 0.0
|
||||
currency: str | None = None
|
||||
|
||||
model_provider: str | None = None
|
||||
model_name: str | None = None
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
|
||||
tool_name: str | None = None
|
||||
|
||||
iteration_id: str | None = None
|
||||
iteration_index: int | None = None
|
||||
loop_id: str | None = None
|
||||
loop_index: int | None = None
|
||||
parallel_id: str | None = None
|
||||
|
||||
node_inputs: Mapping[str, Any] | None = None
|
||||
node_outputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
|
||||
invoked_by: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
|
||||
pass
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
app_id: str
|
||||
trace_info_type: str
|
||||
|
|
@ -128,11 +246,31 @@ trace_info_info_map = {
|
|||
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
|
||||
"ToolTraceInfo": ToolTraceInfo,
|
||||
"GenerateNameTraceInfo": GenerateNameTraceInfo,
|
||||
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
|
||||
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
|
||||
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
|
||||
}
|
||||
|
||||
|
||||
class OperationType(StrEnum):
|
||||
"""Operation type for token metric labels.
|
||||
|
||||
Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output``
|
||||
counters so consumers can break down token usage by operation.
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
NODE_EXECUTION = "node_execution"
|
||||
MESSAGE = "message"
|
||||
RULE_GENERATE = "rule_generate"
|
||||
CODE_GENERATE = "code_generate"
|
||||
STRUCTURED_OUTPUT = "structured_output"
|
||||
INSTRUCTION_MODIFY = "instruction_modify"
|
||||
|
||||
|
||||
class TraceTaskName(StrEnum):
|
||||
CONVERSATION_TRACE = "conversation"
|
||||
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
|
||||
WORKFLOW_TRACE = "workflow"
|
||||
MESSAGE_TRACE = "message"
|
||||
MODERATION_TRACE = "moderation"
|
||||
|
|
@ -140,4 +278,6 @@ class TraceTaskName(StrEnum):
|
|||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||
TOOL_TRACE = "tool"
|
||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||
PROMPT_GENERATION_TRACE = "prompt_generation"
|
||||
NODE_EXECUTION_TRACE = "node_execution"
|
||||
DATASOURCE_TRACE = "datasource"
|
||||
|
|
|
|||
|
|
@ -15,22 +15,32 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
DraftNodeExecutionTrace,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
PromptGenerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
TaskData,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowNodeTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.engine import db
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from models.workflow import WorkflowAppLog
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
|
|
@ -40,9 +50,142 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
|
||||
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
|
||||
app_name = ""
|
||||
workspace_name = ""
|
||||
if not app_id and not tenant_id:
|
||||
return app_name, workspace_name
|
||||
with Session(db.engine) as session:
|
||||
if app_id:
|
||||
name = session.scalar(select(App.name).where(App.id == app_id))
|
||||
if name:
|
||||
app_name = name
|
||||
if tenant_id:
|
||||
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
|
||||
if name:
|
||||
workspace_name = name
|
||||
return app_name, workspace_name
|
||||
|
||||
|
||||
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
|
||||
"builtin": BuiltinToolProvider,
|
||||
"plugin": BuiltinToolProvider,
|
||||
"api": ApiToolProvider,
|
||||
"workflow": WorkflowToolProvider,
|
||||
"mcp": MCPToolProvider,
|
||||
}
|
||||
|
||||
|
||||
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
|
||||
if not credential_id:
|
||||
return ""
|
||||
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
|
||||
if not model_cls:
|
||||
return ""
|
||||
with Session(db.engine) as session:
|
||||
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined]
|
||||
return str(name) if name else ""
|
||||
|
||||
|
||||
def _lookup_llm_credential_info(
|
||||
tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm"
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Lookup LLM credential ID and name for the given provider and model.
|
||||
Returns (credential_id, credential_name).
|
||||
|
||||
Handles async timing issues gracefully - if credential is deleted between lookups,
|
||||
returns the ID but empty name rather than failing.
|
||||
"""
|
||||
if not tenant_id or not provider:
|
||||
return None, ""
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Try to find provider-level or model-level configuration
|
||||
provider_record = session.scalar(
|
||||
select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM,
|
||||
)
|
||||
)
|
||||
|
||||
if not provider_record:
|
||||
return None, ""
|
||||
|
||||
# Check if there's a model-specific config
|
||||
credential_id = None
|
||||
credential_name = ""
|
||||
is_model_level = False
|
||||
|
||||
if model:
|
||||
# Try model-level first
|
||||
model_record = session.scalar(
|
||||
select(ProviderModel).where(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name == provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type,
|
||||
)
|
||||
)
|
||||
|
||||
if model_record and model_record.credential_id:
|
||||
credential_id = model_record.credential_id
|
||||
is_model_level = True
|
||||
|
||||
if not credential_id and provider_record.credential_id:
|
||||
# Fall back to provider-level credential
|
||||
credential_id = provider_record.credential_id
|
||||
is_model_level = False
|
||||
|
||||
# Lookup credential_name if we have credential_id
|
||||
if credential_id:
|
||||
try:
|
||||
if is_model_level:
|
||||
# Query ProviderModelCredential
|
||||
cred_name = session.scalar(
|
||||
select(ProviderModelCredential.credential_name).where(
|
||||
ProviderModelCredential.id == credential_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Query ProviderCredential
|
||||
cred_name = session.scalar(
|
||||
select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id)
|
||||
)
|
||||
|
||||
if cred_name:
|
||||
credential_name = str(cred_name)
|
||||
except Exception as e:
|
||||
# Credential might have been deleted between lookups (async timing)
|
||||
# Return ID but empty name rather than failing
|
||||
logger.warning(
|
||||
"Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s",
|
||||
credential_id,
|
||||
provider,
|
||||
model,
|
||||
str(e),
|
||||
)
|
||||
|
||||
return credential_id, credential_name
|
||||
except Exception as e:
|
||||
# Database query failed or other unexpected error
|
||||
# Return empty rather than propagating error to telemetry emission
|
||||
logger.warning(
|
||||
"Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s",
|
||||
tenant_id,
|
||||
provider,
|
||||
model,
|
||||
str(e),
|
||||
)
|
||||
return None, ""
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, key: str) -> dict[str, Any]:
|
||||
match key:
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
|
|
@ -149,7 +292,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
|||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {key}")
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
|
||||
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
|
@ -314,6 +457,10 @@ class OpsTraceManager:
|
|||
if app_id is None:
|
||||
return None
|
||||
|
||||
# Handle storage_id format (tenant-{uuid}) - not a real app_id
|
||||
if isinstance(app_id, str) and app_id.startswith("tenant-"):
|
||||
return None
|
||||
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
if app is None:
|
||||
|
|
@ -466,8 +613,6 @@ class TraceTask:
|
|||
|
||||
@classmethod
|
||||
def _get_workflow_run_repo(cls):
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
if cls._workflow_run_repo is None:
|
||||
with cls._repo_lock:
|
||||
if cls._workflow_run_repo is None:
|
||||
|
|
@ -478,6 +623,56 @@ class TraceTask:
|
|||
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return cls._workflow_run_repo
|
||||
|
||||
@classmethod
|
||||
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
|
||||
"""Extract user ID from metadata, prioritizing end_user over account.
|
||||
|
||||
Returns the actual user ID (end_user or account) who invoked the workflow,
|
||||
regardless of invoke_from context.
|
||||
"""
|
||||
# Priority 1: End user (external users via API/WebApp)
|
||||
if user_id := metadata.get("from_end_user_id"):
|
||||
return f"end_user:{user_id}"
|
||||
|
||||
# Priority 2: Account user (internal users via console/debugger)
|
||||
if user_id := metadata.get("from_account_id"):
|
||||
return f"account:{user_id}"
|
||||
|
||||
# Priority 3: User (internal users via console/debugger)
|
||||
if user_id := metadata.get("user_id"):
|
||||
return f"user:{user_id}"
|
||||
|
||||
return "anonymous"
|
||||
|
||||
@classmethod
|
||||
def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]:
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
with Session(db.engine) as session:
|
||||
node_executions = session.scalars(
|
||||
select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
total_prompt = 0
|
||||
total_completion = 0
|
||||
|
||||
for node_exec in node_executions:
|
||||
metadata = node_exec.execution_metadata_dict
|
||||
|
||||
prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS)
|
||||
if prompt is not None:
|
||||
total_prompt += prompt
|
||||
|
||||
completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS)
|
||||
if completion is not None:
|
||||
total_completion += completion
|
||||
|
||||
return (total_prompt, total_completion)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
|
|
@ -498,6 +693,8 @@ class TraceTask:
|
|||
self.app_id = None
|
||||
self.trace_id = None
|
||||
self.kwargs = kwargs
|
||||
if user_id is not None and "user_id" not in self.kwargs:
|
||||
self.kwargs["user_id"] = user_id
|
||||
external_trace_id = kwargs.get("external_trace_id")
|
||||
if external_trace_id:
|
||||
self.trace_id = external_trace_id
|
||||
|
|
@ -511,7 +708,7 @@ class TraceTask:
|
|||
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
|
||||
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
|
||||
),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
|
||||
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
|
||||
message_id=self.message_id, timer=self.timer, **self.kwargs
|
||||
),
|
||||
|
|
@ -527,6 +724,9 @@ class TraceTask:
|
|||
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
|
||||
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
|
||||
),
|
||||
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
|
||||
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
|
||||
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
|
||||
}
|
||||
|
||||
return preprocess_map.get(self.trace_type, lambda: None)()
|
||||
|
|
@ -562,6 +762,10 @@ class TraceTask:
|
|||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
|
||||
workflow_run_id=workflow_run_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
|
|
@ -582,7 +786,14 @@ class TraceTask:
|
|||
)
|
||||
message_id = session.scalar(message_data_stmt)
|
||||
|
||||
metadata = {
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
|
|
@ -595,8 +806,14 @@ class TraceTask:
|
|||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
|
||||
parent_trace_context = self.kwargs.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
|
|
@ -611,6 +828,8 @@ class TraceTask:
|
|||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
|
|
@ -618,10 +837,11 @@ class TraceTask:
|
|||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
invoked_by=self._get_user_id_from_metadata(metadata),
|
||||
)
|
||||
return workflow_trace_info
|
||||
|
||||
def message_trace(self, message_id: str | None):
|
||||
def message_trace(self, message_id: str | None, **kwargs):
|
||||
if not message_id:
|
||||
return {}
|
||||
message_data = get_message_data(message_id)
|
||||
|
|
@ -644,6 +864,19 @@ class TraceTask:
|
|||
|
||||
streaming_metrics = self._extract_streaming_metrics(message_data)
|
||||
|
||||
tenant_id = ""
|
||||
with Session(db.engine) as session:
|
||||
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
|
|
@ -655,7 +888,14 @@ class TraceTask:
|
|||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"message_id": message_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": message_data.app_id,
|
||||
"user_id": message_data.from_end_user_id or message_data.from_account_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
message_tokens = message_data.message_tokens
|
||||
|
||||
|
|
@ -672,7 +912,9 @@ class TraceTask:
|
|||
outputs=message_data.answer,
|
||||
file_list=file_list,
|
||||
start_time=created_at,
|
||||
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
end_time=message_data.updated_at
|
||||
if message_data.updated_at and message_data.updated_at > created_at
|
||||
else created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
conversation_mode=conversation_mode,
|
||||
|
|
@ -697,6 +939,8 @@ class TraceTask:
|
|||
"preset_response": moderation_result.preset_response,
|
||||
"query": moderation_result.query,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
|
|
@ -738,6 +982,8 @@ class TraceTask:
|
|||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
|
|
@ -777,6 +1023,52 @@ class TraceTask:
|
|||
if not message_data:
|
||||
return {}
|
||||
|
||||
tenant_id = ""
|
||||
with Session(db.engine) as session:
|
||||
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
doc_list = [doc.model_dump() for doc in documents] if documents else []
|
||||
dataset_ids: set[str] = set()
|
||||
for doc in doc_list:
|
||||
doc_meta = doc.get("metadata") or {}
|
||||
did = doc_meta.get("dataset_id")
|
||||
if did:
|
||||
dataset_ids.add(did)
|
||||
|
||||
embedding_models: dict[str, dict[str, str]] = {}
|
||||
if dataset_ids:
|
||||
with Session(db.engine) as session:
|
||||
rows = session.execute(
|
||||
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
|
||||
Dataset.id.in_(list(dataset_ids))
|
||||
)
|
||||
).all()
|
||||
for row in rows:
|
||||
embedding_models[str(row[0])] = {
|
||||
"embedding_model": row[1] or "",
|
||||
"embedding_model_provider": row[2] or "",
|
||||
}
|
||||
|
||||
# Extract rerank model info from retrieval_model kwargs
|
||||
rerank_model_provider = ""
|
||||
rerank_model_name = ""
|
||||
if "retrieval_model" in kwargs:
|
||||
retrieval_model = kwargs["retrieval_model"]
|
||||
if isinstance(retrieval_model, dict):
|
||||
reranking_model = retrieval_model.get("reranking_model")
|
||||
if isinstance(reranking_model, dict):
|
||||
rerank_model_provider = reranking_model.get("reranking_provider_name", "")
|
||||
rerank_model_name = reranking_model.get("reranking_model_name", "")
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
|
|
@ -787,13 +1079,23 @@ class TraceTask:
|
|||
"agent_based": message_data.agent_based,
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": message_data.app_id,
|
||||
"user_id": message_data.from_end_user_id or message_data.from_account_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"embedding_models": embedding_models,
|
||||
"rerank_model_provider": rerank_model_provider,
|
||||
"rerank_model_name": rerank_model_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
inputs=message_data.query or message_data.inputs,
|
||||
documents=[doc.model_dump() for doc in documents] if documents else [],
|
||||
documents=doc_list,
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
|
|
@ -836,6 +1138,10 @@ class TraceTask:
|
|||
"error": error,
|
||||
"tool_parameters": tool_parameters,
|
||||
}
|
||||
if message_data.workflow_run_id:
|
||||
metadata["workflow_run_id"] = message_data.workflow_run_id
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
file_url = ""
|
||||
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
||||
|
|
@ -890,6 +1196,8 @@ class TraceTask:
|
|||
"conversation_id": conversation_id,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
generate_name_trace_info = GenerateNameTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
|
|
@ -904,6 +1212,182 @@ class TraceTask:
|
|||
|
||||
return generate_name_trace_info
|
||||
|
||||
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
|
||||
tenant_id = kwargs.get("tenant_id", "")
|
||||
user_id = kwargs.get("user_id", "")
|
||||
app_id = kwargs.get("app_id")
|
||||
operation_type = kwargs.get("operation_type", "")
|
||||
instruction = kwargs.get("instruction", "")
|
||||
generated_output = kwargs.get("generated_output", "")
|
||||
|
||||
prompt_tokens = kwargs.get("prompt_tokens", 0)
|
||||
completion_tokens = kwargs.get("completion_tokens", 0)
|
||||
total_tokens = kwargs.get("total_tokens", 0)
|
||||
|
||||
model_provider = kwargs.get("model_provider", "")
|
||||
model_name = kwargs.get("model_name", "")
|
||||
|
||||
latency = kwargs.get("latency", 0.0)
|
||||
|
||||
timer = kwargs.get("timer")
|
||||
start_time = timer.get("start") if timer else None
|
||||
end_time = timer.get("end") if timer else None
|
||||
|
||||
total_price = kwargs.get("total_price")
|
||||
currency = kwargs.get("currency")
|
||||
|
||||
error = kwargs.get("error")
|
||||
|
||||
app_name = None
|
||||
workspace_name = None
|
||||
if app_id:
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
|
||||
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"app_id": app_id or "",
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"operation_type": operation_type,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
return PromptGenerationTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
inputs=instruction,
|
||||
outputs=generated_output,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
metadata=metadata,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=operation_type,
|
||||
instruction=instruction,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
latency=latency,
|
||||
total_price=total_price,
|
||||
currency=currency,
|
||||
error=error,
|
||||
)
|
||||
|
||||
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
|
||||
node_data: dict = kwargs.get("node_execution_data", {})
|
||||
if not node_data:
|
||||
return {}
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(
|
||||
node_data.get("app_id"), node_data.get("tenant_id")
|
||||
)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
# Try tool credential lookup first
|
||||
credential_id = node_data.get("credential_id")
|
||||
if is_enterprise_telemetry_enabled():
|
||||
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
|
||||
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
|
||||
if not credential_id:
|
||||
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
|
||||
tenant_id=node_data.get("tenant_id"),
|
||||
provider=node_data.get("model_provider"),
|
||||
model=node_data.get("model_name"),
|
||||
model_type="llm",
|
||||
)
|
||||
if llm_cred_id:
|
||||
credential_id = llm_cred_id
|
||||
credential_name = llm_cred_name
|
||||
else:
|
||||
credential_name = ""
|
||||
metadata: dict[str, Any] = {
|
||||
"tenant_id": node_data.get("tenant_id"),
|
||||
"app_id": node_data.get("app_id"),
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"user_id": node_data.get("user_id"),
|
||||
"invoke_from": node_data.get("invoke_from"),
|
||||
"credential_id": credential_id,
|
||||
"credential_name": credential_name,
|
||||
"dataset_ids": node_data.get("dataset_ids"),
|
||||
"dataset_names": node_data.get("dataset_names"),
|
||||
"plugin_name": node_data.get("plugin_name"),
|
||||
}
|
||||
|
||||
parent_trace_context = node_data.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
|
||||
message_id: str | None = None
|
||||
conversation_id = node_data.get("conversation_id")
|
||||
workflow_execution_id = node_data.get("workflow_execution_id")
|
||||
if conversation_id and workflow_execution_id and not parent_trace_context:
|
||||
with Session(db.engine) as session:
|
||||
msg_id = session.scalar(
|
||||
select(Message.id).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.workflow_run_id == workflow_execution_id,
|
||||
)
|
||||
)
|
||||
if msg_id:
|
||||
message_id = str(msg_id)
|
||||
metadata["message_id"] = message_id
|
||||
if conversation_id:
|
||||
metadata["conversation_id"] = conversation_id
|
||||
|
||||
return WorkflowNodeTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
start_time=node_data.get("created_at"),
|
||||
end_time=node_data.get("finished_at"),
|
||||
metadata=metadata,
|
||||
workflow_id=node_data.get("workflow_id", ""),
|
||||
workflow_run_id=node_data.get("workflow_execution_id", ""),
|
||||
tenant_id=node_data.get("tenant_id", ""),
|
||||
node_execution_id=node_data.get("node_execution_id", ""),
|
||||
node_id=node_data.get("node_id", ""),
|
||||
node_type=node_data.get("node_type", ""),
|
||||
title=node_data.get("title", ""),
|
||||
status=node_data.get("status", ""),
|
||||
error=node_data.get("error"),
|
||||
elapsed_time=node_data.get("elapsed_time", 0.0),
|
||||
index=node_data.get("index", 0),
|
||||
predecessor_node_id=node_data.get("predecessor_node_id"),
|
||||
total_tokens=node_data.get("total_tokens", 0),
|
||||
total_price=node_data.get("total_price", 0.0),
|
||||
currency=node_data.get("currency"),
|
||||
model_provider=node_data.get("model_provider"),
|
||||
model_name=node_data.get("model_name"),
|
||||
prompt_tokens=node_data.get("prompt_tokens"),
|
||||
completion_tokens=node_data.get("completion_tokens"),
|
||||
tool_name=node_data.get("tool_name"),
|
||||
iteration_id=node_data.get("iteration_id"),
|
||||
iteration_index=node_data.get("iteration_index"),
|
||||
loop_id=node_data.get("loop_id"),
|
||||
loop_index=node_data.get("loop_index"),
|
||||
parallel_id=node_data.get("parallel_id"),
|
||||
node_inputs=node_data.get("node_inputs"),
|
||||
node_outputs=node_data.get("node_outputs"),
|
||||
process_data=node_data.get("process_data"),
|
||||
invoked_by=self._get_user_id_from_metadata(metadata),
|
||||
)
|
||||
|
||||
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
|
||||
node_trace = self.node_execution_trace(**kwargs)
|
||||
if not isinstance(node_trace, WorkflowNodeTraceInfo):
|
||||
return node_trace
|
||||
return DraftNodeExecutionTrace(**node_trace.model_dump())
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
|
|
@ -937,13 +1421,17 @@ class TraceQueueManager:
|
|||
self.user_id = user_id
|
||||
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
||||
self.flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
|
||||
if trace_manager_timer is None:
|
||||
self.start_timer()
|
||||
|
||||
def add_trace_task(self, trace_task: TraceTask):
|
||||
global trace_manager_timer, trace_manager_queue
|
||||
try:
|
||||
if self.trace_instance:
|
||||
if self._enterprise_telemetry_enabled or self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception:
|
||||
|
|
@ -979,20 +1467,27 @@ class TraceQueueManager:
|
|||
def send_to_celery(self, tasks: list[TraceTask]):
|
||||
with self.flask_app.app_context():
|
||||
for task in tasks:
|
||||
if task.app_id is None:
|
||||
continue
|
||||
storage_id = task.app_id
|
||||
if storage_id is None:
|
||||
tenant_id = task.kwargs.get("tenant_id")
|
||||
if tenant_id:
|
||||
storage_id = f"tenant-{tenant_id}"
|
||||
else:
|
||||
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
|
||||
continue
|
||||
|
||||
file_id = uuid4().hex
|
||||
trace_info = task.execute()
|
||||
|
||||
task_data = TaskData(
|
||||
app_id=task.app_id,
|
||||
app_id=storage_id,
|
||||
trace_info_type=type(trace_info).__name__,
|
||||
trace_info=trace_info.model_dump() if trace_info else None,
|
||||
)
|
||||
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
|
||||
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
|
||||
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
|
||||
file_info = {
|
||||
"file_id": file_id,
|
||||
"app_id": task.app_id,
|
||||
"app_id": storage_id,
|
||||
}
|
||||
process_trace_tasks.delay(file_info) # type: ignore
|
||||
|
|
|
|||
|
|
@ -918,11 +918,11 @@ class ProviderManager:
|
|||
|
||||
trail_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL.value,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
paid_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.PAID.value,
|
||||
pool_type=ProviderQuotaType.PAID,
|
||||
)
|
||||
else:
|
||||
trail_pool = None
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore
|
|||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import (
|
||||
AutoBuildRowCountIncrement,
|
||||
Field,
|
||||
FilteringIndex,
|
||||
HNSWParams,
|
||||
|
|
@ -51,6 +52,9 @@ class BaiduConfig(BaseModel):
|
|||
replicas: int = 3
|
||||
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
|
||||
inverted_index_parser_mode: str = "COARSE_MODE"
|
||||
auto_build_row_count_increment: int = 500
|
||||
auto_build_row_count_increment_ratio: float = 0.05
|
||||
rebuild_index_timeout_in_seconds: int = 300
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
@ -107,18 +111,6 @@ class BaiduVector(BaseVector):
|
|||
rows.append(row)
|
||||
table.upsert(rows=rows)
|
||||
|
||||
# rebuild vector index after upsert finished
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.vector_index)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
|
||||
if res and res.code == 0:
|
||||
|
|
@ -232,8 +224,14 @@ class BaiduVector(BaseVector):
|
|||
return self._client.database(self._client_config.database)
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
return any(table.table_name == self._collection_name for table in tables)
|
||||
try:
|
||||
table = self._db.table(self._collection_name)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
return False
|
||||
else:
|
||||
raise
|
||||
return True
|
||||
|
||||
def _create_table(self, dimension: int):
|
||||
# Try to grab distributed lock and create table
|
||||
|
|
@ -287,6 +285,11 @@ class BaiduVector(BaseVector):
|
|||
field=VDBField.VECTOR,
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=True,
|
||||
auto_build_index_policy=AutoBuildRowCountIncrement(
|
||||
row_count_increment=self._client_config.auto_build_row_count_increment,
|
||||
row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -335,7 +338,7 @@ class BaiduVector(BaseVector):
|
|||
)
|
||||
|
||||
# Wait for table created
|
||||
timeout = 300 # 5 minutes timeout
|
||||
timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
|
@ -345,6 +348,20 @@ class BaiduVector(BaseVector):
|
|||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
# rebuild vector index immediately after table created, make sure index is ready
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
self._wait_for_index_ready(table, timeout)
|
||||
|
||||
def _wait_for_index_ready(self, table, timeout: int = 3600):
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.vector_index)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
|
|
@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
|||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
|
||||
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
|
||||
auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT,
|
||||
auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO,
|
||||
rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
|
|
@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||
password=new_cluster["password"],
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status="ACTIVE",
|
||||
status=TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from configs import dify_config
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
class TidbService:
|
||||
|
|
@ -170,7 +171,7 @@ class TidbService:
|
|||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = "ACTIVE"
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -95,15 +95,11 @@ class FirecrawlApp:
|
|||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
total = crawl_status_response.get("total", 0)
|
||||
if total == 0:
|
||||
# Normalize to avoid None bypassing the zero-guard when the API returns null.
|
||||
total = crawl_status_response.get("total") or 0
|
||||
if total <= 0:
|
||||
raise Exception("Failed to check crawl status. Error: No page found")
|
||||
data = crawl_status_response.get("data", [])
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = self._extract_common_fields(item)
|
||||
url_data_list.append(url_data)
|
||||
url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers)
|
||||
if url_data_list:
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
try:
|
||||
|
|
@ -120,6 +116,36 @@ class FirecrawlApp:
|
|||
self._handle_error(response, "check crawl status")
|
||||
raise RuntimeError("unreachable: _handle_error always raises")
|
||||
|
||||
def _collect_all_crawl_pages(
|
||||
self, first_page: dict[str, Any], headers: dict[str, str]
|
||||
) -> list[FirecrawlDocumentData]:
|
||||
"""Collect all crawl result pages by following pagination links.
|
||||
|
||||
Raises an exception if any paginated request fails, to avoid returning
|
||||
partial data that is inconsistent with the reported total.
|
||||
|
||||
The number of pages processed is capped at ``total`` (the
|
||||
server-reported page count) to guard against infinite loops caused by
|
||||
a misbehaving server that keeps returning a ``next`` URL.
|
||||
"""
|
||||
total: int = first_page.get("total") or 0
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
current_page = first_page
|
||||
pages_processed = 0
|
||||
while True:
|
||||
for item in current_page.get("data", []):
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data_list.append(self._extract_common_fields(item))
|
||||
next_url: str | None = current_page.get("next")
|
||||
pages_processed += 1
|
||||
if not next_url or pages_processed >= total:
|
||||
break
|
||||
response = self._get_request(next_url, headers)
|
||||
if response.status_code != 200:
|
||||
self._handle_error(response, "fetch next crawl page")
|
||||
current_page = response.json()
|
||||
return url_data_list
|
||||
|
||||
def _format_crawl_status_response(
|
||||
self,
|
||||
status: str,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
"""Telemetry facade.
|
||||
|
||||
Thin public API for emitting telemetry events. All routing logic
|
||||
lives in ``core.telemetry.gateway`` which is shared by both CE and EE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.telemetry.events import TelemetryContext, TelemetryEvent
|
||||
from core.telemetry.gateway import emit as gateway_emit
|
||||
from core.telemetry.gateway import get_trace_task_to_case
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
|
||||
"""Emit a telemetry event.
|
||||
|
||||
Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a
|
||||
``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``.
|
||||
"""
|
||||
case = get_trace_task_to_case().get(event.name)
|
||||
if case is None:
|
||||
return
|
||||
|
||||
context: dict[str, object] = {
|
||||
"tenant_id": event.context.tenant_id,
|
||||
"user_id": event.context.user_id,
|
||||
"app_id": event.context.app_id,
|
||||
}
|
||||
gateway_emit(case, context, event.payload, trace_manager)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TelemetryContext",
|
||||
"TelemetryEvent",
|
||||
"TraceTaskName",
|
||||
"emit",
|
||||
]
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelemetryContext:
|
||||
tenant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelemetryEvent:
|
||||
name: TraceTaskName
|
||||
context: TelemetryContext
|
||||
payload: dict[str, Any]
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
"""Telemetry gateway — single routing layer for all editions.
|
||||
|
||||
Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either
|
||||
the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only
|
||||
metric/log Celery queue.
|
||||
|
||||
This module lives in ``core/`` so both CE and EE share one routing table
|
||||
and one ``emit()`` entry point. No separate enterprise gateway module is
|
||||
needed — enterprise-specific dispatch (Celery task, payload offloading)
|
||||
is handled here behind lazy imports that no-op in CE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from enterprise.telemetry.contracts import SignalType
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routing table — authoritative mapping for all editions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_case_to_trace_task: dict | None = None
|
||||
_case_routing: dict | None = None
|
||||
|
||||
|
||||
def _get_case_to_trace_task() -> dict:
|
||||
global _case_to_trace_task
|
||||
if _case_to_trace_task is None:
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
_case_to_trace_task = {
|
||||
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
|
||||
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
|
||||
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE,
|
||||
TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE,
|
||||
TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE,
|
||||
}
|
||||
return _case_to_trace_task
|
||||
|
||||
|
||||
def get_trace_task_to_case() -> dict:
|
||||
"""Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task)."""
|
||||
return {v: k for k, v in _get_case_to_trace_task().items()}
|
||||
|
||||
|
||||
def _get_case_routing() -> dict:
|
||||
global _case_routing
|
||||
if _case_routing is None:
|
||||
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase
|
||||
|
||||
_case_routing = {
|
||||
# TRACE — CE-eligible (flow in both CE and EE)
|
||||
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
# TRACE — enterprise-only
|
||||
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
# METRIC_LOG — enterprise-only (signal-driven, not trace)
|
||||
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
}
|
||||
return _case_routing
|
||||
|
||||
|
||||
def __getattr__(name: str) -> dict:
|
||||
"""Lazy module-level access to routing tables."""
|
||||
if name == "CASE_ROUTING":
|
||||
return _get_case_routing()
|
||||
if name == "CASE_TO_TRACE_TASK":
|
||||
return _get_case_to_trace_task()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_enterprise_telemetry_enabled() -> bool:
|
||||
try:
|
||||
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
|
||||
|
||||
return is_enterprise_telemetry_enabled()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _handle_payload_sizing(
|
||||
payload: dict[str, Any],
|
||||
tenant_id: str,
|
||||
event_id: str,
|
||||
) -> tuple[dict[str, Any], str | None]:
|
||||
"""Inline or offload payload based on size.
|
||||
|
||||
Returns ``(payload_for_envelope, storage_key | None)``. Payloads
|
||||
exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object
|
||||
storage and replaced with an empty dict in the envelope.
|
||||
"""
|
||||
try:
|
||||
payload_json = json.dumps(payload)
|
||||
payload_size = len(payload_json.encode("utf-8"))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
|
||||
return payload, None
|
||||
|
||||
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
|
||||
return payload, None
|
||||
|
||||
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
|
||||
try:
|
||||
storage.save(storage_key, payload_json.encode("utf-8"))
|
||||
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
|
||||
return {}, storage_key
|
||||
except Exception:
|
||||
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
|
||||
return payload, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def emit(
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
"""Route a telemetry event to the correct pipeline.
|
||||
|
||||
TRACE events are enqueued into ``TraceQueueManager`` (works in both CE
|
||||
and EE). Enterprise-only traces are silently dropped when EE is
|
||||
disabled.
|
||||
|
||||
METRIC_LOG events are dispatched to the enterprise Celery queue;
|
||||
silently dropped when enterprise telemetry is unavailable.
|
||||
"""
|
||||
route = _get_case_routing().get(case)
|
||||
if route is None:
|
||||
logger.warning("Unknown telemetry case: %s, dropping event", case)
|
||||
return
|
||||
|
||||
if not route.ce_eligible and not is_enterprise_telemetry_enabled():
|
||||
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
|
||||
return
|
||||
|
||||
if route.signal_type == SignalType.TRACE:
|
||||
_emit_trace(case, context, payload, trace_manager)
|
||||
else:
|
||||
_emit_metric_log(case, context, payload)
|
||||
|
||||
|
||||
def _emit_trace(
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
trace_manager: TraceQueueManager | None,
|
||||
) -> None:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
|
||||
from core.ops.ops_trace_manager import TraceTask
|
||||
|
||||
trace_task_name = _get_case_to_trace_task().get(case)
|
||||
if trace_task_name is None:
|
||||
logger.warning("No TraceTaskName mapping for case: %s", case)
|
||||
return
|
||||
|
||||
queue_manager = trace_manager or LocalTraceQueueManager(
|
||||
app_id=context.get("app_id"),
|
||||
user_id=context.get("user_id"),
|
||||
)
|
||||
queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload))
|
||||
logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id"))
|
||||
|
||||
|
||||
def _emit_metric_log(
|
||||
case: TelemetryCase,
|
||||
context: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
"""Build envelope and dispatch to enterprise Celery queue.
|
||||
|
||||
No-ops when the enterprise telemetry task is not importable (CE mode).
|
||||
"""
|
||||
try:
|
||||
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
|
||||
except ImportError:
|
||||
logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case)
|
||||
return
|
||||
|
||||
tenant_id = context.get("tenant_id") or ""
|
||||
event_id = str(uuid.uuid4())
|
||||
|
||||
payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id)
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=case,
|
||||
tenant_id=tenant_id,
|
||||
event_id=event_id,
|
||||
payload=payload_for_envelope,
|
||||
metadata={"payload_ref": payload_ref} if payload_ref else None,
|
||||
)
|
||||
|
||||
process_enterprise_telemetry.delay(envelope.model_dump_json())
|
||||
logger.debug(
|
||||
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
|
||||
case,
|
||||
tenant_id,
|
||||
event_id,
|
||||
)
|
||||
|
|
@ -50,7 +50,7 @@ class BuiltinTool(Tool):
|
|||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class ToolLabelManager:
|
|||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
tool_type=controller.provider_type,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
|
|
@ -58,7 +58,7 @@ class ToolLabelManager:
|
|||
raise ValueError("Unsupported tool type")
|
||||
stmt = select(ToolLabelBinding.label_name).where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
ToolLabelBinding.tool_type == controller.provider_type,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from decimal import Decimal
|
|||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
|
|
@ -78,7 +79,7 @@ class ModelInvocationUtils:
|
|||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
|
|
|||
|
|
@ -0,0 +1,525 @@
|
|||
# Dify Enterprise Telemetry Data Dictionary
|
||||
|
||||
Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md).
|
||||
|
||||
## Resource Attributes
|
||||
|
||||
Attached to every signal (Span, Metric, Log).
|
||||
|
||||
| Attribute | Type | Example |
|
||||
|-----------|------|---------|
|
||||
| `service.name` | string | `dify` |
|
||||
| `host.name` | string | `dify-api-7f8b` |
|
||||
|
||||
## Traces (Spans)
|
||||
|
||||
### `dify.workflow.run`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.trace_id` | string | Business trace ID (Workflow Run ID) |
|
||||
| `dify.tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.workflow.id` | string | Workflow definition ID |
|
||||
| `dify.workflow.run_id` | string | Unique ID for this run |
|
||||
| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. |
|
||||
| `dify.workflow.error` | string | Error message if failed |
|
||||
| `dify.workflow.elapsed_time` | float | Total execution time (seconds) |
|
||||
| `dify.invoke_from` | string | `api`, `webapp`, `debug` |
|
||||
| `dify.conversation.id` | string | Conversation ID (optional) |
|
||||
| `dify.message.id` | string | Message ID (optional) |
|
||||
| `dify.invoked_by` | string | User ID who triggered the run |
|
||||
| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) |
|
||||
| `gen_ai.user.id` | string | End-user identifier (optional) |
|
||||
| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) |
|
||||
| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) |
|
||||
| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) |
|
||||
| `dify.parent.app.id` | string | Parent app ID (optional) |
|
||||
|
||||
### `dify.node.execution`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.trace_id` | string | Business trace ID |
|
||||
| `dify.tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.workflow.id` | string | Workflow definition ID |
|
||||
| `dify.workflow.run_id` | string | Workflow Run ID |
|
||||
| `dify.message.id` | string | Message ID (optional) |
|
||||
| `dify.conversation.id` | string | Conversation ID (optional) |
|
||||
| `dify.node.execution_id` | string | Unique node execution ID |
|
||||
| `dify.node.id` | string | Node ID in workflow graph |
|
||||
| `dify.node.type` | string | Node type (see appendix) |
|
||||
| `dify.node.title` | string | Display title |
|
||||
| `dify.node.status` | string | `succeeded`, `failed` |
|
||||
| `dify.node.error` | string | Error message if failed |
|
||||
| `dify.node.elapsed_time` | float | Execution time (seconds) |
|
||||
| `dify.node.index` | int | Execution order index |
|
||||
| `dify.node.predecessor_node_id` | string | Triggering node ID |
|
||||
| `dify.node.iteration_id` | string | Iteration ID (optional) |
|
||||
| `dify.node.loop_id` | string | Loop ID (optional) |
|
||||
| `dify.node.parallel_id` | string | Parallel branch ID (optional) |
|
||||
| `dify.node.invoked_by` | string | User ID who triggered execution |
|
||||
| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) |
|
||||
| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) |
|
||||
| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) |
|
||||
| `gen_ai.request.model` | string | LLM model name (LLM nodes only) |
|
||||
| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) |
|
||||
| `gen_ai.user.id` | string | End-user identifier (optional) |
|
||||
|
||||
### `dify.node.execution.draft`
|
||||
|
||||
Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs.
|
||||
|
||||
## Counters
|
||||
|
||||
All counters are cumulative and emitted at 100% accuracy.
|
||||
|
||||
### Token Counters
|
||||
|
||||
| Metric | Unit | Description |
|
||||
|--------|------|-------------|
|
||||
| `dify.tokens.total` | `{token}` | Total tokens consumed |
|
||||
| `dify.tokens.input` | `{token}` | Input (prompt) tokens |
|
||||
| `dify.tokens.output` | `{token}` | Output (completion) tokens |
|
||||
|
||||
**Labels:**
|
||||
|
||||
- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution)
|
||||
|
||||
⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting.
|
||||
|
||||
#### Token Hierarchy & Query Patterns
|
||||
|
||||
Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting:
|
||||
|
||||
```
|
||||
App-level total
|
||||
├── workflow ← sum of all node_execution tokens (DO NOT add both)
|
||||
│ └── node_execution ← per-node breakdown
|
||||
├── message ← independent (non-workflow chat apps only)
|
||||
├── rule_generate ← independent helper LLM call
|
||||
├── code_generate ← independent helper LLM call
|
||||
├── structured_output ← independent helper LLM call
|
||||
└── instruction_modify← independent helper LLM call
|
||||
```
|
||||
|
||||
**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both.
|
||||
|
||||
**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`.
|
||||
App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries.
|
||||
|
||||
**Common queries** (PromQL):
|
||||
|
||||
```promql
|
||||
# ── Totals ──────────────────────────────────────────────────
|
||||
# App-level total (exclude node_execution to avoid double-counting)
|
||||
sum by (app_id) (dify_tokens_total{operation_type!="node_execution"})
|
||||
|
||||
# Single app total
|
||||
sum (dify_tokens_total{app_id="<app_id>", operation_type!="node_execution"})
|
||||
|
||||
# Per-tenant totals
|
||||
sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"})
|
||||
|
||||
# ── Drill-down ──────────────────────────────────────────────
|
||||
# Workflow-level tokens for an app
|
||||
sum (dify_tokens_total{app_id="<app_id>", operation_type="workflow"})
|
||||
|
||||
# Node-level breakdown within an app
|
||||
sum by (node_type) (dify_tokens_total{app_id="<app_id>", operation_type="node_execution"})
|
||||
|
||||
# Model breakdown for an app
|
||||
sum by (model_provider, model_name) (dify_tokens_total{app_id="<app_id>"})
|
||||
|
||||
# Input vs output per model
|
||||
sum by (model_name) (dify_tokens_input_total{app_id="<app_id>"})
|
||||
sum by (model_name) (dify_tokens_output_total{app_id="<app_id>"})
|
||||
|
||||
# ── Rates ───────────────────────────────────────────────────
|
||||
# Token consumption rate (per hour)
|
||||
sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
|
||||
|
||||
# Per-app consumption rate
|
||||
sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
|
||||
```
|
||||
|
||||
**Finding `app_id` from app name** (trace query — Tempo / Jaeger):
|
||||
|
||||
```
|
||||
{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id)
|
||||
```
|
||||
|
||||
### Request Counters
|
||||
|
||||
| Metric | Unit | Description |
|
||||
|--------|------|-------------|
|
||||
| `dify.requests.total` | `{request}` | Total operations count |
|
||||
|
||||
**Labels by type:**
|
||||
|
||||
| `type` | Additional Labels |
|
||||
|--------|-------------------|
|
||||
| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` |
|
||||
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
|
||||
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
|
||||
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` |
|
||||
| `tool` | `tenant_id`, `app_id`, `tool_name` |
|
||||
| `moderation` | `tenant_id`, `app_id` |
|
||||
| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
|
||||
| `dataset_retrieval` | `tenant_id`, `app_id` |
|
||||
| `generate_name` | `tenant_id`, `app_id` |
|
||||
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` |
|
||||
|
||||
### Error Counters
|
||||
|
||||
| Metric | Unit | Description |
|
||||
|--------|------|-------------|
|
||||
| `dify.errors.total` | `{error}` | Total failed operations |
|
||||
|
||||
**Labels by type:**
|
||||
|
||||
| `type` | Additional Labels |
|
||||
|--------|-------------------|
|
||||
| `workflow` | `tenant_id`, `app_id` |
|
||||
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
|
||||
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
|
||||
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
|
||||
| `tool` | `tenant_id`, `app_id`, `tool_name` |
|
||||
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
|
||||
|
||||
### Other Counters
|
||||
|
||||
| Metric | Unit | Labels |
|
||||
|--------|------|--------|
|
||||
| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` |
|
||||
| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` |
|
||||
| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` |
|
||||
| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` |
|
||||
| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` |
|
||||
|
||||
## Histograms
|
||||
|
||||
| Metric | Unit | Labels |
|
||||
|--------|------|--------|
|
||||
| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` |
|
||||
| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` |
|
||||
| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
|
||||
| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
|
||||
| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` |
|
||||
| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
|
||||
|
||||
## Structured Logs
|
||||
|
||||
### Span Companion Logs
|
||||
|
||||
Logs that accompany spans. Signal type: `span_detail`
|
||||
|
||||
#### `dify.workflow.run` Companion Log
|
||||
|
||||
**Common attributes:** All span attributes (see Traces section) plus:
|
||||
|
||||
| Additional Attribute | Type | Always Present | Description |
|
||||
|---------------------|------|----------------|-------------|
|
||||
| `dify.app.name` | string | No | Application display name |
|
||||
| `dify.workspace.name` | string | No | Workspace display name |
|
||||
| `dify.workflow.version` | string | Yes | Workflow definition version |
|
||||
| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) |
|
||||
| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) |
|
||||
| `dify.workflow.query` | string | No | User query text (content-gated) |
|
||||
|
||||
**Event attributes:**
|
||||
|
||||
- `dify.event.name`: `"dify.workflow.run"`
|
||||
- `dify.event.signal`: `"span_detail"`
|
||||
- `trace_id`, `span_id`, `tenant_id`, `user_id`
|
||||
|
||||
#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs
|
||||
|
||||
**Common attributes:** All span attributes (see Traces section) plus:
|
||||
|
||||
| Additional Attribute | Type | Always Present | Description |
|
||||
|---------------------|------|----------------|-------------|
|
||||
| `dify.app.name` | string | No | Application display name |
|
||||
| `dify.workspace.name` | string | No | Workspace display name |
|
||||
| `dify.invoke_from` | string | No | Invocation source |
|
||||
| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) |
|
||||
| `dify.node.total_price` | float | No | Cost (LLM nodes only) |
|
||||
| `dify.node.currency` | string | No | Currency code (LLM nodes only) |
|
||||
| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) |
|
||||
| `dify.node.loop_index` | int | No | Loop index (loop nodes) |
|
||||
| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) |
|
||||
| `dify.credential.name` | string | No | Credential name (plugin nodes) |
|
||||
| `dify.credential.id` | string | No | Credential ID (plugin nodes) |
|
||||
| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) |
|
||||
| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) |
|
||||
| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) |
|
||||
| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) |
|
||||
| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) |
|
||||
|
||||
**Event attributes:**
|
||||
|
||||
- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"`
|
||||
- `dify.event.signal`: `"span_detail"`
|
||||
- `trace_id`, `span_id`, `tenant_id`, `user_id`
|
||||
|
||||
### Standalone Logs
|
||||
|
||||
Logs without structural spans. Signal type: `metric_only`
|
||||
|
||||
#### `dify.message.run`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.message.run"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID (32-char hex) |
|
||||
| `span_id` | string | OTEL span ID (16-char hex) |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `user_id` | string | User identifier (optional) |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.conversation.id` | string | Conversation ID (optional) |
|
||||
| `dify.workflow.run_id` | string | Workflow run ID (optional) |
|
||||
| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` |
|
||||
| `gen_ai.provider.name` | string | LLM provider |
|
||||
| `gen_ai.request.model` | string | LLM model |
|
||||
| `gen_ai.usage.input_tokens` | int | Input tokens |
|
||||
| `gen_ai.usage.output_tokens` | int | Output tokens |
|
||||
| `gen_ai.usage.total_tokens` | int | Total tokens |
|
||||
| `dify.message.status` | string | `succeeded`, `failed` |
|
||||
| `dify.message.error` | string | Error message (if failed) |
|
||||
| `dify.message.duration` | float | Duration (seconds) |
|
||||
| `dify.message.time_to_first_token` | float | TTFT (seconds) |
|
||||
| `dify.message.inputs` | string/JSON | Inputs (content-gated) |
|
||||
| `dify.message.outputs` | string/JSON | Outputs (content-gated) |
|
||||
|
||||
#### `dify.tool.execution`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.tool.execution"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.tool.name` | string | Tool name |
|
||||
| `dify.tool.duration` | float | Duration (seconds) |
|
||||
| `dify.tool.status` | string | `succeeded`, `failed` |
|
||||
| `dify.tool.error` | string | Error message (if failed) |
|
||||
| `dify.tool.inputs` | string/JSON | Inputs (content-gated) |
|
||||
| `dify.tool.outputs` | string/JSON | Outputs (content-gated) |
|
||||
| `dify.tool.parameters` | string/JSON | Parameters (content-gated) |
|
||||
| `dify.tool.config` | string/JSON | Configuration (content-gated) |
|
||||
|
||||
#### `dify.moderation.check`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.moderation.check"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.moderation.type` | string | `input`, `output` |
|
||||
| `dify.moderation.action` | string | `pass`, `block`, `flag` |
|
||||
| `dify.moderation.flagged` | boolean | Whether flagged |
|
||||
| `dify.moderation.categories` | JSON array | Flagged categories |
|
||||
| `dify.moderation.query` | string | Content (content-gated) |
|
||||
|
||||
#### `dify.suggested_question.generation`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.suggested_question.generation"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.suggested_question.count` | int | Number of questions |
|
||||
| `dify.suggested_question.duration` | float | Duration (seconds) |
|
||||
| `dify.suggested_question.status` | string | `succeeded`, `failed` |
|
||||
| `dify.suggested_question.error` | string | Error message (if failed) |
|
||||
| `dify.suggested_question.questions` | JSON array | Questions (content-gated) |
|
||||
|
||||
#### `dify.dataset.retrieval`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.dataset.retrieval"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.dataset.id` | string | Dataset identifier |
|
||||
| `dify.dataset.name` | string | Dataset name |
|
||||
| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) |
|
||||
| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) |
|
||||
| `dify.retrieval.rerank_provider` | string | Rerank model provider |
|
||||
| `dify.retrieval.rerank_model` | string | Rerank model name |
|
||||
| `dify.retrieval.query` | string | Search query (content-gated) |
|
||||
| `dify.retrieval.document_count` | int | Documents retrieved |
|
||||
| `dify.retrieval.duration` | float | Duration (seconds) |
|
||||
| `dify.retrieval.status` | string | `succeeded`, `failed` |
|
||||
| `dify.retrieval.error` | string | Error message (if failed) |
|
||||
| `dify.dataset.documents` | JSON array | Documents (content-gated) |
|
||||
|
||||
#### `dify.generate_name.execution`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.generate_name.execution"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.conversation.id` | string | Conversation identifier |
|
||||
| `dify.generate_name.duration` | float | Duration (seconds) |
|
||||
| `dify.generate_name.status` | string | `succeeded`, `failed` |
|
||||
| `dify.generate_name.error` | string | Error message (if failed) |
|
||||
| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) |
|
||||
| `dify.generate_name.outputs` | string | Generated name (content-gated) |
|
||||
|
||||
#### `dify.prompt_generation.execution`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.prompt_generation.execution"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) |
|
||||
| `gen_ai.provider.name` | string | LLM provider |
|
||||
| `gen_ai.request.model` | string | LLM model |
|
||||
| `gen_ai.usage.input_tokens` | int | Input tokens |
|
||||
| `gen_ai.usage.output_tokens` | int | Output tokens |
|
||||
| `gen_ai.usage.total_tokens` | int | Total tokens |
|
||||
| `dify.prompt_generation.duration` | float | Duration (seconds) |
|
||||
| `dify.prompt_generation.status` | string | `succeeded`, `failed` |
|
||||
| `dify.prompt_generation.error` | string | Error message (if failed) |
|
||||
| `dify.prompt_generation.instruction` | string | Instruction (content-gated) |
|
||||
| `dify.prompt_generation.output` | string/JSON | Output (content-gated) |
|
||||
|
||||
#### `dify.app.created`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.app.created"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` |
|
||||
| `dify.app.created_at` | string | Timestamp (ISO 8601) |
|
||||
|
||||
#### `dify.app.updated`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.app.updated"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.app.updated_at` | string | Timestamp (ISO 8601) |
|
||||
|
||||
#### `dify.app.deleted`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.app.deleted"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.app.deleted_at` | string | Timestamp (ISO 8601) |
|
||||
|
||||
#### `dify.feedback.created`
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.feedback.created"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `trace_id` | string | OTEL trace ID |
|
||||
| `span_id` | string | OTEL span ID |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.app_id` | string | Application identifier |
|
||||
| `dify.message.id` | string | Message identifier |
|
||||
| `dify.feedback.rating` | string | `like`, `dislike`, `null` |
|
||||
| `dify.feedback.content` | string | Feedback text (content-gated) |
|
||||
| `dify.feedback.created_at` | string | Timestamp (ISO 8601) |
|
||||
|
||||
#### `dify.telemetry.rehydration_failed`
|
||||
|
||||
Diagnostic event for telemetry system health monitoring.
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` |
|
||||
| `dify.event.signal` | string | `"metric_only"` |
|
||||
| `tenant_id` | string | Tenant identifier |
|
||||
| `dify.telemetry.error` | string | Error message |
|
||||
| `dify.telemetry.payload_type` | string | Payload type (see appendix) |
|
||||
| `dify.telemetry.correlation_id` | string | Correlation ID |
|
||||
|
||||
## Content-Gated Attributes
|
||||
|
||||
When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`).
|
||||
|
||||
| Attribute | Signal |
|
||||
|-----------|--------|
|
||||
| `dify.workflow.inputs` | `dify.workflow.run` |
|
||||
| `dify.workflow.outputs` | `dify.workflow.run` |
|
||||
| `dify.workflow.query` | `dify.workflow.run` |
|
||||
| `dify.node.inputs` | `dify.node.execution` |
|
||||
| `dify.node.outputs` | `dify.node.execution` |
|
||||
| `dify.node.process_data` | `dify.node.execution` |
|
||||
| `dify.message.inputs` | `dify.message.run` |
|
||||
| `dify.message.outputs` | `dify.message.run` |
|
||||
| `dify.tool.inputs` | `dify.tool.execution` |
|
||||
| `dify.tool.outputs` | `dify.tool.execution` |
|
||||
| `dify.tool.parameters` | `dify.tool.execution` |
|
||||
| `dify.tool.config` | `dify.tool.execution` |
|
||||
| `dify.moderation.query` | `dify.moderation.check` |
|
||||
| `dify.suggested_question.questions` | `dify.suggested_question.generation` |
|
||||
| `dify.retrieval.query` | `dify.dataset.retrieval` |
|
||||
| `dify.dataset.documents` | `dify.dataset.retrieval` |
|
||||
| `dify.generate_name.inputs` | `dify.generate_name.execution` |
|
||||
| `dify.generate_name.outputs` | `dify.generate_name.execution` |
|
||||
| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` |
|
||||
| `dify.prompt_generation.output` | `dify.prompt_generation.execution` |
|
||||
| `dify.feedback.content` | `dify.feedback.created` |
|
||||
|
||||
## Appendix
|
||||
|
||||
### Operation Types
|
||||
|
||||
- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify`
|
||||
|
||||
### Node Types
|
||||
|
||||
- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input`
|
||||
|
||||
### Workflow Statuses
|
||||
|
||||
- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused`
|
||||
|
||||
### Payload Types
|
||||
|
||||
- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback`
|
||||
|
||||
### Null Value Behavior
|
||||
|
||||
**Spans:** Attributes with `null` values are omitted.
|
||||
|
||||
**Logs:** Attributes with `null` values appear as `null` in JSON.
|
||||
|
||||
**Content-Gated:** Replaced with reference strings, not set to `null`.
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
# Dify Enterprise Telemetry
|
||||
|
||||
This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb.
|
||||
|
||||
## Overview
|
||||
|
||||
Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage.
|
||||
|
||||
- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes).
|
||||
- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`.
|
||||
- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking.
|
||||
|
||||
### Signal Architecture
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Workflow Run] -->|Span| B(dify.workflow.run)
|
||||
A -->|Log| C(dify.workflow.run detail)
|
||||
B ---|trace_id| C
|
||||
|
||||
D[Node Execution] -->|Span| E(dify.node.execution)
|
||||
D -->|Log| F(dify.node.execution detail)
|
||||
E ---|span_id| F
|
||||
|
||||
G[Message/Tool/etc] -->|Log| H(dify.* event)
|
||||
G -->|Metric| I(dify.* counter/histogram)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The Enterprise OTEL exporter is configured via environment variables.
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` |
|
||||
| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` |
|
||||
| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - |
|
||||
| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - |
|
||||
| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` |
|
||||
| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - |
|
||||
| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `true` |
|
||||
| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` |
|
||||
| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` |
|
||||
|
||||
## Correlation Model
|
||||
|
||||
Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks.
|
||||
|
||||
### ID Generation Rules
|
||||
|
||||
- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))`
|
||||
- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)`
|
||||
|
||||
### Scenario A: Simple Workflow
|
||||
|
||||
A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`).
|
||||
|
||||
```
|
||||
trace_id = UUID(workflow_run_id)
|
||||
├── [root span] dify.workflow.run (span_id = hash(workflow_run_id))
|
||||
│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1))
|
||||
│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2))
|
||||
│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3))
|
||||
```
|
||||
|
||||
### Scenario B: Nested Sub-Workflow
|
||||
|
||||
A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id.
|
||||
|
||||
```
|
||||
trace_id = UUID(outer_workflow_run_id) ← shared across both workflows
|
||||
├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id))
|
||||
│ ├── dify.node.execution - "Start Node"
|
||||
│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow)
|
||||
│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id))
|
||||
│ │ ├── dify.node.execution - "Inner Start"
|
||||
│ │ └── dify.node.execution - "Inner End"
|
||||
│ └── dify.node.execution - "End Node"
|
||||
```
|
||||
|
||||
**Key attributes for nested workflows:**
|
||||
|
||||
- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id`
|
||||
- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id`
|
||||
- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id`
|
||||
- Inner workflow's `dify.parent.app.id` = outer `app_id`
|
||||
|
||||
### Scenario C: Draft Node Execution
|
||||
|
||||
A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root.
|
||||
|
||||
```
|
||||
trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow
|
||||
└── dify.node.execution.draft (span_id = hash(node_execution_id))
|
||||
```
|
||||
|
||||
**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace.
|
||||
|
||||
## Content Gating
|
||||
|
||||
When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector.
|
||||
|
||||
**Reference String Format:**
|
||||
|
||||
```
|
||||
ref:{id_type}={uuid}
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```
|
||||
ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000
|
||||
ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001
|
||||
ref:message_id=770e8400-e29b-41d4-a716-446655440002
|
||||
```
|
||||
|
||||
To retrieve actual content when gating is enabled, query the Dify database using the provided UUID.
|
||||
|
||||
## Reference
|
||||
|
||||
For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md).
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""Telemetry gateway contracts and data structures.
|
||||
|
||||
This module defines the envelope format for telemetry events and the routing
|
||||
configuration that determines how each event type is processed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class TelemetryCase(StrEnum):
|
||||
"""Enumeration of all known telemetry event cases."""
|
||||
|
||||
WORKFLOW_RUN = "workflow_run"
|
||||
NODE_EXECUTION = "node_execution"
|
||||
DRAFT_NODE_EXECUTION = "draft_node_execution"
|
||||
MESSAGE_RUN = "message_run"
|
||||
TOOL_EXECUTION = "tool_execution"
|
||||
MODERATION_CHECK = "moderation_check"
|
||||
SUGGESTED_QUESTION = "suggested_question"
|
||||
DATASET_RETRIEVAL = "dataset_retrieval"
|
||||
GENERATE_NAME = "generate_name"
|
||||
PROMPT_GENERATION = "prompt_generation"
|
||||
APP_CREATED = "app_created"
|
||||
APP_UPDATED = "app_updated"
|
||||
APP_DELETED = "app_deleted"
|
||||
FEEDBACK_CREATED = "feedback_created"
|
||||
|
||||
|
||||
class SignalType(StrEnum):
|
||||
"""Signal routing type for telemetry cases."""
|
||||
|
||||
TRACE = "trace"
|
||||
METRIC_LOG = "metric_log"
|
||||
|
||||
|
||||
class CaseRoute(BaseModel):
|
||||
"""Routing configuration for a telemetry case.
|
||||
|
||||
Attributes:
|
||||
signal_type: The type of signal (trace or metric_log).
|
||||
ce_eligible: Whether this case is eligible for community edition tracing.
|
||||
"""
|
||||
|
||||
signal_type: SignalType
|
||||
ce_eligible: bool
|
||||
|
||||
|
||||
class TelemetryEnvelope(BaseModel):
|
||||
"""Envelope for telemetry events.
|
||||
|
||||
Attributes:
|
||||
case: The telemetry case type.
|
||||
tenant_id: The tenant identifier.
|
||||
event_id: Unique event identifier for deduplication.
|
||||
payload: The main event payload (inline for small payloads,
|
||||
empty when offloaded to storage via ``payload_ref``).
|
||||
metadata: Optional metadata dictionary. When the gateway
|
||||
offloads a large payload to object storage, this contains
|
||||
``{"payload_ref": "<storage_key>"}``.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", use_enum_values=False)
|
||||
|
||||
case: TelemetryCase
|
||||
tenant_id: str
|
||||
event_id: str
|
||||
payload: dict[str, Any]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
def enqueue_draft_node_execution_trace(
|
||||
*,
|
||||
execution: WorkflowNodeExecutionModel,
|
||||
outputs: Mapping[str, Any] | None,
|
||||
workflow_execution_id: str | None,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
node_data = _build_node_execution_data(
|
||||
execution=execution,
|
||||
outputs=outputs,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=execution.tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=execution.app_id,
|
||||
),
|
||||
payload={"node_execution_data": node_data},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_node_execution_data(
|
||||
*,
|
||||
execution: WorkflowNodeExecutionModel,
|
||||
outputs: Mapping[str, Any] | None,
|
||||
workflow_execution_id: str | None,
|
||||
) -> dict[str, Any]:
|
||||
metadata = execution.execution_metadata_dict
|
||||
node_outputs = outputs if outputs is not None else execution.outputs_dict
|
||||
execution_id = workflow_execution_id or execution.workflow_run_id or execution.id
|
||||
|
||||
return {
|
||||
"workflow_id": execution.workflow_id,
|
||||
"workflow_execution_id": execution_id,
|
||||
"tenant_id": execution.tenant_id,
|
||||
"app_id": execution.app_id,
|
||||
"node_execution_id": execution.id,
|
||||
"node_id": execution.node_id,
|
||||
"node_type": execution.node_type,
|
||||
"title": execution.title,
|
||||
"status": execution.status,
|
||||
"error": execution.error,
|
||||
"elapsed_time": execution.elapsed_time,
|
||||
"index": execution.index,
|
||||
"predecessor_node_id": execution.predecessor_node_id,
|
||||
"created_at": execution.created_at,
|
||||
"finished_at": execution.finished_at,
|
||||
"total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
|
||||
"total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
|
||||
"currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
|
||||
"tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
|
||||
if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
|
||||
else None,
|
||||
"iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
|
||||
"iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
|
||||
"loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
|
||||
"loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
|
||||
"parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
|
||||
"node_inputs": execution.inputs_dict,
|
||||
"node_outputs": node_outputs,
|
||||
"process_data": execution.process_data_dict,
|
||||
}
|
||||
|
|
@ -0,0 +1,955 @@
|
|||
"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass.
|
||||
|
||||
Invoked directly in the Celery task, not through OpsTraceManager dispatch.
|
||||
Only requires a matching ``trace(trace_info)`` method signature.
|
||||
|
||||
Signal strategy:
|
||||
- **Traces (spans)**: workflow run, node execution, draft node execution only.
|
||||
- **Metrics + structured logs**: all other event types.
|
||||
|
||||
Token metric labels (unified structure):
|
||||
All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the
|
||||
same label set for consistent filtering and aggregation:
|
||||
- tenant_id: Tenant identifier
|
||||
- app_id: Application identifier
|
||||
- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.)
|
||||
- model_provider: LLM provider name (empty string if not applicable)
|
||||
- model_name: LLM model name (empty string if not applicable)
|
||||
- node_type: Workflow node type (empty string if not node_execution)
|
||||
|
||||
This unified structure allows filtering by operation_type to separate:
|
||||
- Workflow-level aggregates (operation_type=workflow)
|
||||
- Individual node executions (operation_type=node_execution)
|
||||
- Direct message calls (operation_type=message)
|
||||
- Prompt generation operations (operation_type=rule_generate, code_generate, etc.)
|
||||
|
||||
Without this, tokens are double-counted when querying totals (workflow totals include
|
||||
node totals, since workflow.total_tokens is the sum of all node tokens).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
DraftNodeExecutionTrace,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
OperationType,
|
||||
PromptGenerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowNodeTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from enterprise.telemetry.entities import (
|
||||
EnterpriseTelemetryCounter,
|
||||
EnterpriseTelemetryEvent,
|
||||
EnterpriseTelemetryHistogram,
|
||||
EnterpriseTelemetrySpan,
|
||||
TokenMetricLabels,
|
||||
)
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseOtelTrace:
|
||||
"""Duck-typed enterprise trace handler.
|
||||
|
||||
``*_trace`` methods emit spans (workflow/node only) or structured logs
|
||||
(all other events), plus metrics at 100 % accuracy.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if exporter is None:
|
||||
raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized")
|
||||
self._exporter = exporter
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self._workflow_trace(trace_info)
|
||||
elif isinstance(trace_info, MessageTraceInfo):
|
||||
self._message_trace(trace_info)
|
||||
elif isinstance(trace_info, ToolTraceInfo):
|
||||
self._tool_trace(trace_info)
|
||||
elif isinstance(trace_info, DraftNodeExecutionTrace):
|
||||
self._draft_node_execution_trace(trace_info)
|
||||
elif isinstance(trace_info, WorkflowNodeTraceInfo):
|
||||
self._node_execution_trace(trace_info)
|
||||
elif isinstance(trace_info, ModerationTraceInfo):
|
||||
self._moderation_trace(trace_info)
|
||||
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self._suggested_question_trace(trace_info)
|
||||
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self._dataset_retrieval_trace(trace_info)
|
||||
elif isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self._generate_name_trace(trace_info)
|
||||
elif isinstance(trace_info, PromptGenerationTraceInfo):
|
||||
self._prompt_generation_trace(trace_info)
|
||||
|
||||
def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
|
||||
metadata = self._metadata(trace_info)
|
||||
tenant_id, app_id, user_id = self._context_ids(trace_info, metadata)
|
||||
return {
|
||||
"dify.trace_id": trace_info.resolved_trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"dify.message.id": trace_info.message_id,
|
||||
}
|
||||
|
||||
def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
|
||||
return trace_info.metadata
|
||||
|
||||
def _context_ids(
|
||||
self,
|
||||
trace_info: BaseTraceInfo,
|
||||
metadata: dict[str, Any],
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id")
|
||||
app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id")
|
||||
user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id")
|
||||
return tenant_id, app_id, user_id
|
||||
|
||||
def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]:
|
||||
return dict(values)
|
||||
|
||||
def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return cast(dict[str, Any], value)
|
||||
if isinstance(value, list):
|
||||
items: list[object] = []
|
||||
for item in cast(list[object], value):
|
||||
items.append(item)
|
||||
return items
|
||||
return None
|
||||
|
||||
def _content_or_ref(self, value: Any, ref: str) -> Any:
|
||||
if self._exporter.include_content:
|
||||
return self._maybe_json(value)
|
||||
return ref
|
||||
|
||||
def _maybe_json(self, value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value, default=str)
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SPAN-emitting handlers (workflow, node execution, draft node)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _workflow_trace(self, info: WorkflowTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
|
||||
span_attrs: dict[str, Any] = {
|
||||
"dify.trace_id": info.resolved_trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.workflow.id": info.workflow_id,
|
||||
"dify.workflow.run_id": info.workflow_run_id,
|
||||
"dify.workflow.status": info.workflow_run_status,
|
||||
"dify.workflow.error": info.error,
|
||||
"dify.workflow.elapsed_time": info.workflow_run_elapsed_time,
|
||||
"dify.invoke_from": metadata.get("triggered_from"),
|
||||
"dify.conversation.id": info.conversation_id,
|
||||
"dify.message.id": info.message_id,
|
||||
"dify.invoked_by": info.invoked_by,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"gen_ai.user.id": user_id,
|
||||
}
|
||||
|
||||
trace_correlation_override, parent_span_id_source = info.resolved_parent_context
|
||||
|
||||
parent_ctx = metadata.get("parent_trace_context")
|
||||
if isinstance(parent_ctx, dict):
|
||||
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
|
||||
span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id")
|
||||
span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id")
|
||||
span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id")
|
||||
span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id")
|
||||
|
||||
self._exporter.export_span(
|
||||
EnterpriseTelemetrySpan.WORKFLOW_RUN,
|
||||
span_attrs,
|
||||
correlation_id=info.workflow_run_id,
|
||||
span_id_source=info.workflow_run_id,
|
||||
start_time=info.start_time,
|
||||
end_time=info.end_time,
|
||||
trace_correlation_override=trace_correlation_override,
|
||||
parent_span_id_source=parent_span_id_source,
|
||||
)
|
||||
|
||||
# -- Companion log: ALL attrs (span + detail) for full picture --
|
||||
log_attrs: dict[str, Any] = {**span_attrs}
|
||||
log_attrs.update(
|
||||
{
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.workflow.version": info.workflow_run_version,
|
||||
}
|
||||
)
|
||||
|
||||
ref = f"ref:workflow_run_id={info.workflow_run_id}"
|
||||
log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref)
|
||||
log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref)
|
||||
log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref)
|
||||
|
||||
emit_telemetry_log(
|
||||
event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN,
|
||||
attributes=log_attrs,
|
||||
signal="span_detail",
|
||||
trace_id_source=info.workflow_run_id,
|
||||
span_id_source=info.workflow_run_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# -- Metrics --
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.WORKFLOW,
|
||||
model_provider="",
|
||||
model_name="",
|
||||
node_type="",
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens is not None and info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
|
||||
if info.completion_tokens is not None and info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
invoke_from = metadata.get("triggered_from", "")
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="workflow",
|
||||
status=info.workflow_run_status,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
)
|
||||
# Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults
|
||||
# to 0 in the DB and can be stale if the Celery write races with the trace task.
|
||||
# start_time = workflow_run.created_at, end_time = workflow_run.finished_at.
|
||||
if info.start_time and info.end_time:
|
||||
workflow_duration = (info.end_time - info.start_time).total_seconds()
|
||||
elif info.workflow_run_elapsed_time:
|
||||
workflow_duration = float(info.workflow_run_elapsed_time)
|
||||
else:
|
||||
workflow_duration = 0.0
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.WORKFLOW_DURATION,
|
||||
workflow_duration,
|
||||
self._labels(
|
||||
**labels,
|
||||
status=info.workflow_run_status,
|
||||
),
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="workflow",
|
||||
),
|
||||
)
|
||||
|
||||
def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None:
|
||||
self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node")
|
||||
|
||||
def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None:
|
||||
self._emit_node_execution_trace(
|
||||
info,
|
||||
EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION,
|
||||
"draft_node",
|
||||
correlation_id_override=info.node_execution_id,
|
||||
trace_correlation_override_param=info.workflow_run_id,
|
||||
)
|
||||
|
||||
def _emit_node_execution_trace(
|
||||
self,
|
||||
info: WorkflowNodeTraceInfo,
|
||||
span_name: EnterpriseTelemetrySpan,
|
||||
request_type: str,
|
||||
correlation_id_override: str | None = None,
|
||||
trace_correlation_override_param: str | None = None,
|
||||
) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
|
||||
span_attrs: dict[str, Any] = {
|
||||
"dify.trace_id": info.resolved_trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"dify.app_id": app_id,
|
||||
"dify.workflow.id": info.workflow_id,
|
||||
"dify.workflow.run_id": info.workflow_run_id,
|
||||
"dify.message.id": info.message_id,
|
||||
"dify.conversation.id": metadata.get("conversation_id"),
|
||||
"dify.node.execution_id": info.node_execution_id,
|
||||
"dify.node.id": info.node_id,
|
||||
"dify.node.type": info.node_type,
|
||||
"dify.node.title": info.title,
|
||||
"dify.node.status": info.status,
|
||||
"dify.node.error": info.error,
|
||||
"dify.node.elapsed_time": info.elapsed_time,
|
||||
"dify.node.index": info.index,
|
||||
"dify.node.predecessor_node_id": info.predecessor_node_id,
|
||||
"dify.node.iteration_id": info.iteration_id,
|
||||
"dify.node.loop_id": info.loop_id,
|
||||
"dify.node.parallel_id": info.parallel_id,
|
||||
"dify.node.invoked_by": info.invoked_by,
|
||||
"gen_ai.usage.input_tokens": info.prompt_tokens,
|
||||
"gen_ai.usage.output_tokens": info.completion_tokens,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"gen_ai.request.model": info.model_name,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.user.id": user_id,
|
||||
}
|
||||
|
||||
resolved_override, _ = info.resolved_parent_context
|
||||
trace_correlation_override = trace_correlation_override_param or resolved_override
|
||||
|
||||
effective_correlation_id = correlation_id_override or info.workflow_run_id
|
||||
self._exporter.export_span(
|
||||
span_name,
|
||||
span_attrs,
|
||||
correlation_id=effective_correlation_id,
|
||||
span_id_source=info.node_execution_id,
|
||||
start_time=info.start_time,
|
||||
end_time=info.end_time,
|
||||
trace_correlation_override=trace_correlation_override,
|
||||
)
|
||||
|
||||
# -- Companion log: ALL attrs (span + detail) --
|
||||
log_attrs: dict[str, Any] = {**span_attrs}
|
||||
log_attrs.update(
|
||||
{
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"dify.invoke_from": metadata.get("invoke_from"),
|
||||
"gen_ai.user.id": user_id,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.node.total_price": info.total_price,
|
||||
"dify.node.currency": info.currency,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_name,
|
||||
"gen_ai.tool.name": info.tool_name,
|
||||
"dify.node.iteration_index": info.iteration_index,
|
||||
"dify.node.loop_index": info.loop_index,
|
||||
"dify.plugin.name": metadata.get("plugin_name"),
|
||||
"dify.credential.name": metadata.get("credential_name"),
|
||||
"dify.credential.id": metadata.get("credential_id"),
|
||||
"dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")),
|
||||
"dify.dataset.names": self._maybe_json(metadata.get("dataset_names")),
|
||||
}
|
||||
)
|
||||
|
||||
ref = f"ref:node_execution_id={info.node_execution_id}"
|
||||
log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref)
|
||||
log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref)
|
||||
log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref)
|
||||
|
||||
emit_telemetry_log(
|
||||
event_name=span_name.value,
|
||||
attributes=log_attrs,
|
||||
signal="span_detail",
|
||||
trace_id_source=info.workflow_run_id,
|
||||
span_id_source=info.node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# -- Metrics --
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
node_type=info.node_type,
|
||||
model_provider=info.model_provider or "",
|
||||
)
|
||||
if info.total_tokens:
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.NODE_EXECUTION,
|
||||
model_provider=info.model_provider or "",
|
||||
model_name=info.model_name or "",
|
||||
node_type=info.node_type,
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens is not None and info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels
|
||||
)
|
||||
if info.completion_tokens is not None and info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type=request_type,
|
||||
status=info.status,
|
||||
model_name=info.model_name or "",
|
||||
),
|
||||
)
|
||||
duration_labels = dict(labels)
|
||||
duration_labels["model_name"] = info.model_name or ""
|
||||
plugin_name = metadata.get("plugin_name")
|
||||
if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}:
|
||||
duration_labels["plugin_name"] = plugin_name
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type=request_type,
|
||||
model_name=info.model_name or "",
|
||||
),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# METRIC-ONLY handlers (structured log + counters/histograms)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _message_trace(self, info: MessageTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"dify.invoke_from": metadata.get("from_source"),
|
||||
"dify.conversation.id": metadata.get("conversation_id"),
|
||||
"dify.conversation.mode": info.conversation_mode,
|
||||
"gen_ai.provider.name": metadata.get("ls_provider"),
|
||||
"gen_ai.request.model": metadata.get("ls_model_name"),
|
||||
"gen_ai.usage.input_tokens": info.message_tokens,
|
||||
"gen_ai.usage.output_tokens": info.answer_tokens,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.message.status": metadata.get("status"),
|
||||
"dify.message.error": info.error,
|
||||
"dify.message.from_source": metadata.get("from_source"),
|
||||
"dify.message.from_end_user_id": metadata.get("from_end_user_id"),
|
||||
"dify.message.from_account_id": metadata.get("from_account_id"),
|
||||
"dify.streaming": info.is_streaming_request,
|
||||
"dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token,
|
||||
"dify.message.streaming_duration": info.llm_streaming_time_to_generate,
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
inputs = self._safe_payload_value(info.inputs)
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref)
|
||||
attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.MESSAGE_RUN,
|
||||
attributes=attrs,
|
||||
trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None),
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
model_provider=metadata.get("ls_provider") or "",
|
||||
model_name=metadata.get("ls_model_name") or "",
|
||||
)
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=OperationType.MESSAGE,
|
||||
model_provider=metadata.get("ls_provider") or "",
|
||||
model_name=metadata.get("ls_model_name") or "",
|
||||
node_type="",
|
||||
).to_dict()
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.message_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels)
|
||||
if info.answer_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels)
|
||||
invoke_from = metadata.get("from_source", "")
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="message",
|
||||
status=metadata.get("status", ""),
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
)
|
||||
|
||||
if info.start_time and info.end_time:
|
||||
duration = (info.end_time - info.start_time).total_seconds()
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels)
|
||||
|
||||
if info.gen_ai_server_time_to_first_token is not None:
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="message",
|
||||
),
|
||||
)
|
||||
|
||||
def _tool_trace(self, info: ToolTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"gen_ai.tool.name": info.tool_name,
|
||||
"dify.tool.time_cost": info.time_cost,
|
||||
"dify.tool.error": info.error,
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref)
|
||||
attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref)
|
||||
attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref)
|
||||
attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
tool_name=info.tool_name,
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="tool",
|
||||
),
|
||||
)
|
||||
self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="tool",
|
||||
),
|
||||
)
|
||||
|
||||
def _moderation_trace(self, info: ModerationTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs.update(
|
||||
{
|
||||
"dify.moderation.flagged": info.flagged,
|
||||
"dify.moderation.action": info.action,
|
||||
"dify.moderation.preset_response": info.preset_response,
|
||||
"dify.moderation.type": "unknown",
|
||||
"dify.moderation.categories": self._maybe_json([]),
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
attrs["dify.moderation.query"] = self._content_or_ref(
|
||||
info.query,
|
||||
f"ref:message_id={info.message_id}",
|
||||
)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.MODERATION_CHECK,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="moderation",
|
||||
),
|
||||
)
|
||||
|
||||
def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
duration: float | None = None
|
||||
if info.start_time is not None and info.end_time is not None:
|
||||
duration = (info.end_time - info.start_time).total_seconds()
|
||||
error = info.error or (info.metadata.get("error") if info.metadata else None)
|
||||
status = "failed" if error else (info.status or "succeeded")
|
||||
attrs.update(
|
||||
{
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.suggested_question.status": status,
|
||||
"dify.suggested_question.error": error,
|
||||
"dify.suggested_question.duration": duration,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_id,
|
||||
"dify.suggested_question.count": len(info.suggested_question),
|
||||
"dify.workflow.run_id": metadata.get("workflow_run_id"),
|
||||
}
|
||||
)
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
attrs["dify.suggested_question.questions"] = self._content_or_ref(
|
||||
info.suggested_question,
|
||||
f"ref:message_id={info.message_id}",
|
||||
)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="suggested_question",
|
||||
model_provider=info.model_provider or "",
|
||||
model_name=info.model_id or "",
|
||||
),
|
||||
)
|
||||
|
||||
def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs["dify.dataset.error"] = info.error
|
||||
attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id")
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
docs: list[dict[str, Any]] = []
|
||||
documents_any: Any = info.documents
|
||||
documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else []
|
||||
for entry in documents_list:
|
||||
if isinstance(entry, dict):
|
||||
entry_dict: dict[str, Any] = cast(dict[str, Any], entry)
|
||||
docs.append(entry_dict)
|
||||
dataset_ids: list[str] = []
|
||||
dataset_names: list[str] = []
|
||||
structured_docs: list[dict[str, Any]] = []
|
||||
for doc in docs:
|
||||
meta_raw = doc.get("metadata")
|
||||
meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {}
|
||||
did = meta.get("dataset_id")
|
||||
dname = meta.get("dataset_name")
|
||||
if did and did not in dataset_ids:
|
||||
dataset_ids.append(did)
|
||||
if dname and dname not in dataset_names:
|
||||
dataset_names.append(dname)
|
||||
structured_docs.append(
|
||||
{
|
||||
"dataset_id": did,
|
||||
"document_id": meta.get("document_id"),
|
||||
"segment_id": meta.get("segment_id"),
|
||||
"score": meta.get("score"),
|
||||
}
|
||||
)
|
||||
|
||||
attrs["dify.dataset.ids"] = self._maybe_json(dataset_ids)
|
||||
attrs["dify.dataset.names"] = self._maybe_json(dataset_names)
|
||||
attrs["dify.retrieval.document_count"] = len(docs)
|
||||
|
||||
embedding_models_raw: Any = metadata.get("embedding_models")
|
||||
embedding_models: dict[str, Any] = (
|
||||
cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {}
|
||||
)
|
||||
if embedding_models:
|
||||
providers: list[str] = []
|
||||
models: list[str] = []
|
||||
for ds_info in embedding_models.values():
|
||||
if isinstance(ds_info, dict):
|
||||
ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info)
|
||||
p = ds_info_dict.get("embedding_model_provider", "")
|
||||
m = ds_info_dict.get("embedding_model", "")
|
||||
if p and p not in providers:
|
||||
providers.append(p)
|
||||
if m and m not in models:
|
||||
models.append(m)
|
||||
attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers)
|
||||
attrs["dify.dataset.embedding_models"] = self._maybe_json(models)
|
||||
|
||||
# Add rerank model to logs
|
||||
rerank_provider = metadata.get("rerank_model_provider", "")
|
||||
rerank_model = metadata.get("rerank_model_name", "")
|
||||
if rerank_provider or rerank_model:
|
||||
attrs["dify.retrieval.rerank_provider"] = rerank_provider
|
||||
attrs["dify.retrieval.rerank_model"] = rerank_model
|
||||
|
||||
ref = f"ref:message_id={info.message_id}"
|
||||
retrieval_inputs = self._safe_payload_value(info.inputs)
|
||||
attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref)
|
||||
attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL,
|
||||
attributes=attrs,
|
||||
trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None),
|
||||
span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="dataset_retrieval",
|
||||
),
|
||||
)
|
||||
|
||||
for did in dataset_ids:
|
||||
# Get embedding model for this specific dataset
|
||||
ds_embedding_info = embedding_models.get(did, {})
|
||||
embedding_provider = ds_embedding_info.get("embedding_model_provider", "")
|
||||
embedding_model = ds_embedding_info.get("embedding_model", "")
|
||||
|
||||
# Get rerank model (same for all datasets in this retrieval)
|
||||
rerank_provider = metadata.get("rerank_model_provider", "")
|
||||
rerank_model = metadata.get("rerank_model_name", "")
|
||||
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.DATASET_RETRIEVALS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
dataset_id=did,
|
||||
embedding_model_provider=embedding_provider,
|
||||
embedding_model=embedding_model,
|
||||
rerank_model_provider=rerank_provider,
|
||||
rerank_model=rerank_model,
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = self._common_attrs(info)
|
||||
attrs["dify.conversation.id"] = info.conversation_id
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
duration: float | None = None
|
||||
if info.start_time is not None and info.end_time is not None:
|
||||
duration = (info.end_time - info.start_time).total_seconds()
|
||||
error: str | None = metadata.get("error") if metadata else None
|
||||
status = "failed" if error else "succeeded"
|
||||
attrs["dify.generate_name.duration"] = duration
|
||||
attrs["dify.generate_name.status"] = status
|
||||
attrs["dify.generate_name.error"] = error
|
||||
|
||||
ref = f"ref:conversation_id={info.conversation_id}"
|
||||
inputs = self._safe_payload_value(info.inputs)
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref)
|
||||
attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
)
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="generate_name",
|
||||
),
|
||||
)
|
||||
|
||||
def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None:
|
||||
metadata = self._metadata(info)
|
||||
tenant_id, app_id, user_id = self._context_ids(info, metadata)
|
||||
attrs = {
|
||||
"dify.trace_id": info.resolved_trace_id,
|
||||
"dify.tenant_id": tenant_id,
|
||||
"gen_ai.user.id": user_id,
|
||||
"dify.app.id": app_id or "",
|
||||
"dify.app.name": metadata.get("app_name"),
|
||||
"dify.workspace.name": metadata.get("workspace_name"),
|
||||
"dify.prompt_generation.operation_type": info.operation_type,
|
||||
"gen_ai.provider.name": info.model_provider,
|
||||
"gen_ai.request.model": info.model_name,
|
||||
"gen_ai.usage.input_tokens": info.prompt_tokens,
|
||||
"gen_ai.usage.output_tokens": info.completion_tokens,
|
||||
"gen_ai.usage.total_tokens": info.total_tokens,
|
||||
"dify.prompt_generation.latency": info.latency,
|
||||
"dify.prompt_generation.error": info.error,
|
||||
}
|
||||
node_execution_id = metadata.get("node_execution_id")
|
||||
if node_execution_id:
|
||||
attrs["dify.node.execution_id"] = node_execution_id
|
||||
|
||||
if info.total_price is not None:
|
||||
attrs["dify.prompt_generation.total_price"] = info.total_price
|
||||
attrs["dify.prompt_generation.currency"] = info.currency
|
||||
|
||||
ref = f"ref:trace_id={info.trace_id}"
|
||||
outputs = self._safe_payload_value(info.outputs)
|
||||
attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref)
|
||||
attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref)
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION,
|
||||
attributes=attrs,
|
||||
trace_id_source=info.resolved_trace_id,
|
||||
span_id_source=node_execution_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
token_labels = TokenMetricLabels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=info.operation_type,
|
||||
model_provider=info.model_provider,
|
||||
model_name=info.model_name,
|
||||
node_type="",
|
||||
).to_dict()
|
||||
|
||||
labels = self._labels(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id or "",
|
||||
operation_type=info.operation_type,
|
||||
model_provider=info.model_provider,
|
||||
model_name=info.model_name,
|
||||
)
|
||||
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
|
||||
if info.prompt_tokens > 0:
|
||||
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
|
||||
if info.completion_tokens > 0:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
|
||||
)
|
||||
|
||||
status = "failed" if info.error else "success"
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.REQUESTS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="prompt_generation",
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
|
||||
self._exporter.record_histogram(
|
||||
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION,
|
||||
info.latency,
|
||||
labels,
|
||||
)
|
||||
|
||||
if info.error:
|
||||
self._exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.ERRORS,
|
||||
1,
|
||||
self._labels(
|
||||
**labels,
|
||||
type="prompt_generation",
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
from enum import StrEnum
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class EnterpriseTelemetrySpan(StrEnum):
|
||||
WORKFLOW_RUN = "dify.workflow.run"
|
||||
NODE_EXECUTION = "dify.node.execution"
|
||||
DRAFT_NODE_EXECUTION = "dify.node.execution.draft"
|
||||
|
||||
|
||||
class EnterpriseTelemetryEvent(StrEnum):
|
||||
"""Event names for enterprise telemetry logs."""
|
||||
|
||||
APP_CREATED = "dify.app.created"
|
||||
APP_UPDATED = "dify.app.updated"
|
||||
APP_DELETED = "dify.app.deleted"
|
||||
FEEDBACK_CREATED = "dify.feedback.created"
|
||||
WORKFLOW_RUN = "dify.workflow.run"
|
||||
MESSAGE_RUN = "dify.message.run"
|
||||
TOOL_EXECUTION = "dify.tool.execution"
|
||||
MODERATION_CHECK = "dify.moderation.check"
|
||||
SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation"
|
||||
DATASET_RETRIEVAL = "dify.dataset.retrieval"
|
||||
GENERATE_NAME_EXECUTION = "dify.generate_name.execution"
|
||||
PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution"
|
||||
REHYDRATION_FAILED = "dify.telemetry.rehydration_failed"
|
||||
|
||||
|
||||
class EnterpriseTelemetryCounter(StrEnum):
|
||||
TOKENS = "tokens"
|
||||
INPUT_TOKENS = "input_tokens"
|
||||
OUTPUT_TOKENS = "output_tokens"
|
||||
REQUESTS = "requests"
|
||||
ERRORS = "errors"
|
||||
FEEDBACK = "feedback"
|
||||
DATASET_RETRIEVALS = "dataset_retrievals"
|
||||
APP_CREATED = "app_created"
|
||||
APP_UPDATED = "app_updated"
|
||||
APP_DELETED = "app_deleted"
|
||||
|
||||
|
||||
class EnterpriseTelemetryHistogram(StrEnum):
|
||||
WORKFLOW_DURATION = "workflow_duration"
|
||||
NODE_DURATION = "node_duration"
|
||||
MESSAGE_DURATION = "message_duration"
|
||||
MESSAGE_TTFT = "message_ttft"
|
||||
TOOL_DURATION = "tool_duration"
|
||||
PROMPT_GENERATION_DURATION = "prompt_generation_duration"
|
||||
|
||||
|
||||
class TokenMetricLabels(BaseModel):
|
||||
"""Unified label structure for all dify.token.* metrics.
|
||||
|
||||
All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST
|
||||
use this exact label set to ensure consistent filtering and aggregation across
|
||||
different operation types.
|
||||
|
||||
Attributes:
|
||||
tenant_id: Tenant identifier.
|
||||
app_id: Application identifier.
|
||||
operation_type: Source of token usage (workflow | node_execution | message |
|
||||
rule_generate | code_generate | structured_output | instruction_modify).
|
||||
model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level).
|
||||
model_name: LLM model name. Empty string if not applicable (e.g., workflow-level).
|
||||
node_type: Workflow node type. Empty string unless operation_type=node_execution.
|
||||
|
||||
Usage:
|
||||
labels = TokenMetricLabels(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
operation_type=OperationType.WORKFLOW,
|
||||
model_provider="",
|
||||
model_name="",
|
||||
node_type="",
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS,
|
||||
100,
|
||||
labels.to_dict()
|
||||
)
|
||||
|
||||
Design rationale:
|
||||
Without this unified structure, tokens get double-counted when querying totals
|
||||
because workflow.total_tokens is already the sum of all node tokens. The
|
||||
operation_type label allows filtering to separate workflow-level aggregates from
|
||||
node-level detail, while keeping the same label cardinality for consistent queries.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
operation_type: str
|
||||
model_provider: str
|
||||
model_name: str
|
||||
node_type: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
def to_dict(self) -> dict[str, AttributeValue]:
|
||||
return cast(
|
||||
dict[str, AttributeValue],
|
||||
{
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
"operation_type": self.operation_type,
|
||||
"model_provider": self.model_provider,
|
||||
"model_name": self.model_name,
|
||||
"node_type": self.node_type,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EnterpriseTelemetryCounter",
|
||||
"EnterpriseTelemetryEvent",
|
||||
"EnterpriseTelemetryHistogram",
|
||||
"EnterpriseTelemetrySpan",
|
||||
"TokenMetricLabels",
|
||||
]
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
"""Blinker signal handlers for enterprise telemetry.
|
||||
|
||||
Registered at import time via ``@signal.connect`` decorators.
|
||||
Import must happen during ``ext_enterprise_telemetry.init_app()`` to
|
||||
ensure handlers fire. Each handler delegates to ``core.telemetry.gateway``
|
||||
which handles routing, EE-gating, and dispatch.
|
||||
|
||||
All handlers are best-effort: exceptions are caught and logged so that
|
||||
telemetry failures never break user-facing operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from events.app_event import app_was_created
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"_handle_app_created",
|
||||
]
|
||||
|
||||
|
||||
@app_was_created.connect
|
||||
def _handle_app_created(sender: object, **kwargs: object) -> None:
|
||||
try:
|
||||
from core.telemetry.gateway import emit as gateway_emit
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
gateway_emit(
|
||||
case=TelemetryCase.APP_CREATED,
|
||||
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
|
||||
payload={
|
||||
"app_id": getattr(sender, "id", None),
|
||||
"mode": getattr(sender, "mode", None),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to emit app_created telemetry", exc_info=True)
|
||||
|
|
@ -0,0 +1,289 @@
|
|||
"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation.
|
||||
|
||||
Uses dedicated TracerProvider and MeterProvider instances (configurable sampling,
|
||||
independent from ext_otel.py infrastructure).
|
||||
|
||||
Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py).
|
||||
Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace import SpanContext, TraceFlags
|
||||
from opentelemetry.util.types import Attributes, AttributeValue
|
||||
|
||||
from configs import dify_config
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
|
||||
from enterprise.telemetry.id_generator import (
|
||||
CorrelationIdGenerator,
|
||||
compute_deterministic_span_id,
|
||||
set_correlation_id,
|
||||
set_span_id_source,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_enterprise_telemetry_enabled() -> bool:
|
||||
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
|
||||
|
||||
|
||||
def _parse_otlp_headers(raw: str) -> dict[str, str]:
|
||||
"""Parse ``key=value,key2=value2`` into a dict."""
|
||||
if not raw:
|
||||
return {}
|
||||
headers: dict[str, str] = {}
|
||||
for pair in raw.split(","):
|
||||
if "=" not in pair:
|
||||
continue
|
||||
k, v = pair.split("=", 1)
|
||||
headers[k.strip().lower()] = v.strip()
|
||||
return headers
|
||||
|
||||
|
||||
def _datetime_to_ns(dt: datetime) -> int:
|
||||
"""Convert a datetime to nanoseconds since epoch (OTEL convention)."""
|
||||
# Ensure we always interpret naive datetimes as UTC instead of local time.
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
else:
|
||||
dt = dt.astimezone(UTC)
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
class _ExporterFactory:
|
||||
def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool):
|
||||
self._protocol = protocol
|
||||
self._endpoint = endpoint
|
||||
self._headers = headers
|
||||
self._grpc_headers = tuple(headers.items()) if headers else None
|
||||
self._http_headers = headers or None
|
||||
self._insecure = insecure
|
||||
|
||||
def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter:
|
||||
if self._protocol == "grpc":
|
||||
return GRPCSpanExporter(
|
||||
endpoint=self._endpoint or None,
|
||||
headers=self._grpc_headers,
|
||||
insecure=self._insecure,
|
||||
)
|
||||
trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else ""
|
||||
return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers)
|
||||
|
||||
def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter:
|
||||
if self._protocol == "grpc":
|
||||
return GRPCMetricExporter(
|
||||
endpoint=self._endpoint or None,
|
||||
headers=self._grpc_headers,
|
||||
insecure=self._insecure,
|
||||
)
|
||||
metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else ""
|
||||
return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers)
|
||||
|
||||
|
||||
class EnterpriseExporter:
|
||||
"""Shared OTEL exporter for all enterprise telemetry.
|
||||
|
||||
``export_span`` creates spans with optional real timestamps, deterministic
|
||||
span/trace IDs, and cross-workflow parent linking.
|
||||
``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy.
|
||||
"""
|
||||
|
||||
def __init__(self, config: object) -> None:
|
||||
endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "")
|
||||
headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "")
|
||||
protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower()
|
||||
service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify")
|
||||
sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0)
|
||||
self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True)
|
||||
api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "")
|
||||
|
||||
# Auto-detect TLS: https:// uses secure, everything else is insecure
|
||||
insecure = not endpoint.startswith("https://")
|
||||
|
||||
resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
sampler = ParentBasedTraceIdRatio(sampling_rate)
|
||||
id_generator = CorrelationIdGenerator()
|
||||
self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator)
|
||||
|
||||
headers = _parse_otlp_headers(headers_raw)
|
||||
if api_key:
|
||||
if "authorization" in headers:
|
||||
logger.warning(
|
||||
"ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains "
|
||||
"'authorization'; the API key will take precedence."
|
||||
)
|
||||
headers["authorization"] = f"Bearer {api_key}"
|
||||
factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure)
|
||||
|
||||
trace_exporter = factory.create_trace_exporter()
|
||||
self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
|
||||
self._tracer = self._tracer_provider.get_tracer("dify.enterprise")
|
||||
|
||||
metric_exporter = factory.create_metric_exporter()
|
||||
self._meter_provider = MeterProvider(
|
||||
resource=resource,
|
||||
metric_readers=[PeriodicExportingMetricReader(metric_exporter)],
|
||||
)
|
||||
meter = self._meter_provider.get_meter("dify.enterprise")
|
||||
self._counters = {
|
||||
EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"),
|
||||
EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"),
|
||||
EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"),
|
||||
EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"),
|
||||
EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter(
|
||||
"dify.dataset.retrievals.total", unit="{retrieval}"
|
||||
),
|
||||
EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"),
|
||||
EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"),
|
||||
EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"),
|
||||
}
|
||||
self._histograms = {
|
||||
EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram(
|
||||
"dify.message.time_to_first_token", unit="s"
|
||||
),
|
||||
EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"),
|
||||
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram(
|
||||
"dify.prompt_generation.duration", unit="s"
|
||||
),
|
||||
}
|
||||
|
||||
def export_span(
|
||||
self,
|
||||
name: str,
|
||||
attributes: dict[str, Any],
|
||||
correlation_id: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
trace_correlation_override: str | None = None,
|
||||
parent_span_id_source: str | None = None,
|
||||
) -> None:
|
||||
"""Export an OTEL span with optional deterministic IDs and real timestamps.
|
||||
|
||||
Args:
|
||||
name: Span operation name.
|
||||
attributes: Span attributes dict.
|
||||
correlation_id: Source for trace_id derivation (groups spans in one trace).
|
||||
span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id).
|
||||
start_time: Real span start time. When None, uses current time.
|
||||
end_time: Real span end time. When None, span ends immediately.
|
||||
trace_correlation_override: Override trace_id source (for cross-workflow linking).
|
||||
When set, trace_id is derived from this instead of ``correlation_id``.
|
||||
parent_span_id_source: Override parent span_id source (for cross-workflow linking).
|
||||
When set, parent span_id is derived from this value. When None and
|
||||
``correlation_id`` is set, parent is the workflow root span.
|
||||
"""
|
||||
effective_trace_correlation = trace_correlation_override or correlation_id
|
||||
set_correlation_id(effective_trace_correlation)
|
||||
set_span_id_source(span_id_source)
|
||||
|
||||
try:
|
||||
parent_context: Context | None = None
|
||||
# A span is the "root" of its correlation group when span_id_source == correlation_id
|
||||
# (i.e. a workflow root span). All other spans are children.
|
||||
if parent_span_id_source:
|
||||
# Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow)
|
||||
parent_span_id = compute_deterministic_span_id(parent_span_id_source)
|
||||
try:
|
||||
parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(
|
||||
"Invalid trace correlation UUID for cross-workflow link: %s, span=%s",
|
||||
effective_trace_correlation,
|
||||
name,
|
||||
)
|
||||
parent_trace_id = 0
|
||||
if parent_trace_id:
|
||||
parent_span_context = SpanContext(
|
||||
trace_id=parent_trace_id,
|
||||
span_id=parent_span_id,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
|
||||
elif correlation_id and correlation_id != span_id_source:
|
||||
# Child span: parent is the correlation-group root (workflow root span)
|
||||
parent_span_id = compute_deterministic_span_id(correlation_id)
|
||||
try:
|
||||
parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id))
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(
|
||||
"Invalid trace correlation UUID for child span link: %s, span=%s",
|
||||
effective_trace_correlation or correlation_id,
|
||||
name,
|
||||
)
|
||||
parent_trace_id = 0
|
||||
if parent_trace_id:
|
||||
parent_span_context = SpanContext(
|
||||
trace_id=parent_trace_id,
|
||||
span_id=parent_span_id,
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
|
||||
|
||||
span_start_time = _datetime_to_ns(start_time) if start_time is not None else None
|
||||
span_end_on_exit = end_time is None
|
||||
|
||||
with self._tracer.start_as_current_span(
|
||||
name,
|
||||
context=parent_context,
|
||||
start_time=span_start_time,
|
||||
end_on_exit=span_end_on_exit,
|
||||
) as span:
|
||||
for key, value in attributes.items():
|
||||
if value is not None:
|
||||
span.set_attribute(key, value)
|
||||
if end_time is not None:
|
||||
span.end(end_time=_datetime_to_ns(end_time))
|
||||
except Exception:
|
||||
logger.exception("Failed to export span %s", name)
|
||||
finally:
|
||||
set_correlation_id(None)
|
||||
set_span_id_source(None)
|
||||
|
||||
def increment_counter(
|
||||
self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue]
|
||||
) -> None:
|
||||
counter = self._counters.get(name)
|
||||
if counter:
|
||||
counter.add(value, cast(Attributes, labels))
|
||||
|
||||
def record_histogram(
|
||||
self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue]
|
||||
) -> None:
|
||||
histogram = self._histograms.get(name)
|
||||
if histogram:
|
||||
histogram.record(value, cast(Attributes, labels))
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tracer_provider.shutdown()
|
||||
self._meter_provider.shutdown()
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
|
||||
|
||||
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
|
||||
When a span_id_source is set, the span_id is derived deterministically
|
||||
from that value, enabling any span to reference another as parent
|
||||
without depending on span creation order.
|
||||
"""
|
||||
|
||||
import random
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
from opentelemetry.sdk.trace.id_generator import IdGenerator
|
||||
|
||||
_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None)
|
||||
_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None)
|
||||
|
||||
|
||||
def set_correlation_id(correlation_id: str | None) -> None:
|
||||
_correlation_id_context.set(correlation_id)
|
||||
|
||||
|
||||
def get_correlation_id() -> str | None:
|
||||
return _correlation_id_context.get()
|
||||
|
||||
|
||||
def set_span_id_source(source_id: str | None) -> None:
|
||||
"""Set the source for deterministic span_id generation.
|
||||
|
||||
When set, ``generate_span_id()`` derives the span_id from this value
|
||||
(lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow
|
||||
root spans or ``node_execution_id`` for node spans.
|
||||
"""
|
||||
_span_id_source_context.set(source_id)
|
||||
|
||||
|
||||
def compute_deterministic_span_id(source_id: str) -> int:
|
||||
"""Derive a deterministic span_id from any UUID string.
|
||||
|
||||
Uses the lower 64 bits of the UUID, guaranteeing non-zero output
|
||||
(OTEL requires span_id != 0).
|
||||
"""
|
||||
span_id = uuid.UUID(source_id).int & ((1 << 64) - 1)
|
||||
return span_id if span_id != 0 else 1
|
||||
|
||||
|
||||
class CorrelationIdGenerator(IdGenerator):
|
||||
"""ID generator that derives trace_id and optionally span_id from context.
|
||||
|
||||
- trace_id: always derived from correlation_id (groups all spans in one trace)
|
||||
- span_id: derived from span_id_source when set (enables deterministic
|
||||
parent-child linking), otherwise random
|
||||
"""
|
||||
|
||||
def generate_trace_id(self) -> int:
|
||||
correlation_id = _correlation_id_context.get()
|
||||
if correlation_id:
|
||||
try:
|
||||
return uuid.UUID(correlation_id).int
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
return random.getrandbits(128)
|
||||
|
||||
def generate_span_id(self) -> int:
|
||||
source = _span_id_source_context.get()
|
||||
if source:
|
||||
try:
|
||||
return compute_deterministic_span_id(source)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == 0:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
|
|
@ -0,0 +1,421 @@
|
|||
"""Enterprise metric/log event handler.
|
||||
|
||||
This module processes metric and log telemetry events after they've been
|
||||
dequeued from the enterprise_telemetry Celery queue. It handles case routing,
|
||||
idempotency checking, and payload rehydration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnterpriseMetricHandler:
|
||||
"""Handler for enterprise metric and log telemetry events.
|
||||
|
||||
Processes envelopes from the enterprise_telemetry queue, routing each
|
||||
case to the appropriate handler method. Implements idempotency checking
|
||||
and payload rehydration with fallback.
|
||||
"""
|
||||
|
||||
def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None:
|
||||
"""Increment a diagnostic counter for operational monitoring.
|
||||
|
||||
Args:
|
||||
counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total').
|
||||
labels: Optional labels for the counter.
|
||||
"""
|
||||
try:
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
return
|
||||
|
||||
full_counter_name = f"enterprise_telemetry.handler.{counter_name}"
|
||||
logger.debug(
|
||||
"Diagnostic counter: %s, labels=%s",
|
||||
full_counter_name,
|
||||
labels or {},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True)
|
||||
|
||||
def handle(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Main entry point for processing telemetry envelopes.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope to process.
|
||||
"""
|
||||
# Check for duplicate events
|
||||
if self._is_duplicate(envelope):
|
||||
logger.debug(
|
||||
"Skipping duplicate event: tenant_id=%s, event_id=%s",
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
self._increment_diagnostic_counter("deduped_total")
|
||||
return
|
||||
|
||||
# Route to appropriate handler based on case
|
||||
case = envelope.case
|
||||
if case == TelemetryCase.APP_CREATED:
|
||||
self._on_app_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
|
||||
elif case == TelemetryCase.APP_UPDATED:
|
||||
self._on_app_updated(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
|
||||
elif case == TelemetryCase.APP_DELETED:
|
||||
self._on_app_deleted(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
|
||||
elif case == TelemetryCase.FEEDBACK_CREATED:
|
||||
self._on_feedback_created(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
|
||||
elif case == TelemetryCase.MESSAGE_RUN:
|
||||
self._on_message_run(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
|
||||
elif case == TelemetryCase.TOOL_EXECUTION:
|
||||
self._on_tool_execution(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
|
||||
elif case == TelemetryCase.MODERATION_CHECK:
|
||||
self._on_moderation_check(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
|
||||
elif case == TelemetryCase.SUGGESTED_QUESTION:
|
||||
self._on_suggested_question(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
|
||||
elif case == TelemetryCase.DATASET_RETRIEVAL:
|
||||
self._on_dataset_retrieval(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
|
||||
elif case == TelemetryCase.GENERATE_NAME:
|
||||
self._on_generate_name(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
|
||||
elif case == TelemetryCase.PROMPT_GENERATION:
|
||||
self._on_prompt_generation(envelope)
|
||||
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
|
||||
case,
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
)
|
||||
|
||||
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
|
||||
"""Check if this event has already been processed.
|
||||
|
||||
Uses Redis with TTL for deduplication. Returns True if duplicate,
|
||||
False if first time seeing this event.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope to check.
|
||||
|
||||
Returns:
|
||||
True if this event_id has been seen before, False otherwise.
|
||||
"""
|
||||
dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}"
|
||||
|
||||
try:
|
||||
# Atomic set-if-not-exists with 1h TTL
|
||||
# Returns True if key was set (first time), None if already exists (duplicate)
|
||||
was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600)
|
||||
return was_set is None
|
||||
except Exception:
|
||||
# Fail open: if Redis is unavailable, process the event
|
||||
# (prefer occasional duplicate over lost data)
|
||||
logger.warning(
|
||||
"Redis unavailable for deduplication check, processing event anyway: %s",
|
||||
envelope.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]:
|
||||
"""Rehydrate payload from storage reference or inline data.
|
||||
|
||||
If the envelope payload is empty and metadata contains a
|
||||
``payload_ref``, the full payload is loaded from object storage
|
||||
(where the gateway wrote it as JSON). When both the inline
|
||||
payload and storage resolution fail, a degraded-event marker
|
||||
is emitted so the gap is observable.
|
||||
|
||||
Args:
|
||||
envelope: The telemetry envelope containing payload data.
|
||||
|
||||
Returns:
|
||||
The rehydrated payload dictionary, or ``{}`` on total failure.
|
||||
"""
|
||||
payload = envelope.payload
|
||||
|
||||
# Resolve from object storage when the gateway offloaded a large payload.
|
||||
if not payload and envelope.metadata:
|
||||
payload_ref = envelope.metadata.get("payload_ref")
|
||||
if payload_ref:
|
||||
try:
|
||||
payload_bytes = storage.load(payload_ref)
|
||||
payload = json.loads(payload_bytes.decode("utf-8"))
|
||||
logger.debug("Loaded payload from storage: key=%s", payload_ref)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to load payload from storage: key=%s, event_id=%s",
|
||||
payload_ref,
|
||||
envelope.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if not payload:
|
||||
# Storage resolution failed or no data available — emit degraded event.
|
||||
logger.error(
|
||||
"Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s",
|
||||
envelope.event_id,
|
||||
envelope.tenant_id,
|
||||
envelope.case,
|
||||
)
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED,
|
||||
attributes={
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event_id": envelope.event_id,
|
||||
"dify.case": envelope.case,
|
||||
"rehydration_failed": True,
|
||||
},
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
self._increment_diagnostic_counter("rehydration_failed_total")
|
||||
return {}
|
||||
|
||||
return payload
|
||||
|
||||
# Stub methods for each metric/log case
|
||||
# These will be implemented in later tasks with actual emission logic
|
||||
|
||||
def _on_app_created(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app created event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
"dify.app.mode": payload.get("mode"),
|
||||
"dify.app.created_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_CREATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_CREATED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
"mode": str(payload.get("mode", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_app_updated(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app updated event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
"dify.app.updated_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_UPDATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_UPDATED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle app deleted event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
attrs = {
|
||||
"dify.app.id": payload.get("app_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
"dify.app.deleted_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.APP_DELETED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.APP_DELETED,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle feedback created event."""
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
|
||||
from enterprise.telemetry.telemetry_log import emit_metric_only_event
|
||||
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
|
||||
|
||||
exporter = get_enterprise_exporter()
|
||||
if not exporter:
|
||||
logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id)
|
||||
return
|
||||
|
||||
payload = self._rehydrate(envelope)
|
||||
if not payload:
|
||||
return
|
||||
|
||||
include_content = exporter.include_content
|
||||
attrs: dict = {
|
||||
"dify.message.id": payload.get("message_id"),
|
||||
"dify.tenant_id": envelope.tenant_id,
|
||||
"dify.event.id": envelope.event_id,
|
||||
"dify.app_id": payload.get("app_id"),
|
||||
"dify.conversation.id": payload.get("conversation_id"),
|
||||
"gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"),
|
||||
"dify.feedback.rating": payload.get("rating"),
|
||||
"dify.feedback.from_source": payload.get("from_source"),
|
||||
"dify.feedback.created_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
if include_content:
|
||||
attrs["dify.feedback.content"] = payload.get("content")
|
||||
|
||||
user_id = payload.get("from_end_user_id") or payload.get("from_account_id")
|
||||
emit_metric_only_event(
|
||||
event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED,
|
||||
attributes=attrs,
|
||||
tenant_id=envelope.tenant_id,
|
||||
user_id=str(user_id or ""),
|
||||
)
|
||||
exporter.increment_counter(
|
||||
EnterpriseTelemetryCounter.FEEDBACK,
|
||||
1,
|
||||
{
|
||||
"tenant_id": envelope.tenant_id,
|
||||
"app_id": str(payload.get("app_id", "")),
|
||||
"rating": str(payload.get("rating", "")),
|
||||
},
|
||||
)
|
||||
|
||||
def _on_message_run(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle message run event.
|
||||
|
||||
Intentionally a no-op: metrics and structured logs for message runs are
|
||||
emitted directly by EnterpriseOtelTrace._message_trace at trace time,
|
||||
not through the metric handler queue path.
|
||||
"""
|
||||
logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle tool execution event.
|
||||
|
||||
Intentionally a no-op: metrics and structured logs for tool executions
|
||||
are emitted directly by EnterpriseOtelTrace._tool_trace at trace time,
|
||||
not through the metric handler queue path.
|
||||
"""
|
||||
logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle moderation check event.
|
||||
|
||||
Intentionally a no-op: metrics and structured logs for moderation checks
|
||||
are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time,
|
||||
not through the metric handler queue path.
|
||||
"""
|
||||
logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle suggested question event.
|
||||
|
||||
Intentionally a no-op: metrics and structured logs for suggested questions
|
||||
are emitted directly by EnterpriseOtelTrace._suggested_question_trace at
|
||||
trace time, not through the metric handler queue path.
|
||||
"""
|
||||
logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle dataset retrieval event.
|
||||
|
||||
Intentionally a no-op: metrics and structured logs for dataset retrievals
|
||||
are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at
|
||||
trace time, not through the metric handler queue path.
|
||||
"""
|
||||
logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_generate_name(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle generate name event.
|
||||
|
||||
Intentionally a no-op: metrics and structured logs for generate name
|
||||
operations are emitted directly by EnterpriseOtelTrace._generate_name_trace
|
||||
at trace time, not through the metric handler queue path.
|
||||
"""
|
||||
logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id)
|
||||
|
||||
def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None:
|
||||
"""Handle prompt generation event.
|
||||
|
||||
Intentionally a no-op: metrics and structured logs for prompt generation
|
||||
operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace
|
||||
at trace time, not through the metric handler queue path.
|
||||
"""
|
||||
logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id)
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
"""Structured-log emitter for enterprise telemetry events.
|
||||
|
||||
Emits structured JSON log lines correlated with OTEL traces via trace_id.
|
||||
Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
|
||||
|
||||
logger = logging.getLogger("dify.telemetry")
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def compute_trace_id_hex(uuid_str: str | None) -> str:
|
||||
"""Convert a business UUID string to a 32-hex OTEL-compatible trace_id.
|
||||
|
||||
Returns empty string when *uuid_str* is ``None`` or invalid.
|
||||
"""
|
||||
if not uuid_str:
|
||||
return ""
|
||||
normalized = uuid_str.strip().lower()
|
||||
if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized):
|
||||
return normalized
|
||||
try:
|
||||
return f"{uuid.UUID(normalized).int:032x}"
|
||||
except (ValueError, AttributeError):
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def compute_span_id_hex(uuid_str: str | None) -> str:
|
||||
if not uuid_str:
|
||||
return ""
|
||||
normalized = uuid_str.strip().lower()
|
||||
if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized):
|
||||
return normalized
|
||||
try:
|
||||
from enterprise.telemetry.id_generator import compute_deterministic_span_id
|
||||
|
||||
return f"{compute_deterministic_span_id(normalized):016x}"
|
||||
except (ValueError, AttributeError):
|
||||
return ""
|
||||
|
||||
|
||||
def emit_telemetry_log(
|
||||
*,
|
||||
event_name: str | EnterpriseTelemetryEvent,
|
||||
attributes: dict[str, Any],
|
||||
signal: str = "metric_only",
|
||||
trace_id_source: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit a structured log line for a telemetry event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
event_name:
|
||||
Canonical event name, e.g. ``"dify.workflow.run"``.
|
||||
attributes:
|
||||
All event-specific attributes (already built by the caller).
|
||||
signal:
|
||||
``"metric_only"`` for events with no span, ``"span_detail"``
|
||||
for detail logs accompanying a slim span.
|
||||
trace_id_source:
|
||||
A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex
|
||||
trace_id for cross-signal correlation.
|
||||
tenant_id:
|
||||
Tenant identifier (for the ``IdentityContextFilter``).
|
||||
user_id:
|
||||
User identifier (for the ``IdentityContextFilter``).
|
||||
"""
|
||||
if not logger.isEnabledFor(logging.INFO):
|
||||
return
|
||||
attrs = {
|
||||
"dify.event.name": event_name,
|
||||
"dify.event.signal": signal,
|
||||
**attributes,
|
||||
}
|
||||
|
||||
extra: dict[str, Any] = {"attributes": attrs}
|
||||
|
||||
trace_id_hex = compute_trace_id_hex(trace_id_source)
|
||||
if trace_id_hex:
|
||||
extra["trace_id"] = trace_id_hex
|
||||
span_id_hex = compute_span_id_hex(span_id_source)
|
||||
if span_id_hex:
|
||||
extra["span_id"] = span_id_hex
|
||||
if tenant_id:
|
||||
extra["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
extra["user_id"] = user_id
|
||||
|
||||
logger.info("telemetry.%s", signal, extra=extra)
|
||||
|
||||
|
||||
def emit_metric_only_event(
|
||||
*,
|
||||
event_name: str | EnterpriseTelemetryEvent,
|
||||
attributes: dict[str, Any],
|
||||
trace_id_source: str | None = None,
|
||||
span_id_source: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
emit_telemetry_log(
|
||||
event_name=event_name,
|
||||
attributes=attributes,
|
||||
signal="metric_only",
|
||||
trace_id_source=trace_id_source,
|
||||
span_id_source=span_id_source,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
|
@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery:
|
|||
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
|
||||
}
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED:
|
||||
imports.append("tasks.enterprise_telemetry_task")
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
return celery_app
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
"""Flask extension for enterprise telemetry lifecycle management.
|
||||
|
||||
Initializes the EnterpriseExporter singleton during ``create_app()``
|
||||
(single-threaded), registers blinker event handlers, and hooks atexit
|
||||
for graceful shutdown.
|
||||
|
||||
Skipped entirely when ``ENTERPRISE_ENABLED`` and ``ENTERPRISE_TELEMETRY_ENABLED``
|
||||
are false (``is_enabled()`` gate).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_app import DifyApp
|
||||
from enterprise.telemetry.exporter import EnterpriseExporter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_exporter: EnterpriseExporter | None = None
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
global _exporter
|
||||
|
||||
if not is_enabled():
|
||||
return
|
||||
|
||||
from enterprise.telemetry.exporter import EnterpriseExporter
|
||||
|
||||
_exporter = EnterpriseExporter(dify_config)
|
||||
atexit.register(_exporter.shutdown)
|
||||
|
||||
# Import to trigger @signal.connect decorator registration
|
||||
import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport]
|
||||
|
||||
logger.info("Enterprise telemetry initialized")
|
||||
|
||||
|
||||
def get_enterprise_exporter() -> EnterpriseExporter | None:
|
||||
return _exporter
|
||||
|
|
@ -78,16 +78,24 @@ def init_app(app: DifyApp):
|
|||
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
|
||||
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
|
||||
if protocol == "grpc":
|
||||
# Auto-detect TLS: https:// uses secure, everything else is insecure
|
||||
endpoint = dify_config.OTLP_BASE_ENDPOINT
|
||||
insecure = not endpoint.startswith("https://")
|
||||
|
||||
exporter = GRPCSpanExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT,
|
||||
endpoint=endpoint,
|
||||
# Header field names must consist of lowercase letters, check RFC7540
|
||||
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
|
||||
insecure=True,
|
||||
headers=(
|
||||
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None
|
||||
),
|
||||
insecure=insecure,
|
||||
)
|
||||
metric_exporter = GRPCMetricExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT,
|
||||
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
|
||||
insecure=True,
|
||||
endpoint=endpoint,
|
||||
headers=(
|
||||
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None
|
||||
),
|
||||
insecure=insecure,
|
||||
)
|
||||
else:
|
||||
headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set
|
|||
OpenTelemetry span attributes according to semantic conventions.
|
||||
"""
|
||||
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content
|
||||
from extensions.otel.parser.llm import LLMNodeOTelParser
|
||||
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
|
||||
from extensions.otel.parser.tool import ToolNodeOTelParser
|
||||
|
|
@ -17,4 +17,5 @@ __all__ = [
|
|||
"RetrievalNodeOTelParser",
|
||||
"ToolNodeOTelParser",
|
||||
"safe_json_dumps",
|
||||
"should_include_content",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
"""
|
||||
Base parser interface and utilities for OpenTelemetry node parsers.
|
||||
|
||||
Content gating: ``should_include_content()`` controls whether content-bearing
|
||||
span attributes (inputs, outputs, prompts, completions, documents) are written.
|
||||
Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when
|
||||
``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -9,6 +14,7 @@ from opentelemetry.trace import Span
|
|||
from opentelemetry.trace.status import Status, StatusCode
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph_events import GraphNodeEventBase
|
||||
|
|
@ -17,6 +23,17 @@ from dify_graph.variables import Segment
|
|||
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
|
||||
|
||||
|
||||
def should_include_content() -> bool:
|
||||
"""Return True if content should be written to spans.
|
||||
|
||||
CE (ENTERPRISE_ENABLED=False): always True — no behaviour change.
|
||||
EE: follows ENTERPRISE_INCLUDE_CONTENT (default True).
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return True
|
||||
return dify_config.ENTERPRISE_INCLUDE_CONTENT
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
|
||||
"""
|
||||
Safely serialize objects to JSON, handling non-serializable types.
|
||||
|
|
@ -101,10 +118,11 @@ class DefaultNodeOTelParser:
|
|||
# Extract inputs and outputs from result_event
|
||||
if result_event and result_event.node_run_result:
|
||||
node_run_result = result_event.node_run_result
|
||||
if node_run_result.inputs:
|
||||
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
|
||||
if node_run_result.outputs:
|
||||
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
|
||||
if should_include_content():
|
||||
if node_run_result.inputs:
|
||||
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
|
||||
if node_run_result.outputs:
|
||||
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
|
||||
|
||||
if error:
|
||||
span.record_exception(error)
|
||||
|
|
|
|||
|
|
@ -21,3 +21,15 @@ class DifySpanAttributes:
|
|||
|
||||
INVOKE_FROM = "dify.invoke_from"
|
||||
"""Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER."""
|
||||
|
||||
INVOKED_BY = "dify.invoked_by"
|
||||
"""Invoked by, e.g. end_user, account, user."""
|
||||
|
||||
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
"""Number of input tokens (prompt tokens) used."""
|
||||
|
||||
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
"""Number of output tokens (completion tokens) generated."""
|
||||
|
||||
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
"""Total number of tokens used."""
|
||||
|
|
|
|||
|
|
@ -1,16 +1,19 @@
|
|||
import logging
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
|
|
@ -30,8 +33,8 @@ class GitHubEmailRecord(TypedDict, total=False):
|
|||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
name: NotRequired[str | None]
|
||||
email: NotRequired[str | None]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
|
|
@ -127,9 +130,14 @@ class GitHubOAuth(OAuth):
|
|||
response.raise_for_status()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
try:
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_response.raise_for_status()
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
except (httpx.HTTPStatusError, ValidationError):
|
||||
logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True)
|
||||
primary_email = None
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
|
||||
|
|
@ -137,8 +145,11 @@ class GitHubOAuth(OAuth):
|
|||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
raise ValueError(
|
||||
'Dify currently not supports the "Keep my email addresses private" feature,'
|
||||
" please disable it and login again"
|
||||
)
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
|
|
|
|||
|
|
@ -43,7 +43,9 @@ from .enums import (
|
|||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
SummaryStatus,
|
||||
TidbAuthBindingStatus,
|
||||
)
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
|
|
@ -494,7 +496,9 @@ class Document(Base):
|
|||
)
|
||||
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_form: Mapped[IndexStructureType] = mapped_column(
|
||||
EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'")
|
||||
)
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
|
|
@ -998,7 +1002,9 @@ class ChildChunk(Base):
|
|||
# indexing fields
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase):
|
|||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
|
||||
status: Mapped[TidbAuthBindingStatus] = mapped_column(
|
||||
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
|
||||
)
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
|
|||
|
|
@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum):
|
|||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentType(StrEnum):
|
||||
"""Document segment type"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOMIZED = "customized"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
|
|
@ -323,3 +330,10 @@ class ProviderQuotaType(StrEnum):
|
|||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ApiTokenType(StrEnum):
|
||||
"""API Token type"""
|
||||
|
||||
APP = "app"
|
||||
DATASET = "dataset"
|
||||
|
|
|
|||
|
|
@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent):
|
|||
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
|
||||
return cls(form_id=form_id, message_id=message_id)
|
||||
def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent":
|
||||
return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id)
|
||||
|
||||
form: Mapped["HumanInputForm"] = relationship(
|
||||
"HumanInputForm",
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from configs import dify_config
|
|||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
|
|
@ -31,6 +31,7 @@ from .account import Account, Tenant
|
|||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .engine import db
|
||||
from .enums import (
|
||||
ApiTokenType,
|
||||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
|
|
@ -43,6 +44,8 @@ from .enums import (
|
|||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
ProviderQuotaType,
|
||||
TagType,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
|
@ -1782,7 +1785,7 @@ class MessageFile(TypeBase):
|
|||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(
|
||||
EnumText(FileTransferMethod, length=255), nullable=False
|
||||
)
|
||||
|
|
@ -2094,7 +2097,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
|
|||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=True)
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_used_at = mapped_column(sa.DateTime, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -2404,7 +2407,7 @@ class Tag(TypeBase):
|
|||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
@ -2489,7 +2492,9 @@ class TenantCreditPool(TypeBase):
|
|||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
||||
pool_type: Mapped[ProviderQuotaType] = mapped_column(
|
||||
EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial"
|
||||
)
|
||||
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
|
|||
|
|
@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderSchemaType,
|
||||
ToolProviderType,
|
||||
WorkflowToolParameterConfiguration,
|
||||
)
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
|
|
@ -141,7 +145,9 @@ class ApiToolProvider(TypeBase):
|
|||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original schema
|
||||
schema: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column(
|
||||
EnumText(ApiProviderSchemaType, length=40), nullable=False
|
||||
)
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
|
|
@ -208,7 +214,7 @@ class ToolLabelBinding(TypeBase):
|
|||
# tool id
|
||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# tool type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# label name
|
||||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
||||
|
|
@ -386,7 +392,7 @@ class ToolModelInvoke(TypeBase):
|
|||
# provider
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# tool name
|
||||
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
|
|
|
|||
|
|
@ -1221,7 +1221,9 @@ class WorkflowAppLog(TypeBase):
|
|||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column(
|
||||
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False
|
||||
)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
@ -1301,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase):
|
|||
|
||||
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column(
|
||||
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True
|
||||
)
|
||||
|
||||
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[WorkflowExecutionStatus] = mapped_column(
|
||||
EnumText(WorkflowExecutionStatus, length=255), nullable=False
|
||||
)
|
||||
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
|
||||
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ dependencies = [
|
|||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.68",
|
||||
"boto3==1.42.73",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
|
|
@ -23,7 +23,7 @@ dependencies = [
|
|||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.192.0",
|
||||
"google-api-python-client==2.193.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-auth-httplib2==0.3.0",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
|
|
@ -40,7 +40,7 @@ dependencies = [
|
|||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
|
|
@ -72,13 +72,14 @@ dependencies = [
|
|||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.23.0",
|
||||
"sentry-sdk[flask]~=2.54.0",
|
||||
"resend~=2.26.0",
|
||||
"sentry-sdk[flask]~=2.55.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.52.1",
|
||||
"starlette==1.0.0",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"pypandoc~=1.13",
|
||||
"yarl~=1.23.0",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.9.0",
|
||||
|
|
@ -91,7 +92,7 @@ dependencies = [
|
|||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
"bleach~=6.2.0",
|
||||
"bleach~=6.3.0",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
|
@ -118,7 +119,7 @@ dev = [
|
|||
"ruff~=0.15.5",
|
||||
"pytest~=9.0.2",
|
||||
"pytest-benchmark~=5.2.3",
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-cov~=7.1.0",
|
||||
"pytest-env~=1.6.0",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.14.1",
|
||||
|
|
@ -202,7 +203,7 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
|||
# Required by vector store clients
|
||||
############################################################
|
||||
vdb = [
|
||||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_gpdb20160503~=5.1.0",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.14.1",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from configs import dify_config
|
|||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
|
|
@ -57,7 +58,7 @@ def create_clusters(batch_size):
|
|||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
active=False,
|
||||
status="CREATING",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
db.session.add(tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from configs import dify_config
|
|||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
|
|
@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
|
|||
try:
|
||||
# check the number of idle tidb serverless
|
||||
tidb_serverless_list = db.session.scalars(
|
||||
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
||||
select(TidbAuthBinding).where(
|
||||
TidbAuthBinding.active == False,
|
||||
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
).all()
|
||||
if len(tidb_serverless_list) == 0:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class AppService:
|
|||
class ArgsDict(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
icon_type: str
|
||||
icon_type: IconType | str | None
|
||||
icon: str
|
||||
icon_background: str
|
||||
use_icon_as_answer_icon: bool
|
||||
|
|
@ -257,7 +257,13 @@ class AppService:
|
|||
assert current_user is not None
|
||||
app.name = args["name"]
|
||||
app.description = args["description"]
|
||||
app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None
|
||||
icon_type = args.get("icon_type")
|
||||
if icon_type is None:
|
||||
resolved_icon_type = app.icon_type
|
||||
else:
|
||||
resolved_icon_type = IconType(icon_type)
|
||||
|
||||
app.icon_type = resolved_icon_type
|
||||
app.icon = args["icon"]
|
||||
app.icon_background = args["icon_background"]
|
||||
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,16 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class AuthCredentials(TypedDict):
|
||||
auth_type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
def __init__(self, provider: str, credentials: AuthCredentials):
|
||||
auth_factory = self.get_apikey_auth_factory(provider)
|
||||
self.auth = auth_factory(credentials)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import json
|
|||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import json
|
|||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import json
|
|||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ from urllib.parse import urljoin
|
|||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
|
||||
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: AuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
|
|
|
|||
|
|
@ -335,7 +335,11 @@ class BillingService:
|
|||
# Redis returns bytes, decode to string and parse JSON
|
||||
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
|
||||
plan_dict = json.loads(json_str)
|
||||
# NOTE (hj24): New billing versions may return timestamp as str, and validate_python
|
||||
# in non-strict mode will coerce it to the expected int type.
|
||||
# To preserve compatibility, always keep non-strict mode here and avoid strict mode.
|
||||
subscription_plan = subscription_adapter.validate_python(plan_dict)
|
||||
# NOTE END
|
||||
tenant_plans[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from configs import dify_config
|
|||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -16,7 +17,10 @@ class CreditPoolService:
|
|||
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""create default credit pool for new tenant"""
|
||||
credit_pool = TenantCreditPool(
|
||||
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
||||
tenant_id=tenant_id,
|
||||
quota_limit=dify_config.HOSTED_POOL_CREDITS,
|
||||
quota_used=0,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
db.session.add(credit_pool)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ from models.enums import (
|
|||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
|
@ -1439,7 +1440,7 @@ class DocumentService:
|
|||
.filter(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != "qa_model", # Skip qa_model documents
|
||||
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
)
|
||||
|
|
@ -2039,7 +2040,7 @@ class DocumentService:
|
|||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_form = IndexStructureType(knowledge_config.doc_form)
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
|
|
@ -2639,7 +2640,7 @@ class DocumentService:
|
|||
document.splitting_completed_at = None
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = document_data.doc_form
|
||||
document.doc_form = IndexStructureType(document_data.doc_form)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# update document segment
|
||||
|
|
@ -3100,7 +3101,7 @@ class DocumentService:
|
|||
class SegmentService:
|
||||
@classmethod
|
||||
def segment_create_args_validate(cls, args: dict, document: Document):
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
if "answer" not in args or not args["answer"]:
|
||||
raise ValueError("Answer is required")
|
||||
if not args["answer"].strip():
|
||||
|
|
@ -3157,7 +3158,7 @@ class SegmentService:
|
|||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment_document.word_count += len(args["answer"])
|
||||
segment_document.answer = args["answer"]
|
||||
|
||||
|
|
@ -3231,7 +3232,7 @@ class SegmentService:
|
|||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[content + segment_item["answer"]]
|
||||
)[0]
|
||||
|
|
@ -3254,7 +3255,7 @@ class SegmentService:
|
|||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment_document.answer = segment_item["answer"]
|
||||
segment_document.word_count += len(segment_item["answer"])
|
||||
increment_word_count += segment_document.word_count
|
||||
|
|
@ -3321,7 +3322,7 @@ class SegmentService:
|
|||
content = args.content or segment.content
|
||||
if segment.content == content:
|
||||
segment.word_count = len(content)
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
|
|
@ -3418,7 +3419,7 @@ class SegmentService:
|
|||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
|
||||
else:
|
||||
|
|
@ -3435,7 +3436,7 @@ class SegmentService:
|
|||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
if document.doc_form == "qa_model":
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
|
|
@ -3786,7 +3787,7 @@ class SegmentService:
|
|||
child_chunk.word_count = len(child_chunk.content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = naive_utc_now()
|
||||
child_chunk.type = "customized"
|
||||
child_chunk.type = SegmentType.CUSTOMIZED
|
||||
update_child_chunks.append(child_chunk)
|
||||
else:
|
||||
new_child_chunks_args.append(child_chunk_update_args)
|
||||
|
|
@ -3845,7 +3846,7 @@ class SegmentService:
|
|||
child_chunk.word_count = len(content)
|
||||
child_chunk.updated_by = current_user.id
|
||||
child_chunk.updated_at = naive_utc_now()
|
||||
child_chunk.type = "customized"
|
||||
child_chunk.type = SegmentType.CUSTOMIZED
|
||||
db.session.add(child_chunk)
|
||||
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from flask_login import current_user
|
|||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
|
|
@ -79,9 +80,9 @@ class RagPipelineTransformService:
|
|||
pipeline = self._create_pipeline(pipeline_yaml)
|
||||
|
||||
# save chunk structure to dataset
|
||||
if doc_form == "hierarchical_model":
|
||||
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
dataset.chunk_structure = "hierarchical_model"
|
||||
elif doc_form == "text_model":
|
||||
elif doc_form == IndexStructureType.PARAGRAPH_INDEX:
|
||||
dataset.chunk_structure = "text_model"
|
||||
else:
|
||||
raise ValueError("Unsupported doc form")
|
||||
|
|
@ -101,7 +102,7 @@ class RagPipelineTransformService:
|
|||
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
|
||||
pipeline_yaml = {}
|
||||
if doc_form == "text_model":
|
||||
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
|
||||
match datasource_type:
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if indexing_technique == "high_quality":
|
||||
|
|
@ -132,7 +133,7 @@ class RagPipelineTransformService:
|
|||
pipeline_yaml = yaml.safe_load(f)
|
||||
case _:
|
||||
raise ValueError("Unsupported datasource type")
|
||||
elif doc_form == "hierarchical_model":
|
||||
elif doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
match datasource_type:
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
# get graph from transform.file-parentchild.yml
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
|||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
|
||||
|
||||
|
|
@ -83,7 +84,7 @@ class TagService:
|
|||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
name=args["name"],
|
||||
type=args["type"],
|
||||
type=TagType(args["type"]),
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from sqlalchemy import func
|
|||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
|
@ -109,7 +110,7 @@ def batch_create_segment_to_index_task(
|
|||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
|
|
@ -159,7 +160,7 @@ def batch_create_segment_to_index_task(
|
|||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from configs import dify_config
|
|||
from core.db.session_factory import session_factory
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
)
|
||||
if (
|
||||
document.indexing_status == IndexingStatus.COMPLETED
|
||||
and document.doc_form != "qa_model"
|
||||
and document.doc_form != IndexStructureType.QA_INDEX
|
||||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
"""Celery worker for enterprise metric/log telemetry events.
|
||||
|
||||
This module defines the Celery task that processes telemetry envelopes
|
||||
from the enterprise_telemetry queue. It deserializes envelopes and
|
||||
dispatches them to the EnterpriseMetricHandler.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="enterprise_telemetry")
|
||||
def process_enterprise_telemetry(envelope_json: str) -> None:
|
||||
"""Process enterprise metric/log telemetry envelope.
|
||||
|
||||
This task is enqueued by the TelemetryGateway for metric/log-only
|
||||
events. It deserializes the envelope and dispatches to the handler.
|
||||
|
||||
Best-effort processing: logs errors but never raises, to avoid
|
||||
failing user requests due to telemetry issues.
|
||||
|
||||
Args:
|
||||
envelope_json: JSON-serialized TelemetryEnvelope.
|
||||
"""
|
||||
try:
|
||||
# Deserialize envelope
|
||||
envelope_dict = json.loads(envelope_json)
|
||||
envelope = TelemetryEnvelope.model_validate(envelope_dict)
|
||||
|
||||
# Process through handler
|
||||
handler = EnterpriseMetricHandler()
|
||||
handler.handle(envelope)
|
||||
|
||||
logger.debug(
|
||||
"Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s",
|
||||
envelope.tenant_id,
|
||||
envelope.event_id,
|
||||
envelope.case,
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort: log and drop on error, never fail user request
|
||||
logger.warning(
|
||||
"Failed to process enterprise telemetry envelope, dropping event",
|
||||
exc_info=True,
|
||||
)
|
||||
|
|
@ -39,17 +39,36 @@ def process_trace_tasks(file_info):
|
|||
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
|
||||
|
||||
try:
|
||||
trace_type = trace_info_info_map.get(trace_info_type)
|
||||
if trace_type:
|
||||
trace_info = trace_type(**trace_info)
|
||||
|
||||
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
|
||||
|
||||
if is_ee_telemetry_enabled():
|
||||
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
|
||||
|
||||
try:
|
||||
EnterpriseOtelTrace().trace(trace_info)
|
||||
except Exception:
|
||||
logger.exception("Enterprise trace failed for app_id: %s", app_id)
|
||||
|
||||
if trace_instance:
|
||||
with current_app.app_context():
|
||||
trace_type = trace_info_info_map.get(trace_info_type)
|
||||
if trace_type:
|
||||
trace_info = trace_type(**trace_info)
|
||||
trace_instance.trace(trace_info)
|
||||
|
||||
logger.info("Processing trace tasks success, app_id: %s", app_id)
|
||||
except Exception as e:
|
||||
logger.info("error:\n\n\n%s\n\n\n\n", e)
|
||||
logger.exception("Processing trace tasks failed, app_id: %s", app_id)
|
||||
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
|
||||
redis_client.incr(failed_key)
|
||||
logger.info("Processing trace tasks failed, app_id: %s", app_id)
|
||||
finally:
|
||||
storage.delete(file_path)
|
||||
try:
|
||||
storage.delete(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete trace file %s for app_id %s: %s",
|
||||
file_path,
|
||||
app_id,
|
||||
e,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from celery import shared_task
|
|||
from sqlalchemy import or_, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
|
@ -106,7 +107,7 @@ def regenerate_summary_index_task(
|
|||
),
|
||||
DatasetDocument.enabled == True, # Document must be enabled
|
||||
DatasetDocument.archived == False, # Document must not be archived
|
||||
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
|
||||
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
|
||||
.all()
|
||||
|
|
@ -209,7 +210,7 @@ def regenerate_summary_index_task(
|
|||
|
||||
for dataset_document in dataset_documents:
|
||||
# Skip qa_model documents
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
if dataset_document.doc_form == IndexStructureType.QA_INDEX:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ def _record_trigger_failure_log(
|
|||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken
|
||||
from services.api_token_service import ApiTokenCache, CachedApiToken
|
||||
|
||||
|
|
@ -279,7 +280,7 @@ class TestEndToEndCacheFlow:
|
|||
test_token = ApiToken()
|
||||
test_token.id = "test-e2e-id"
|
||||
test_token.token = test_token_value
|
||||
test_token.type = test_scope
|
||||
test_token.type = ApiTokenType.APP
|
||||
test_token.app_id = "test-app"
|
||||
test_token.tenant_id = "test-tenant"
|
||||
test_token.last_used_at = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,342 @@
|
|||
"""Authenticated controller integration tests for console message APIs."""
|
||||
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload
|
||||
from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import ConversationFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
create_console_app,
|
||||
)
|
||||
|
||||
|
||||
def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=mode,
|
||||
name="Test Conversation",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
db_session.commit()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_message(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
created_at_offset_seconds: int = 0,
|
||||
) -> Message:
|
||||
created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds)
|
||||
message = Message(
|
||||
app_id=app_id,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
conversation_id=conversation_id,
|
||||
inputs={},
|
||||
query="Hello",
|
||||
message={"type": "text", "content": "Hello"},
|
||||
message_tokens=1,
|
||||
message_unit_price=Decimal("0.0001"),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
answer="Hi there",
|
||||
answer_tokens=1,
|
||||
answer_unit_price=Decimal("0.0001"),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
parent_message_id=None,
|
||||
provider_response_latency=0,
|
||||
total_price=Decimal("0.0002"),
|
||||
currency="USD",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
app_mode=AppMode.CHAT,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
return message
|
||||
|
||||
|
||||
class TestMessageValidators:
|
||||
def test_chat_messages_query_validators(self) -> None:
|
||||
assert ChatMessagesQuery.empty_to_none("") is None
|
||||
assert ChatMessagesQuery.empty_to_none("val") == "val"
|
||||
assert ChatMessagesQuery.validate_uuid(None) is None
|
||||
assert (
|
||||
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_message_feedback_validators(self) -> None:
|
||||
assert (
|
||||
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_feedback_export_validators(self) -> None:
|
||||
assert FeedbackExportQuery.parse_bool(None) is None
|
||||
assert FeedbackExportQuery.parse_bool(True) is True
|
||||
assert FeedbackExportQuery.parse_bool("1") is True
|
||||
assert FeedbackExportQuery.parse_bool("0") is False
|
||||
assert FeedbackExportQuery.parse_bool("off") is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackExportQuery.parse_bool("invalid")
|
||||
|
||||
|
||||
def test_chat_message_list_not_found(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages",
|
||||
query_string={"conversation_id": str(uuid4())},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "not_found"
|
||||
|
||||
|
||||
def test_chat_message_list_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0)
|
||||
second = _create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
account.id,
|
||||
created_at_offset_seconds=1,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.attach_message_extra_contents",
|
||||
side_effect=_attach_message_extra_contents,
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages",
|
||||
query_string={"conversation_id": conversation.id, "limit": 1},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["limit"] == 1
|
||||
assert payload["has_more"] is True
|
||||
assert len(payload["data"]) == 1
|
||||
assert payload["data"][0]["id"] == second.id
|
||||
|
||||
|
||||
def test_message_feedback_not_found(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
f"/console/api/apps/{app.id}/feedbacks",
|
||||
json={"message_id": str(uuid4()), "rating": "like"},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "not_found"
|
||||
|
||||
|
||||
def test_message_feedback_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
|
||||
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
f"/console/api/apps/{app.id}/feedbacks",
|
||||
json={"message_id": message.id, "rating": "like"},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
|
||||
feedback = db_session_with_containers.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == message.id)
|
||||
)
|
||||
assert feedback is not None
|
||||
assert feedback.rating == FeedbackRating.LIKE
|
||||
assert feedback.from_account_id == account.id
|
||||
|
||||
|
||||
def test_message_annotation_count(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
|
||||
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
|
||||
db_session_with_containers.add(
|
||||
MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
question="Q",
|
||||
content="A",
|
||||
account_id=account.id,
|
||||
)
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/annotations/count",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"count": 1}
|
||||
|
||||
|
||||
def test_message_suggested_questions_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
message_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.MessageService.get_suggested_questions_after_answer",
|
||||
return_value=["q1", "q2"],
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"data": ["q1", "q2"]}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected_status", "expected_code"),
|
||||
[
|
||||
(MessageNotExistsError(), 404, "not_found"),
|
||||
(ConversationNotExistsError(), 404, "not_found"),
|
||||
(ProviderTokenNotInitError(), 400, "provider_not_initialize"),
|
||||
(QuotaExceededError(), 400, "provider_quota_exceeded"),
|
||||
(ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"),
|
||||
(SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"),
|
||||
(Exception(), 500, "internal_server_error"),
|
||||
],
|
||||
)
|
||||
def test_message_suggested_questions_errors(
|
||||
exc: Exception,
|
||||
expected_status: int,
|
||||
expected_code: str,
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
message_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.MessageService.get_suggested_questions_after_answer",
|
||||
side_effect=exc,
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == expected_status
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == expected_code
|
||||
|
||||
|
||||
def test_message_feedback_export_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/feedbacks/export",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"exported": True}
|
||||
|
||||
|
||||
def test_message_api_get_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
|
||||
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.message.attach_message_extra_contents",
|
||||
side_effect=_attach_message_extra_contents,
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/messages/{message.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["id"] == message.id
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
"""Controller integration tests for console statistic routes."""
|
||||
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageFeedback
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
create_console_app,
|
||||
)
|
||||
|
||||
|
||||
def _create_conversation(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
mode: AppMode,
|
||||
created_at_offset_days: int = 0,
|
||||
) -> Conversation:
|
||||
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=mode,
|
||||
name="Stats Conversation",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=account_id,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
db_session.commit()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_message(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
*,
|
||||
from_account_id: str | None,
|
||||
from_end_user_id: str | None = None,
|
||||
message_tokens: int = 1,
|
||||
answer_tokens: int = 1,
|
||||
total_price: Decimal = Decimal("0.01"),
|
||||
provider_response_latency: float = 1.0,
|
||||
created_at_offset_days: int = 0,
|
||||
) -> Message:
|
||||
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
|
||||
message = Message(
|
||||
app_id=app_id,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
conversation_id=conversation_id,
|
||||
inputs={},
|
||||
query="Hello",
|
||||
message={"type": "text", "content": "Hello"},
|
||||
message_tokens=message_tokens,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
answer="Hi there",
|
||||
answer_tokens=answer_tokens,
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
parent_message_id=None,
|
||||
provider_response_latency=provider_response_latency,
|
||||
total_price=total_price,
|
||||
currency="USD",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=from_end_user_id,
|
||||
from_account_id=from_account_id,
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
app_mode=AppMode.CHAT,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
return message
|
||||
|
||||
|
||||
def _create_like_feedback(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
account_id: str,
|
||||
) -> None:
|
||||
db_session.add(
|
||||
MessageFeedback(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def test_daily_message_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["message_count"] == 1
|
||||
|
||||
|
||||
def test_daily_conversation_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-conversations",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["conversation_count"] == 1
|
||||
|
||||
|
||||
def test_daily_terminals_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=None,
|
||||
from_end_user_id=str(uuid4()),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-end-users",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["terminal_count"] == 1
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
message_tokens=40,
|
||||
answer_tokens=60,
|
||||
total_price=Decimal("0.02"),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/token-costs",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["data"][0]["token_count"] == 100
|
||||
assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02")
|
||||
|
||||
|
||||
def test_average_session_interaction_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/average-session-interactions",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["interactions"] == 2.0
|
||||
|
||||
|
||||
def test_user_satisfaction_rate_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
for _ in range(9):
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
|
||||
_create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["rate"] == 100.0
|
||||
|
||||
|
||||
def test_average_response_time_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
provider_response_latency=1.234,
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/average-response-time",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["latency"] == 1234.0
|
||||
|
||||
|
||||
def test_tokens_per_second_statistic(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
|
||||
_create_message(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
conversation.id,
|
||||
from_account_id=account.id,
|
||||
answer_tokens=31,
|
||||
provider_response_latency=2.0,
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/tokens-per-second",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json()["data"][0]["tps"] == 15.5
|
||||
|
||||
|
||||
def test_invalid_time_range(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "Invalid time"
|
||||
|
||||
|
||||
def test_time_range_params_passed(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
import datetime
|
||||
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
|
||||
with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse:
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_parse.assert_called_once_with("something", "something", "UTC")
|
||||
|
|
@ -0,0 +1,415 @@
|
|||
"""Authenticated controller integration tests for workflow draft variable APIs."""
|
||||
|
||||
import uuid
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from factories.variable_factory import segment_to_variable
|
||||
from models import Workflow
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
create_console_app,
|
||||
)
|
||||
|
||||
|
||||
def _create_draft_workflow(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
environment_variables: list | None = None,
|
||||
conversation_variables: list | None = None,
|
||||
) -> Workflow:
|
||||
workflow = Workflow.new(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="workflow",
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by=account_id,
|
||||
environment_variables=environment_variables or [],
|
||||
conversation_variables=conversation_variables or [],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session.add(workflow)
|
||||
db_session.commit()
|
||||
return workflow
|
||||
|
||||
|
||||
def _create_node_variable(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
node_id: str = "node_1",
|
||||
name: str = "test_var",
|
||||
) -> WorkflowDraftVariable:
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
db_session.add(variable)
|
||||
db_session.commit()
|
||||
return variable
|
||||
|
||||
|
||||
def _create_system_variable(
|
||||
db_session: Session, app_id: str, user_id: str, name: str = "query"
|
||||
) -> WorkflowDraftVariable:
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
value=StringSegment(value="system-value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
editable=True,
|
||||
)
|
||||
db_session.add(variable)
|
||||
db_session.commit()
|
||||
return variable
|
||||
|
||||
|
||||
def _build_environment_variable(name: str, value: str):
|
||||
return segment_to_variable(
|
||||
segment=StringSegment(value=value),
|
||||
selector=[ENVIRONMENT_VARIABLE_NODE_ID, name],
|
||||
name=name,
|
||||
description=f"Environment variable {name}",
|
||||
)
|
||||
|
||||
|
||||
def _build_conversation_variable(name: str, value: str):
|
||||
return segment_to_variable(
|
||||
segment=StringSegment(value=value),
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, name],
|
||||
name=name,
|
||||
description=f"Conversation variable {name}",
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_variable_collection_get_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_workflow_variable_collection_get_not_exist(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "draft_workflow_not_exist"
|
||||
|
||||
|
||||
def test_workflow_variable_collection_delete(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var")
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
remaining = db_session_with_containers.scalars(
|
||||
select(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app.id,
|
||||
WorkflowDraftVariable.user_id == account.id,
|
||||
)
|
||||
).all()
|
||||
assert remaining == []
|
||||
|
||||
|
||||
def test_node_variable_collection_get_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
|
||||
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other")
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert [item["id"] for item in payload["items"]] == [node_variable.id]
|
||||
|
||||
|
||||
def test_node_variable_collection_get_invalid_node_id(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "invalid_param"
|
||||
|
||||
|
||||
def test_node_variable_collection_delete(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
|
||||
untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456")
|
||||
target_id = target.id
|
||||
untouched_id = untouched.id
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id))
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id))
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def test_variable_api_get_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["id"] == variable.id
|
||||
assert payload["name"] == "test_var"
|
||||
|
||||
|
||||
def test_variable_api_get_not_found(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "not_found"
|
||||
|
||||
|
||||
def test_variable_api_patch_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.patch(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
json={"name": "renamed_var"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["id"] == variable.id
|
||||
assert payload["name"] == "renamed_var"
|
||||
|
||||
refreshed = db_session_with_containers.scalar(
|
||||
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)
|
||||
)
|
||||
assert refreshed is not None
|
||||
assert refreshed.name == "renamed_var"
|
||||
|
||||
|
||||
def test_variable_api_delete_success(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id))
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_variable_reset_api_put_success_returns_no_content_without_execution(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.put(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id))
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_conversation_variable_collection_get(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
tenant.id,
|
||||
account.id,
|
||||
conversation_variables=[_build_conversation_variable("session_name", "Alice")],
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/conversation-variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert [item["name"] for item in payload["items"]] == ["session_name"]
|
||||
|
||||
created = db_session_with_containers.scalars(
|
||||
select(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app.id,
|
||||
WorkflowDraftVariable.user_id == account.id,
|
||||
WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID,
|
||||
)
|
||||
).all()
|
||||
assert len(created) == 1
|
||||
|
||||
|
||||
def test_system_variable_collection_get(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
variable = _create_system_variable(db_session_with_containers, app.id, account.id)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/system-variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert [item["id"] for item in payload["items"]] == [variable.id]
|
||||
|
||||
|
||||
def test_environment_variable_collection_get(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
_create_draft_workflow(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
tenant.id,
|
||||
account.id,
|
||||
environment_variables=[_build_environment_variable("api_key", "secret-value")],
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/environment-variables",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["items"][0]["name"] == "api_key"
|
||||
assert payload["items"][0]["value"] == "secret-value"
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
"""Controller integration tests for API key data source auth routes."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def test_get_api_key_auth_data_source(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant.id,
|
||||
category="api_key",
|
||||
provider="custom_provider",
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/api-key-auth/data-source",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert len(payload["sources"]) == 1
|
||||
assert payload["sources"][0]["provider"] == "custom_provider"
|
||||
|
||||
|
||||
def test_get_api_key_auth_data_source_empty(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/api-key-auth/data-source",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"sources": []}
|
||||
|
||||
|
||||
def test_create_binding_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
|
||||
|
||||
def test_create_binding_failure(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth",
|
||||
side_effect=ValueError("Invalid structure"),
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/api-key-auth/data-source/binding",
|
||||
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["code"] == "auth_failed"
|
||||
assert payload["message"] == "Invalid structure"
|
||||
|
||||
|
||||
def test_delete_binding_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant.id,
|
||||
category="api_key",
|
||||
provider="custom_provider",
|
||||
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
response = test_client_with_containers.delete(
|
||||
f"/console/api/api-key-auth/data-source/{binding.id}",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert (
|
||||
db_session_with_containers.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id)
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""Controller integration tests for console OAuth data source routes."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.source import DataSourceOauthBinding
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def test_get_oauth_url_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
provider = MagicMock()
|
||||
provider.get_authorization_url.return_value = "http://oauth.provider/auth"
|
||||
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}),
|
||||
patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None),
|
||||
):
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/oauth/data-source/notion",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert tenant.id == account.current_tenant_id
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"data": "http://oauth.provider/auth"}
|
||||
provider.get_authorization_url.assert_called_once()
|
||||
|
||||
|
||||
def test_get_oauth_url_invalid_provider(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/oauth/data-source/unknown_provider",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json() == {"error": "Invalid provider"}
|
||||
|
||||
|
||||
def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None:
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "code=mock_code" in response.location
|
||||
|
||||
|
||||
def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None:
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "error=Access%20denied" in response.location
|
||||
|
||||
|
||||
def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None:
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json() == {"error": "Invalid provider"}
|
||||
|
||||
|
||||
def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None:
|
||||
provider = MagicMock()
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
provider.get_access_token.assert_called_once_with("auth_code_123")
|
||||
|
||||
|
||||
def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None:
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
|
||||
response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json() == {"error": "Invalid code"}
|
||||
|
||||
|
||||
def test_sync_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
binding = DataSourceOauthBinding(
|
||||
tenant_id=tenant.id,
|
||||
access_token="test-access-token",
|
||||
provider="notion",
|
||||
source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []},
|
||||
disabled=False,
|
||||
)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
provider = MagicMock()
|
||||
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}):
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/oauth/data-source/notion/{binding.id}/sync",
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
provider.sync_data_source.assert_called_once_with(binding.id)
|
||||
|
|
@ -0,0 +1,365 @@
|
|||
"""Controller integration tests for console OAuth server routes."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
ensure_dify_setup,
|
||||
)
|
||||
|
||||
|
||||
def _build_oauth_provider_app() -> OAuthProviderApp:
|
||||
return OAuthProviderApp(
|
||||
app_icon="icon_url",
|
||||
client_id="test_client_id",
|
||||
client_secret="test_secret",
|
||||
app_label={"en-US": "Test App"},
|
||||
redirect_uris=["http://localhost/callback"],
|
||||
scope="read,write",
|
||||
)
|
||||
|
||||
|
||||
def test_oauth_provider_successful_post(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["app_icon"] == "icon_url"
|
||||
assert payload["app_label"] == {"en-US": "Test App"}
|
||||
assert payload["scope"] == "read,write"
|
||||
|
||||
|
||||
def test_oauth_provider_invalid_redirect_uri(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert "redirect_uri is invalid" in payload["message"]
|
||||
|
||||
|
||||
def test_oauth_provider_invalid_client_id(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider",
|
||||
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert "client_id is invalid" in payload["message"]
|
||||
|
||||
|
||||
def test_oauth_authorize_successful(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
|
||||
return_value="auth_code_123",
|
||||
) as mock_sign,
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/authorize",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"code": "auth_code_123"}
|
||||
mock_sign.assert_called_once_with("test_client_id", account.id)
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
|
||||
return_value=("access_123", "refresh_123"),
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"access_token": "access_123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": "refresh_123",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_missing_code(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "code is required"
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_invalid_secret(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "invalid_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "client_secret is invalid"
|
||||
|
||||
|
||||
def test_oauth_token_authorization_code_grant_invalid_redirect_uri(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={
|
||||
"client_id": "test_client_id",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "auth_code",
|
||||
"client_secret": "test_secret",
|
||||
"redirect_uri": "http://invalid/callback",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "redirect_uri is invalid"
|
||||
|
||||
|
||||
def test_oauth_token_refresh_token_grant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
|
||||
return_value=("new_access", "new_refresh"),
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"access_token": "new_access",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": "new_refresh",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_token_refresh_token_grant_missing_token(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "refresh_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "refresh_token is required"
|
||||
|
||||
|
||||
def test_oauth_token_invalid_grant_type(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/token",
|
||||
json={"client_id": "test_client_id", "grant_type": "invalid_grant"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["message"] == "invalid grant_type"
|
||||
|
||||
|
||||
def test_oauth_account_successful_retrieval(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
account.avatar = "avatar_url"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token",
|
||||
return_value=account,
|
||||
),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "Bearer valid_access_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"name": "Test User",
|
||||
"email": account.email,
|
||||
"avatar": "avatar_url",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_account_missing_authorization_header(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.get_json() == {"error": "Authorization header is required"}
|
||||
|
||||
|
||||
def test_oauth_account_invalid_authorization_header_format(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
|
||||
return_value=_build_oauth_provider_app(),
|
||||
):
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/oauth/provider/account",
|
||||
json={"client_id": "test_client_id"},
|
||||
headers={"Authorization": "InvalidFormat"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.get_json() == {"error": "Invalid Authorization header format"}
|
||||
|
|
@ -1,17 +1,10 @@
|
|||
"""
|
||||
Test suite for password reset authentication flows.
|
||||
"""Testcontainers integration tests for password reset authentication flows."""
|
||||
|
||||
This module tests the password reset mechanism including:
|
||||
- Password reset email sending
|
||||
- Verification code validation
|
||||
- Password reset with token
|
||||
- Rate limiting and security checks
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
|
|
@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import (
|
|||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_session():
|
||||
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_session_cls.return_value.__exit__.return_value = None
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_db():
|
||||
with patch("controllers.console.auth.forgot_password.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
|
|
@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
|
|
@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi:
|
|||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test successful password reset email sending.
|
||||
|
||||
Verifies that:
|
||||
- Email is sent to valid account
|
||||
- Reset token is generated and returned
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "reset_token_123"
|
||||
|
|
@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi:
|
|||
assert response["data"] == "reset_token_123"
|
||||
mock_send_email.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
|
||||
"""
|
||||
Test password reset email blocked by IP rate limit.
|
||||
|
||||
|
|
@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
- No email is sent when rate limited
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
(None, "en-US"), # Defaults to en-US when not provided
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
|
|
@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
language_input,
|
||||
|
|
@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi:
|
|||
- Unsupported languages default to en-US
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
|
|
@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi:
|
|||
"""Test cases for verifying password reset codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
|
|
@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""
|
||||
|
|
@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Rate limit is reset on success
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_generate_token.return_value = (None, "new_token")
|
||||
|
|
@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi:
|
|||
)
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
|
|
@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
|
@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi:
|
|||
mock_revoke_token.assert_called_once_with("upper_token")
|
||||
mock_reset_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification blocked by rate limit.
|
||||
|
||||
|
|
@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Prevents brute force attacks on verification codes
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(EmailPasswordResetLimitError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with invalid token.
|
||||
|
||||
|
|
@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi:
|
|||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = None
|
||||
|
||||
|
|
@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with mismatched email.
|
||||
|
||||
|
|
@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Prevents token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
|
|
@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi:
|
|||
with pytest.raises(InvalidEmailError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
|
||||
"""
|
||||
Test code verification with incorrect code.
|
||||
|
||||
|
|
@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi:
|
|||
- Rate limit counter is incremented
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
|
|
@ -380,11 +321,8 @@ class TestForgotPasswordResetApi:
|
|||
"""Test cases for resetting password with verified token."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
|
|
@ -394,7 +332,6 @@ class TestForgotPasswordResetApi:
|
|||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
|
|
@ -405,7 +342,6 @@ class TestForgotPasswordResetApi:
|
|||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
|
|
@ -418,7 +354,6 @@ class TestForgotPasswordResetApi:
|
|||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
|
@ -436,9 +371,8 @@ class TestForgotPasswordResetApi:
|
|||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with("valid_token")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_mismatch(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with mismatched passwords.
|
||||
|
||||
|
|
@ -447,7 +381,6 @@ class TestForgotPasswordResetApi:
|
|||
- No password update occurs
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -460,9 +393,8 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(PasswordMismatchError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_invalid_token(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with invalid token.
|
||||
|
||||
|
|
@ -470,7 +402,6 @@ class TestForgotPasswordResetApi:
|
|||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -483,9 +414,8 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, app):
|
||||
"""
|
||||
Test password reset with token not in reset phase.
|
||||
|
||||
|
|
@ -494,7 +424,6 @@ class TestForgotPasswordResetApi:
|
|||
- Prevents use of verification-phase tokens for reset
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -507,13 +436,10 @@ class TestForgotPasswordResetApi:
|
|||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_reset_password_account_not_found(
|
||||
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
|
||||
):
|
||||
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
|
||||
|
|
@ -521,7 +447,6 @@ class TestForgotPasswordResetApi:
|
|||
- AccountNotFound is raised when account doesn't exist
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = None
|
||||
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Shared helpers for authenticated console controller integration tests."""
|
||||
|
||||
import uuid
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_CSRF_TOKEN
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.token import _real_cookie_name, generate_csrf_token
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.model import App, AppMode
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
def ensure_dify_setup(db_session: Session) -> None:
|
||||
"""Create a setup marker once so setup-protected console routes can be exercised."""
|
||||
if db_session.scalar(select(DifySetup).limit(1)) is not None:
|
||||
return
|
||||
|
||||
db_session.add(DifySetup(version=dify_config.project.version))
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]:
|
||||
"""Create an initialized owner account with a current tenant."""
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.initialized_at = naive_utc_now()
|
||||
db_session.add(account)
|
||||
db_session.commit()
|
||||
|
||||
tenant = Tenant(name="Test Tenant", status="normal")
|
||||
db_session.add(tenant)
|
||||
db_session.commit()
|
||||
|
||||
db_session.add(
|
||||
TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
account.set_tenant_id(tenant.id)
|
||||
account.timezone = "UTC"
|
||||
db_session.commit()
|
||||
|
||||
ensure_dify_setup(db_session)
|
||||
return account, tenant
|
||||
|
||||
|
||||
def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App:
|
||||
"""Create a minimal app row that can be loaded by get_app_model."""
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
name="Test App",
|
||||
mode=mode,
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
created_by=account_id,
|
||||
)
|
||||
db_session.add(app)
|
||||
db_session.commit()
|
||||
return app
|
||||
|
||||
|
||||
def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]:
|
||||
"""Attach console auth cookies/headers for endpoints guarded by login_required."""
|
||||
access_token = AccountService.get_account_jwt_token(account)
|
||||
csrf_token = generate_csrf_token(account.id)
|
||||
test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost")
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
HEADER_NAME_CSRF_TOKEN: csrf_token,
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
|
|
@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
name=f"Document {i}",
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
doc_language="en",
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
|
|
@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Archived Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=True, # Archived
|
||||
|
|
@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Disabled Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=False, # Disabled
|
||||
archived=False,
|
||||
|
|
@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {status}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=status, # Not completed
|
||||
enabled=True,
|
||||
archived=False,
|
||||
|
|
@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document for {dataset.name}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
|
|
@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration:
|
|||
created_from=DocumentCreatedFrom.WEB,
|
||||
name=f"Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
|
|
@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration:
|
|||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
|
|
|||
|
|
@ -1,27 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
create_human_input_message_fixture,
|
||||
)
|
||||
|
||||
|
||||
def test_get_by_message_ids_returns_human_input_content(db_session_with_containers):
|
||||
fixture = create_human_input_message_fixture(db_session_with_containers)
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
|
||||
results = repository.get_by_message_ids([fixture.message.id])
|
||||
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) == 1
|
||||
content = results[0][0]
|
||||
assert content.submitted is True
|
||||
assert content.form_submission_data is not None
|
||||
assert content.form_submission_data.action_id == fixture.action_id
|
||||
assert content.form_submission_data.action_text == fixture.action_text
|
||||
assert content.form_submission_data.rendered_content == fixture.form.rendered_content
|
||||
|
|
@ -27,7 +27,7 @@ from models.human_input import (
|
|||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
|
|
@ -218,7 +218,7 @@ class TestDeleteRunsWithRelated:
|
|||
app_id=test_scope.app_id,
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
|
|
@ -278,7 +278,7 @@ class TestCountRunsWithRelated:
|
|||
app_id=test_scope.app_id,
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,407 @@
|
|||
"""Integration tests for SQLAlchemyExecutionExtraContentRepository using Testcontainers.
|
||||
|
||||
Part of #32454 — replaces the mock-based unit tests with real database interactions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, UserAction
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormStatus
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.execution_extra_content import ExecutionExtraContent, HumanInputContent
|
||||
from models.human_input import (
|
||||
ConsoleRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.model import App, Conversation, Message
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestScope:
|
||||
"""Per-test data scope used to isolate DB rows.
|
||||
|
||||
IDs are populated after flushing the base entities to the database.
|
||||
"""
|
||||
|
||||
tenant_id: str = ""
|
||||
app_id: str = ""
|
||||
user_id: str = ""
|
||||
|
||||
|
||||
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
|
||||
"""Remove test-created DB rows for a test scope."""
|
||||
form_ids_subquery = select(HumanInputForm.id).where(
|
||||
HumanInputForm.tenant_id == scope.tenant_id,
|
||||
)
|
||||
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
|
||||
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
|
||||
session.execute(
|
||||
delete(ExecutionExtraContent).where(
|
||||
ExecutionExtraContent.workflow_run_id.in_(
|
||||
select(HumanInputForm.workflow_run_id).where(HumanInputForm.tenant_id == scope.tenant_id)
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(delete(HumanInputForm).where(HumanInputForm.tenant_id == scope.tenant_id))
|
||||
session.execute(delete(Message).where(Message.app_id == scope.app_id))
|
||||
session.execute(delete(Conversation).where(Conversation.app_id == scope.app_id))
|
||||
session.execute(delete(App).where(App.id == scope.app_id))
|
||||
session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == scope.tenant_id))
|
||||
session.execute(delete(Account).where(Account.id == scope.user_id))
|
||||
session.execute(delete(Tenant).where(Tenant.id == scope.tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _seed_base_entities(session: Session, scope: _TestScope) -> None:
|
||||
"""Create the base tenant, account, and app needed by tests."""
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
scope.tenant_id = tenant.id
|
||||
|
||||
account = Account(
|
||||
name="Test Account",
|
||||
email=f"test_{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
scope.user_id = account.id
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=scope.tenant_id,
|
||||
account_id=scope.user_id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
session.add(tenant_join)
|
||||
|
||||
app = App(
|
||||
tenant_id=scope.tenant_id,
|
||||
name="Test App",
|
||||
description="",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="bot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=scope.user_id,
|
||||
updated_by=scope.user_id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
scope.app_id = app.id
|
||||
|
||||
|
||||
def _create_conversation(session: Session, scope: _TestScope) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id=scope.app_id,
|
||||
mode="chat",
|
||||
name="Test Conversation",
|
||||
summary="",
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=scope.user_id,
|
||||
from_end_user_id=None,
|
||||
)
|
||||
conversation.inputs = {}
|
||||
session.add(conversation)
|
||||
session.flush()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_message(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
conversation_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=scope.app_id,
|
||||
conversation_id=conversation_id,
|
||||
inputs={},
|
||||
query="test query",
|
||||
message={"messages": []},
|
||||
answer="test answer",
|
||||
message_tokens=50,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=80,
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
provider_response_latency=0.5,
|
||||
currency="USD",
|
||||
from_source=ConversationFromSource.CONSOLE,
|
||||
from_account_id=scope.user_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
session.add(message)
|
||||
session.flush()
|
||||
return message
|
||||
|
||||
|
||||
def _create_submitted_form(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
action_id: str = "approve",
|
||||
action_title: str = "Approve",
|
||||
node_title: str = "Approval",
|
||||
) -> HumanInputForm:
|
||||
expiration_time = datetime.utcnow() + timedelta(days=1)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id=action_id, title=action_title)],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
node_title=node_title,
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id="node-id",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content=f"Rendered {action_title}",
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=expiration_time,
|
||||
selected_action_id=action_id,
|
||||
)
|
||||
session.add(form)
|
||||
session.flush()
|
||||
return form
|
||||
|
||||
|
||||
def _create_waiting_form(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
default_values: dict | None = None,
|
||||
) -> HumanInputForm:
|
||||
expiration_time = datetime.utcnow() + timedelta(days=1)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values=default_values or {"name": "John"},
|
||||
node_title="Approval",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id="node-id",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="Rendered block",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
session.add(form)
|
||||
session.flush()
|
||||
return form
|
||||
|
||||
|
||||
def _create_human_input_content(
|
||||
session: Session,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
message_id: str,
|
||||
form_id: str,
|
||||
) -> HumanInputContent:
|
||||
content = HumanInputContent.new(
|
||||
workflow_run_id=workflow_run_id,
|
||||
message_id=message_id,
|
||||
form_id=form_id,
|
||||
)
|
||||
session.add(content)
|
||||
return content
|
||||
|
||||
|
||||
def _create_recipient(
|
||||
session: Session,
|
||||
*,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
recipient_type: RecipientType = RecipientType.CONSOLE,
|
||||
access_token: str = "token-1",
|
||||
) -> HumanInputFormRecipient:
|
||||
payload = ConsoleRecipientPayload(account_id=None)
|
||||
recipient = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=recipient_type,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
session.add(recipient)
|
||||
return recipient
|
||||
|
||||
|
||||
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
|
||||
from dify_graph.nodes.human_input.enums import DeliveryMethodType
|
||||
from models.human_input import ConsoleDeliveryPayload
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
||||
)
|
||||
session.add(delivery)
|
||||
session.flush()
|
||||
return delivery
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(db_session_with_containers: Session) -> SQLAlchemyExecutionExtraContentRepository:
|
||||
"""Build a repository backed by the testcontainers database engine."""
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
return SQLAlchemyExecutionExtraContentRepository(sessionmaker(bind=engine, expire_on_commit=False))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_scope(db_session_with_containers: Session) -> Generator[_TestScope]:
|
||||
"""Provide an isolated scope and clean related data after each test."""
|
||||
scope = _TestScope()
|
||||
_seed_base_entities(db_session_with_containers, scope)
|
||||
db_session_with_containers.commit()
|
||||
yield scope
|
||||
_cleanup_scope_data(db_session_with_containers, scope)
|
||||
|
||||
|
||||
class TestGetByMessageIds:
|
||||
"""Tests for SQLAlchemyExecutionExtraContentRepository.get_by_message_ids."""
|
||||
|
||||
def test_groups_contents_by_message(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
repository: SQLAlchemyExecutionExtraContentRepository,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Submitted forms are correctly mapped and grouped by message ID."""
|
||||
workflow_run_id = str(uuid4())
|
||||
conversation = _create_conversation(db_session_with_containers, test_scope)
|
||||
msg1 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
|
||||
msg2 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
|
||||
|
||||
form = _create_submitted_form(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
workflow_run_id=workflow_run_id,
|
||||
action_id="approve",
|
||||
action_title="Approve",
|
||||
)
|
||||
_create_human_input_content(
|
||||
db_session_with_containers,
|
||||
workflow_run_id=workflow_run_id,
|
||||
message_id=msg1.id,
|
||||
form_id=form.id,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = repository.get_by_message_ids([msg1.id, msg2.id])
|
||||
|
||||
assert len(result) == 2
|
||||
# msg1 has one submitted content
|
||||
assert len(result[0]) == 1
|
||||
content = result[0][0]
|
||||
assert content.submitted is True
|
||||
assert content.workflow_run_id == workflow_run_id
|
||||
assert content.form_submission_data is not None
|
||||
assert content.form_submission_data.action_id == "approve"
|
||||
assert content.form_submission_data.action_text == "Approve"
|
||||
assert content.form_submission_data.rendered_content == "Rendered Approve"
|
||||
assert content.form_submission_data.node_id == "node-id"
|
||||
assert content.form_submission_data.node_title == "Approval"
|
||||
# msg2 has no content
|
||||
assert result[1] == []
|
||||
|
||||
def test_returns_unsubmitted_form_definition(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
repository: SQLAlchemyExecutionExtraContentRepository,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Waiting forms return full form_definition with resolved token and defaults."""
|
||||
workflow_run_id = str(uuid4())
|
||||
conversation = _create_conversation(db_session_with_containers, test_scope)
|
||||
msg = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
|
||||
|
||||
form = _create_waiting_form(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
workflow_run_id=workflow_run_id,
|
||||
default_values={"name": "John"},
|
||||
)
|
||||
delivery = _create_delivery(db_session_with_containers, form_id=form.id)
|
||||
_create_recipient(
|
||||
db_session_with_containers,
|
||||
form_id=form.id,
|
||||
delivery_id=delivery.id,
|
||||
access_token="token-1",
|
||||
)
|
||||
_create_human_input_content(
|
||||
db_session_with_containers,
|
||||
workflow_run_id=workflow_run_id,
|
||||
message_id=msg.id,
|
||||
form_id=form.id,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = repository.get_by_message_ids([msg.id])
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 1
|
||||
domain_content = result[0][0]
|
||||
assert domain_content.submitted is False
|
||||
assert domain_content.workflow_run_id == workflow_run_id
|
||||
assert domain_content.form_definition is not None
|
||||
form_def = domain_content.form_definition
|
||||
assert form_def.form_id == form.id
|
||||
assert form_def.node_id == "node-id"
|
||||
assert form_def.node_title == "Approval"
|
||||
assert form_def.form_content == "Rendered block"
|
||||
assert form_def.display_in_ui is True
|
||||
assert form_def.form_token == "token-1"
|
||||
assert form_def.resolved_default_values == {"name": "John"}
|
||||
assert form_def.expiration_time == int(form.expiration_time.timestamp())
|
||||
|
||||
def test_empty_message_ids_returns_empty_list(
|
||||
self,
|
||||
repository: SQLAlchemyExecutionExtraContentRepository,
|
||||
) -> None:
|
||||
"""Passing no message IDs returns an empty list without hitting the DB."""
|
||||
result = repository.get_by_message_ids([])
|
||||
assert result == []
|
||||
|
|
@ -13,6 +13,7 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
|
|
@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory:
|
|||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
doc_form="text_model",
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
document.id = document_id
|
||||
document.indexing_status = indexing_status
|
||||
|
|
|
|||
|
|
@ -525,3 +525,147 @@ class TestAPIBasedExtensionService:
|
|||
# Try to get extension with wrong tenant ID
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
|
||||
|
||||
def test_save_extension_api_key_exactly_four_chars_rejected(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""API key with exactly 4 characters should be rejected (boundary)."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key="1234",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_api_key_exactly_five_chars_accepted(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""API key with exactly 5 characters should be accepted (boundary)."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key="12345",
|
||||
)
|
||||
|
||||
saved = APIBasedExtensionService.save(extension_data)
|
||||
assert saved.id is not None
|
||||
|
||||
def test_save_extension_requestor_constructor_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Exception raised by requestor constructor is wrapped in ValueError."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config")
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="connection error: bad config"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_network_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Network exceptions during ping are wrapped in ValueError."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError(
|
||||
"network failure"
|
||||
)
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="connection error: network failure"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_update_duplicate_name_rejected(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Updating an existing extension to use another extension's name should fail."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
|
||||
ext1 = APIBasedExtensionService.save(
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Extension Alpha",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
)
|
||||
ext2 = APIBasedExtensionService.save(
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Extension Beta",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
)
|
||||
|
||||
# Try to rename ext2 to ext1's name
|
||||
ext2.name = "Extension Alpha"
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(ext2)
|
||||
|
||||
def test_get_all_returns_empty_for_different_tenant(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""Extensions from one tenant should not be visible to another."""
|
||||
fake = Faker()
|
||||
_, tenant1 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
_, tenant2 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant1 is not None
|
||||
|
||||
APIBasedExtensionService.save(
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant1.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
)
|
||||
|
||||
assert tenant2 is not None
|
||||
result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id)
|
||||
assert result == []
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
from models.model import App, Site
|
||||
from models.model import App, IconType, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
|
@ -463,6 +463,109 @@ class TestAppService:
|
|||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_should_preserve_icon_type_when_omitted(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update_app keeps the persisted icon_type when the update payload omits it.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app(
|
||||
app,
|
||||
{
|
||||
"name": "Updated App Name",
|
||||
"description": "Updated app description",
|
||||
"icon_type": None,
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FF8C42",
|
||||
"use_icon_as_answer_icon": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert updated_app.icon_type == IconType.EMOJI
|
||||
|
||||
def test_update_app_should_reject_empty_icon_type(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update_app rejects an explicit empty icon_type.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
with pytest.raises(ValueError):
|
||||
app_service.update_app(
|
||||
app,
|
||||
{
|
||||
"name": "Updated App Name",
|
||||
"description": "Updated app description",
|
||||
"icon_type": "",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FF8C42",
|
||||
"use_icon_as_answer_icon": True,
|
||||
},
|
||||
)
|
||||
|
||||
def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app name update.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
"""Testcontainers integration tests for AttachmentService."""
|
||||
|
||||
import base64
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.attachment_service as attachment_service_module
|
||||
from extensions.ext_database import db
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.attachment_service import AttachmentService
|
||||
|
||||
|
||||
class TestAttachmentService:
|
||||
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
storage_type=StorageType.OPENDAL,
|
||||
key=f"upload/{uuid4()}.txt",
|
||||
name="test-file.txt",
|
||||
size=100,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
return upload_file
|
||||
|
||||
def test_should_initialize_with_sessionmaker(self):
|
||||
session_factory = sessionmaker()
|
||||
|
||||
service = AttachmentService(session_factory=session_factory)
|
||||
|
||||
assert service._session_maker is session_factory
|
||||
|
||||
def test_should_initialize_with_engine(self):
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
service = AttachmentService(session_factory=engine)
|
||||
session = service._session_maker()
|
||||
try:
|
||||
assert session.bind == engine
|
||||
finally:
|
||||
session.close()
|
||||
engine.dispose()
|
||||
|
||||
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
|
||||
def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory):
|
||||
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
|
||||
AttachmentService(session_factory=invalid_session_factory)
|
||||
|
||||
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
|
||||
upload_file = self._create_upload_file(db_session_with_containers)
|
||||
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
|
||||
result = service.get_file_base64(upload_file.id)
|
||||
|
||||
assert result == base64.b64encode(b"binary-content").decode()
|
||||
mock_load.assert_called_once_with(upload_file.key)
|
||||
|
||||
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
|
||||
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
service.get_file_base64(str(uuid4()))
|
||||
|
||||
mock_load.assert_not_called()
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
"""Testcontainers integration tests for ConversationVariableUpdater."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from dify_graph.variables import StringVariable
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import ConversationVariable
|
||||
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
|
||||
|
||||
|
||||
class TestConversationVariableUpdater:
|
||||
def _create_conversation_variable(
|
||||
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
|
||||
) -> ConversationVariable:
|
||||
row = ConversationVariable(
|
||||
id=variable.id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id or str(uuid4()),
|
||||
data=variable.model_dump_json(),
|
||||
)
|
||||
db_session_with_containers.add(row)
|
||||
db_session_with_containers.commit()
|
||||
return row
|
||||
|
||||
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
|
||||
conversation_id = str(uuid4())
|
||||
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
|
||||
self._create_conversation_variable(
|
||||
db_session_with_containers, conversation_id=conversation_id, variable=variable
|
||||
)
|
||||
|
||||
updated_variable = StringVariable(id=variable.id, name="topic", value="new value")
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
updater.update(conversation_id=conversation_id, variable=updated_variable)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id))
|
||||
assert row is not None
|
||||
assert row.data == updated_variable.model_dump_json()
|
||||
|
||||
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
|
||||
conversation_id = str(uuid4())
|
||||
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
result = updater.flush()
|
||||
|
||||
assert result is None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue