Compare commits

...

77 Commits

Author SHA1 Message Date
Xiyuan Chen 085af2fdfd
Merge bf422dfd13 into d14635625c 2026-03-24 09:23:41 +00:00
yyh d14635625c
feat(web): refactor pricing modal scrolling and accessibility (#34011)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-24 17:18:36 +08:00
Stephen Zhou 0c3d11f920
refactor: lazy load large modules (#33888) 2026-03-24 15:29:42 +08:00
GareArc bf422dfd13
feat(telemetry): add enterprise OTEL telemetry with gateway, traces, metrics, and logs
Enterprise-only observability layer that emits OpenTelemetry traces, metrics,
and logs for workflow executions, LLM calls, message processing, and node runs.

Key components:
- core/telemetry: CE-safe gateway and event facade (no-ops when EE disabled)
- enterprise/telemetry: full trace/metric/log pipeline with OTLP exporter
- extensions/ext_enterprise_telemetry: Flask extension for lifecycle management
- tasks/enterprise_telemetry_task: Celery task for async metric dispatch

Wiring changes:
- app_factory: register ext_enterprise_telemetry extension
- ext_celery: conditionally import enterprise_telemetry_task
- configs/enterprise: add EnterpriseTelemetryConfig with OTLP settings
- ops_trace_manager: add enterprise telemetry hooks with lazy imports
- trace_entity: add new trace types (WorkflowNode, PromptGeneration, DraftNode)
- ext_otel: improve gRPC TLS auto-detection and optional auth headers

All enterprise imports are guarded behind ENTERPRISE_ENABLED +
ENTERPRISE_TELEMETRY_ENABLED feature flags with graceful fallbacks.
2026-03-24 00:24:20 -07:00
QuantumGhost 1674f8c2fb
fix: fix omitted app icon_type updates (#33988)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 15:10:05 +08:00
Zhanyuan Guo 7fe25f1365
fix(rate_limit): flush redis cache when __init__ is triggered by changing max_active_requests (#33830)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 15:08:55 +08:00
Wu Tianwei 508350ec6a
test: enhance useChat hook tests with additional scenarios (#33928)
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-24 14:19:32 +08:00
yyh b0920ecd17
refactor(web): migrate plugin toast usage to new UI toast API and update tests (#34001)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 14:02:52 +08:00
tmimmanuel 8b634a9bee
refactor: use EnumText for ApiToolProvider.schema_type_str and Docume… (#33983) 2026-03-24 13:27:50 +09:00
BitToby ecd3a964c1
refactor(api): type auth service credentials with TypedDict (#33867) 2026-03-24 13:22:17 +09:00
yyh 0589fa423b
fix(sdk): patch flatted vulnerability in nodejs client lockfile (#33996) 2026-03-24 11:24:31 +08:00
Stephen Zhou 27c4faad4f
ci: update actions version, fix cache (#33950) 2026-03-24 10:52:27 +08:00
wangxiaolei fbd558762d
fix: fix chunk not display in indexed document (#33942) 2026-03-24 10:36:48 +08:00
yyh 075b8bf1ae
fix(web): update account settings header (#33965) 2026-03-24 10:04:08 +08:00
Desel72 49a1fae555
test: migrate password reset tests to testcontainers (#33974)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 06:04:34 +09:00
tmimmanuel cc17c8e883
refactor: use EnumText for TidbAuthBinding.status and MessageFile.type (#33975) 2026-03-24 05:38:29 +09:00
tmimmanuel 5d2cb3cd80
refactor: use EnumText for DocumentSegment.type (#33979) 2026-03-24 05:37:51 +09:00
Desel72 f2c71f3668
test: migrate oauth server service tests to testcontainers (#33958) 2026-03-24 03:15:22 +09:00
Desel72 0492ed7034
test: migrate api tools manage service tests to testcontainers (#33956)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 02:54:33 +09:00
Renzo dd4f504b39
refactor: select in remaining console app controllers (#33969)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 02:53:05 +09:00
tmimmanuel 75c3ef82d9
refactor: use EnumText for TenantCreditPool.pool_type (#33959) 2026-03-24 02:51:10 +09:00
Desel72 8ca1ebb96d
test: migrate workflow tools manage service tests to testcontainers (#33955)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-24 02:50:10 +09:00
Desel72 3f086b97b6
test: remove mock tests superseded by testcontainers (#33957) 2026-03-24 02:46:54 +09:00
tmimmanuel 4a2e9633db
refactor: use EnumText for ApiToken.type (#33961) 2026-03-24 02:46:06 +09:00
tmimmanuel 20fc69ae7f
refactor: use EnumText for WorkflowAppLog.created_from and WorkflowArchiveLog columns (#33954) 2026-03-24 02:44:46 +09:00
Desel72 f5cc1c8b75
test: migrate saved message service tests to testcontainers (#33949)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 22:26:31 +09:00
Desel72 6698b42f97
test: migrate api based extension service tests to testcontainers (#33952) 2026-03-23 22:20:53 +09:00
Desel72 848a041c25
test: migrate dataset service create dataset tests to testcontainers (#33945) 2026-03-23 22:20:25 +09:00
Baki Burak Öğün 29cff809b9
fix(i18n): comprehensive Turkish (tr-TR) translation fixes and missing keys (#33885)
Co-authored-by: bakiburakogun <bakiburakogun@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Baki Burak Öğün <b.burak.ogun@goc.local>
2026-03-23 21:19:53 +08:00
kurokobo 30deeb6f1c
feat(firecrawl): follow pagination when crawl status is completed (#33864)
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-03-23 21:19:32 +08:00
Desel72 30dd36505c
test: migrate batch update document status tests to testcontainers (#33951) 2026-03-23 21:57:01 +09:00
Desel72 65223c8092
test: remove mock-based tests superseded by testcontainers (#33946) 2026-03-23 21:55:50 +09:00
Desel72 72e3fcd25f
test: migrate end user service batch tests to testcontainers (#33947) 2026-03-23 21:54:37 +09:00
Desel72 4b4a5c058e
test: migrate file service zip and lookup tests to testcontainers (#33944) 2026-03-23 21:52:31 +09:00
letterbeezps 56e0907548
fix: do not block upsert for baidu vdb (#33280)
Co-authored-by: zhangping24 <zhangping24@baidu.com>
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 20:42:57 +08:00
Asuka Minato d956b919a0
ci: fix AttributeError: 'Flask' object has no attribute 'login_manager' FAILED #33891 (#33896)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 20:27:14 +08:00
Coding On Star 8b6fc07019
test(workflow): improve dataset item tests with edit and remove functionality (#33937) 2026-03-23 20:16:59 +08:00
wangxiaolei 1b1df37d23
feat: squid force ipv4 (#33556)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 17:56:19 +08:00
Desel72 6be7ba2928
refactor(web): replace MediaType enum with const object (#33834) 2026-03-23 17:53:55 +08:00
Coding On Star 2c8322c7b9
feat: enhance banner tracking with impression and click events (#33926)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-23 17:29:50 +08:00
Coding On Star fdc880bc67
test(workflow): add unit tests for workflow components (#33910)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-23 16:37:03 +08:00
Desel72 abda859075
refactor: migrate execution extra content repository tests from mocks to testcontainers (#33852) 2026-03-23 17:32:11 +09:00
yyh dc1a68661c
refactor(web): migrate members invite overlays to base ui (#33922)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:31:41 +08:00
dependabot[bot] edb261bc90
chore(deps-dev): bump the dev group across 1 directory with 12 updates (#33919)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 17:26:47 +09:00
dependabot[bot] 407f5f0cde
chore(deps-dev): bump alibabacloud-gpdb20160503 from 3.8.3 to 5.1.0 in /api in the vdb group (#33879)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 17:25:44 +09:00
Bowen Liang d7cafc6296
chore(dep): move hono and @hono/node-server to devDependencies (#33742) 2026-03-23 16:22:33 +08:00
dependabot[bot] 9336935295
chore(deps-dev): bump the storage group across 1 directory with 2 updates (#33915)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 16:57:17 +09:00
Mahmoud Hamdy e5e8c0711c
refactor: rewrite docker/dify-env-sync.sh in Python for better maintainability (#33466)
Co-authored-by: 99 <wh2099@pm.me>
2026-03-23 15:56:00 +08:00
Renzo 02e13e6d05
refactor: select in console app message controller (#33893)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:38:04 +09:00
dependabot[bot] a942d4c926
chore(deps): bump the python-packages group in /api with 4 updates (#33873)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:33:31 +09:00
dependabot[bot] df69997d8e
chore(deps): bump google-cloud-aiplatform from 1.141.0 to 1.142.0 in /api in the google group across 1 directory (#33917)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 16:32:05 +09:00
dependabot[bot] 4ab7ba4f2e
chore(deps): bump the llm group across 1 directory with 2 updates (#33916)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 16:31:23 +09:00
Copilot 76a23deba7
fix: crash when dataset icon_info is undefined in Knowledge Retrieval node (#33907)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-03-23 15:29:03 +08:00
yyh 25a83065d2
refactor(web): remove legacy data-source settings (#33905)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 15:19:20 +08:00
Desel72 82b094a2d5
refactor: migrate attachment service tests to testcontainers (#33900)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:18:46 +09:00
wangxiaolei 3c672703bc
chore: remove log level reset (#33914) 2026-03-23 16:17:15 +09:00
dependabot[bot] 33000d1c60
chore(deps): bump pydantic-extra-types from 2.11.0 to 2.11.1 in /api in the pydantic group (#33876)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 16:13:45 +09:00
dependabot[bot] 2809e4cc40
chore(deps-dev): update pytest-cov requirement from ~=7.0.0 to ~=7.1.0 in /api in the dev group (#33872)d
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:12:23 +09:00
dependabot[bot] 3f8f1fa003
chore(deps): bump google-api-python-client from 2.192.0 to 2.193.0 in /api in the google group (#33868)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:11:32 +09:00
dependabot[bot] 6604f8d506
chore(deps): bump litellm from 1.82.2 to 1.82.6 in /api in the llm group (#33870)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:10:41 +09:00
dependabot[bot] 368fc0bbe5
chore(deps): bump boto3 from 1.42.68 to 1.42.73 in /api in the storage group (#33871)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:10:02 +09:00
Desel72 6014853d45
test: migrate dataset permission tests to testcontainers (#33906) 2026-03-23 16:07:51 +09:00
Desel72 a71b7909fd
refactor: migrate conversation variable updater tests to testcontainers (#33903) 2026-03-23 16:06:08 +09:00
Desel72 1bf296982b
refactor: migrate workflow deletion tests to testcontainers (#33904)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 16:04:47 +09:00
tmimmanuel 2b6f761dfe
refactor: use EnumText for Conversation/Message invoke_from and from_source (#33901) 2026-03-23 16:03:35 +09:00
Desel72 6ecf89e262
refactor: migrate credit pool service tests to testcontainers (#33898)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-23 15:59:16 +09:00
Bipin Rimal e844edcf26
docs: EU AI Act compliance guide for Dify deployers (#33838) 2026-03-23 14:58:51 +08:00
Copilot 244f9e0c11
fix: handle null email/name from GitHub API for private-email users (#33882)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
2026-03-23 14:53:03 +08:00
github-actions[bot] abd68d2ea6
chore(i18n): sync translations with en-US (#33894)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-03-23 14:05:47 +08:00
wangxiaolei 01d97fa2cf
fix: type object 'str' has no attribute 'LLM' (#33899) 2026-03-23 14:51:56 +09:00
yyh 0478023900
refactor(web): migrate dataset-related toast callsites to base/ui/toast and update tests (#33892) 2026-03-23 13:13:52 +08:00
enci 110b8c925e
fix: remove contradictory optional chain in chat/utils.ts (#33841)
Co-authored-by: yoloni <yoloni@tencent.com>
2026-03-23 10:58:10 +08:00
Stephen Zhou eae821d645
chore: update deps (#33862) 2026-03-23 10:54:01 +08:00
Bowen Liang 282e76b1ee
feat(build): set root directory for turbopack configuration (#33812) 2026-03-23 10:04:53 +08:00
Bowen Liang 8384a836b4
fix(i18n): standardize datetime display to 24-hour format on /apps page (#33847) 2026-03-23 10:04:01 +08:00
hj24 886854eff8
chore: add guard tests for billing (#33831)
Co-authored-by: 非法操作 <hjlarry@163.com>
2026-03-23 09:45:32 +08:00
dependabot[bot] 6a8fa7b54e
chore(deps): bump anthropics/claude-code-action from 1.0.75 to 1.0.76 in the github-actions-dependencies group (#33875)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 10:22:44 +09:00
466 changed files with 30469 additions and 14053 deletions

View File

@ -4,10 +4,9 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
with:
node-version-file: web/.nvmrc
working-directory: web
node-version-file: .nvmrc
cache: true
cache-dependency-path: web/pnpm-lock.yaml
run-install: |
cwd: ./web
run-install: true

View File

@ -84,20 +84,20 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
vp run lint:ci
# pnpm run lint:report
# continue-on-error: true
# - name: Annotate Code
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
# with:
# eslint-report: web/eslint_report.json
# github-token: ${{ secrets.GITHUB_TOKEN }}
run: vp run lint:ci
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
@ -114,6 +114,13 @@ jobs:
working-directory: ./web
run: vp run knip
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -353,6 +353,9 @@ BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
# Upstash configuration
UPSTASH_VECTOR_URL=your-server-url

View File

@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_enterprise_telemetry,
ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_fastopenapi,
ext_otel,
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
]

View File

@ -10,6 +10,7 @@ from configs import dify_config
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
@ -269,7 +270,7 @@ def migrate_knowledge_vector_database():
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == "hierarchical_model":
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []

View File

@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
from libs.file_utils import search_file_upwards
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
@ -73,6 +73,8 @@ class DifyConfig(
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
# Enterprise telemetry configs
EnterpriseTelemetryConfig,
):
model_config = SettingsConfigDict(
# read from dotenv format config file

View File

@ -22,3 +22,49 @@ class EnterpriseFeatureConfig(BaseSettings):
ENTERPRISE_REQUEST_TIMEOUT: int = Field(
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
)
class EnterpriseTelemetryConfig(BaseSettings):
"""
Configuration for enterprise telemetry.
"""
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
default=False,
)
ENTERPRISE_OTLP_ENDPOINT: str = Field(
description="Enterprise OTEL collector endpoint.",
default="",
)
ENTERPRISE_OTLP_HEADERS: str = Field(
description="Auth headers for OTLP export (key=value,key2=value2).",
default="",
)
ENTERPRISE_OTLP_PROTOCOL: str = Field(
description="OTLP protocol: 'http' or 'grpc' (default: http).",
default="http",
)
ENTERPRISE_OTLP_API_KEY: str = Field(
description="Bearer token for enterprise OTLP export authentication.",
default="",
)
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
description="Include input/output content in traces (privacy toggle).",
default=True,
)
ENTERPRISE_SERVICE_NAME: str = Field(
description="Service name for OTEL resource.",
default="dify",
)
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
default=1.0,
)

View File

@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings):
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
default="COARSE_MODE",
)
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field(
description="Auto build row count increment threshold (default is 500)",
default=500,
)
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field(
description="Auto build row count increment ratio threshold (default is 0.05)",
default=0.05,
)
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field(
description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)",
default=300,
)

View File

@ -9,6 +9,7 @@ from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
resource_type: ApiTokenType | None = None
resource_model: type | None = None
resource_id_field: str | None = None
token_prefix: str | None = None
@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource):
)
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
assert self.resource_type is not None, "resource_type must be set"
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_tenant_id
@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource):
class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
resource_type: str | None = None
resource_type: ApiTokenType | None = None
resource_model: type | None = None
resource_id_field: str | None = None
@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app"""
return super().post(resource_id)
resource_type = "app"
resource_type = ApiTokenType.APP
resource_model = App
resource_id_field = "app_id"
token_prefix = "app-"
@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
resource_type = "app"
resource_type = ApiTokenType.APP
resource_model = App
resource_id_field = "app_id"
@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset"""
return super().post(resource_id)
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
resource_model = Dataset
resource_id_field = "dataset_id"
token_prefix = "ds-"
@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
resource_model = Dataset
resource_id_field = "dataset_id"

View File

@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel):
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel):
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -594,7 +594,7 @@ class AppApi(Resource):
args_dict: AppService.ArgsDict = {
"name": args.name,
"description": args.description or "",
"icon_type": args.icon_type or "",
"icon_type": args.icon_type,
"icon": args.icon or "",
"icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,

View File

@ -458,9 +458,7 @@ class ChatConversationApi(Resource):
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
subquery = (
db.session.query(
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
)
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery()
)
@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
)
if not conversation:

View File

@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
app = db.session.get(App, args.flow_id)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)

View File

@ -2,6 +2,7 @@ import json
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
return server
@console_ns.doc("create_app_mcp_server")
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
server = db.session.get(AppMCPServer, payload.id)
if not server:
raise NotFound()
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
server = db.session.scalar(
select(AppMCPServer)
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
.limit(1)
)
if not server:
raise NotFound()

View File

@ -4,7 +4,7 @@ from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -244,27 +244,25 @@ class ChatMessageListApi(Resource):
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = (
db.session.query(Conversation)
conversation = db.session.scalar(
select(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first()
.limit(1)
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args.first_id:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
first_message = db.session.scalar(
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
)
if not first_message:
raise NotFound("First message not found")
history_messages = (
db.session.query(Message)
history_messages = db.session.scalars(
select(Message)
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
@ -272,16 +270,14 @@ class ChatMessageListApi(Resource):
)
.order_by(Message.created_at.desc())
.limit(args.limit)
.all()
)
).all()
else:
history_messages = (
db.session.query(Message)
history_messages = db.session.scalars(
select(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args.limit)
.all()
)
).all()
# Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit:
@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource):
message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")
@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
count = db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
)
return {"count": count}
@ -479,7 +479,9 @@ class MessageApi(Resource):
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")

View File

@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
)
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
if original_app_model_config is None:
raise ValueError("Original app model config not found")
agent_mode = original_app_model_config.agent_mode_dict

View File

@ -2,6 +2,7 @@ from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
@ -75,7 +76,7 @@ class AppSite(Resource):
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound

View File

@ -2,6 +2,8 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from sqlalchemy import select
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
return app_model

View File

@ -1,4 +1,5 @@
import logging
import urllib.parse
import httpx
from flask import current_app, redirect, request
@ -112,6 +113,9 @@ class OAuthCallback(Resource):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400
except ValueError as e:
logger.warning("OAuth error with %s", provider, exc_info=True)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService.get_invitation_by_token(token=invite_token)

View File

@ -54,7 +54,7 @@ from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import SegmentStatus
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -777,7 +777,7 @@ class DatasetIndexingStatusApi(Resource):
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = "dataset-"
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@ -826,7 +826,7 @@ class DatasetApiKeyApi(Resource):
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
class DatasetApiDeleteApi(Resource):
resource_type = "dataset"
resource_type = ApiTokenType.DATASET
@console_ns.doc("delete_dataset_api_key")
@console_ns.doc(description="Delete dataset API key")

View File

@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
app_id=self._application_generate_entity.app_config.app_id,
workflow_id=self._workflow.id,
workflow_run_id=workflow_run_id,
created_from=created_from.value,
created_from=created_from,
created_by_role=self._created_by_role,
created_by=self._user_id,
)

View File

@ -19,6 +19,7 @@ class RateLimit:
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
max_active_requests: int
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
@ -27,7 +28,13 @@ class RateLimit:
return cls._instance_dict[client_id]
def __init__(self, client_id: str, max_active_requests: int):
flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests
self.max_active_requests = max_active_requests
# Only flush here if this instance has already been fully initialized,
# i.e. the Redis key attributes exist. Otherwise, rely on the flush at
# the end of initialization below.
if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"):
self.flush_cache(use_local_value=True)
# must be called after max_active_requests is set
if self.disabled():
return
@ -41,8 +48,6 @@ class RateLimit:
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
if self.disabled():
return
self.last_recalculate_time = time.time()
# flush max active requests
if use_local_value or not redis_client.exists(self.max_active_requests_key):
@ -50,7 +55,8 @@ class RateLimit:
else:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
if self.disabled():
return
# flush max active requests (in-transit request list)
if not redis_client.exists(self.active_requests_key):
return

View File

@ -181,10 +181,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
arize_phoenix_config: ArizeConfig | PhoenixConfig,
):
super().__init__(arize_phoenix_config)
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
self.arize_phoenix_config = arize_phoenix_config
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project

View File

@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
class BaseTraceInfo(BaseModel):
message_id: str | None = None
message_data: Any | None = None
inputs: Union[str, dict[str, Any], list] | None = None
outputs: Union[str, dict[str, Any], list] | None = None
inputs: Union[str, dict[str, Any], list[Any]] | None = None
outputs: Union[str, dict[str, Any], list[Any]] | None = None
start_time: datetime | None = None
end_time: datetime | None = None
metadata: dict[str, Any]
@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
@field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v):
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
if v is None:
return None
if isinstance(v, str | dict | list):
@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@property
def resolved_trace_id(self) -> str | None:
"""Get trace_id with intelligent fallback.
Priority:
1. External trace_id (from X-Trace-Id header)
2. workflow_run_id (if this trace type has it)
3. message_id (as final fallback)
"""
if self.trace_id:
return self.trace_id
# Try workflow_run_id (only exists on workflow-related traces)
workflow_run_id = getattr(self, "workflow_run_id", None)
if workflow_run_id:
return workflow_run_id
# Final fallback to message_id
return str(self.message_id) if self.message_id else None
@property
def resolved_parent_context(self) -> tuple[str | None, str | None]:
"""Resolve cross-workflow parent linking from metadata.
Extracts typed parent IDs from the untyped ``parent_trace_context``
metadata dict (set by tool_node when invoking nested workflows).
Returns:
(trace_correlation_override, parent_span_id_source) where
trace_correlation_override is the outer workflow_run_id and
parent_span_id_source is the outer node_execution_id.
"""
parent_ctx = self.metadata.get("parent_trace_context")
if not isinstance(parent_ctx, dict):
return None, None
trace_override = parent_ctx.get("parent_workflow_run_id")
parent_span = parent_ctx.get("parent_node_execution_id")
return (
trace_override if isinstance(trace_override, str) else None,
parent_span if isinstance(parent_span, str) else None,
)
@field_serializer("start_time", "end_time")
def serialize_datetime(self, dt: datetime | None) -> str | None:
if dt is None:
@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo):
workflow_run_version: str
error: str | None = None
total_tokens: int
prompt_tokens: int | None = None
completion_tokens: int | None = None
file_list: list[str]
invoked_by: str | None = None
query: str
metadata: dict[str, Any]
@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo):
answer_tokens: int
total_tokens: int
error: str | None = None
file_list: Union[str, dict[str, Any], list] | None = None
file_list: Union[str, dict[str, Any], list[Any]] | None = None
message_file_data: Any | None = None
conversation_mode: str
gen_ai_server_time_to_first_token: float | None = None
@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo):
tool_config: dict[str, Any]
time_cost: Union[int, float]
tool_parameters: dict[str, Any]
file_url: Union[str, None, list] = None
file_url: Union[str, None, list[str]] = None
class GenerateNameTraceInfo(BaseTraceInfo):
@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo):
tenant_id: str
class PromptGenerationTraceInfo(BaseTraceInfo):
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
tenant_id: str
user_id: str
app_id: str | None = None
operation_type: str
instruction: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
model_provider: str
model_name: str
latency: float
total_price: float | None = None
currency: str | None = None
error: str | None = None
model_config = ConfigDict(protected_namespaces=())
class WorkflowNodeTraceInfo(BaseTraceInfo):
workflow_id: str
workflow_run_id: str
tenant_id: str
node_execution_id: str
node_id: str
node_type: str
title: str
status: str
error: str | None = None
elapsed_time: float
index: int
predecessor_node_id: str | None = None
total_tokens: int = 0
total_price: float = 0.0
currency: str | None = None
model_provider: str | None = None
model_name: str | None = None
prompt_tokens: int | None = None
completion_tokens: int | None = None
tool_name: str | None = None
iteration_id: str | None = None
iteration_index: int | None = None
loop_id: str | None = None
loop_index: int | None = None
parallel_id: str | None = None
node_inputs: Mapping[str, Any] | None = None
node_outputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
invoked_by: str | None = None
model_config = ConfigDict(protected_namespaces=())
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
pass
class TaskData(BaseModel):
app_id: str
trace_info_type: str
@ -128,11 +246,31 @@ trace_info_info_map = {
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
"ToolTraceInfo": ToolTraceInfo,
"GenerateNameTraceInfo": GenerateNameTraceInfo,
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
}
class OperationType(StrEnum):
"""Operation type for token metric labels.
Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output``
counters so consumers can break down token usage by operation.
"""
WORKFLOW = "workflow"
NODE_EXECUTION = "node_execution"
MESSAGE = "message"
RULE_GENERATE = "rule_generate"
CODE_GENERATE = "code_generate"
STRUCTURED_OUTPUT = "structured_output"
INSTRUCTION_MODIFY = "instruction_modify"
class TraceTaskName(StrEnum):
CONVERSATION_TRACE = "conversation"
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
WORKFLOW_TRACE = "workflow"
MESSAGE_TRACE = "message"
MODERATION_TRACE = "moderation"
@ -140,4 +278,6 @@ class TraceTaskName(StrEnum):
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
PROMPT_GENERATION_TRACE = "prompt_generation"
NODE_EXECUTION_TRACE = "node_execution"
DATASOURCE_TRACE = "datasource"

View File

@ -15,22 +15,32 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.engine import db
from models.account import Tenant
from models.dataset import Dataset
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from models.workflow import WorkflowAppLog
from tasks.ops_trace_task import process_trace_tasks
@ -40,9 +50,142 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
app_name = ""
workspace_name = ""
if not app_id and not tenant_id:
return app_name, workspace_name
with Session(db.engine) as session:
if app_id:
name = session.scalar(select(App.name).where(App.id == app_id))
if name:
app_name = name
if tenant_id:
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
if name:
workspace_name = name
return app_name, workspace_name
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
"builtin": BuiltinToolProvider,
"plugin": BuiltinToolProvider,
"api": ApiToolProvider,
"workflow": WorkflowToolProvider,
"mcp": MCPToolProvider,
}
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
if not credential_id:
return ""
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
if not model_cls:
return ""
with Session(db.engine) as session:
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined]
return str(name) if name else ""
def _lookup_llm_credential_info(
tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm"
) -> tuple[str | None, str]:
"""
Lookup LLM credential ID and name for the given provider and model.
Returns (credential_id, credential_name).
Handles async timing issues gracefully - if credential is deleted between lookups,
returns the ID but empty name rather than failing.
"""
if not tenant_id or not provider:
return None, ""
try:
with Session(db.engine) as session:
# Try to find provider-level or model-level configuration
provider_record = session.scalar(
select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.CUSTOM,
)
)
if not provider_record:
return None, ""
# Check if there's a model-specific config
credential_id = None
credential_name = ""
is_model_level = False
if model:
# Try model-level first
model_record = session.scalar(
select(ProviderModel).where(
ProviderModel.tenant_id == tenant_id,
ProviderModel.provider_name == provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type,
)
)
if model_record and model_record.credential_id:
credential_id = model_record.credential_id
is_model_level = True
if not credential_id and provider_record.credential_id:
# Fall back to provider-level credential
credential_id = provider_record.credential_id
is_model_level = False
# Lookup credential_name if we have credential_id
if credential_id:
try:
if is_model_level:
# Query ProviderModelCredential
cred_name = session.scalar(
select(ProviderModelCredential.credential_name).where(
ProviderModelCredential.id == credential_id
)
)
else:
# Query ProviderCredential
cred_name = session.scalar(
select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id)
)
if cred_name:
credential_name = str(cred_name)
except Exception as e:
# Credential might have been deleted between lookups (async timing)
# Return ID but empty name rather than failing
logger.warning(
"Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s",
credential_id,
provider,
model,
str(e),
)
return credential_id, credential_name
except Exception as e:
# Database query failed or other unexpected error
# Return empty rather than propagating error to telemetry emission
logger.warning(
"Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s",
tenant_id,
provider,
model,
str(e),
)
return None, ""
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, key: str) -> dict[str, Any]:
match key:
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
@ -149,7 +292,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
}
case _:
raise KeyError(f"Unsupported tracing provider: {key}")
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map = OpsTraceProviderConfigMap()
@ -314,6 +457,10 @@ class OpsTraceManager:
if app_id is None:
return None
# Handle storage_id format (tenant-{uuid}) - not a real app_id
if isinstance(app_id, str) and app_id.startswith("tenant-"):
return None
app: App | None = db.session.query(App).where(App.id == app_id).first()
if app is None:
@ -466,8 +613,6 @@ class TraceTask:
@classmethod
def _get_workflow_run_repo(cls):
from repositories.factory import DifyAPIRepositoryFactory
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
@ -478,6 +623,56 @@ class TraceTask:
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
@classmethod
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
"""Extract user ID from metadata, prioritizing end_user over account.
Returns the actual user ID (end_user or account) who invoked the workflow,
regardless of invoke_from context.
"""
# Priority 1: End user (external users via API/WebApp)
if user_id := metadata.get("from_end_user_id"):
return f"end_user:{user_id}"
# Priority 2: Account user (internal users via console/debugger)
if user_id := metadata.get("from_account_id"):
return f"account:{user_id}"
# Priority 3: User (internal users via console/debugger)
if user_id := metadata.get("user_id"):
return f"user:{user_id}"
return "anonymous"
@classmethod
def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]:
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from models.workflow import WorkflowNodeExecutionModel
with Session(db.engine) as session:
node_executions = session.scalars(
select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
).all()
total_prompt = 0
total_completion = 0
for node_exec in node_executions:
metadata = node_exec.execution_metadata_dict
prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS)
if prompt is not None:
total_prompt += prompt
completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS)
if completion is not None:
total_completion += completion
return (total_prompt, total_completion)
def __init__(
self,
trace_type: Any,
@ -498,6 +693,8 @@ class TraceTask:
self.app_id = None
self.trace_id = None
self.kwargs = kwargs
if user_id is not None and "user_id" not in self.kwargs:
self.kwargs["user_id"] = user_id
external_trace_id = kwargs.get("external_trace_id")
if external_trace_id:
self.trace_id = external_trace_id
@ -511,7 +708,7 @@ class TraceTask:
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
@ -527,6 +724,9 @@ class TraceTask:
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
),
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
}
return preprocess_map.get(self.trace_type, lambda: None)()
@ -562,6 +762,10 @@ class TraceTask:
total_tokens = workflow_run.total_tokens
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
workflow_run_id=workflow_run_id, tenant_id=tenant_id
)
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
@ -582,7 +786,14 @@ class TraceTask:
)
message_id = session.scalar(message_data_stmt)
metadata = {
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
metadata: dict[str, Any] = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
@ -595,8 +806,14 @@ class TraceTask:
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
parent_trace_context = self.kwargs.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
@ -611,6 +828,8 @@ class TraceTask:
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
file_list=file_list,
query=query,
metadata=metadata,
@ -618,10 +837,11 @@ class TraceTask:
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
invoked_by=self._get_user_id_from_metadata(metadata),
)
return workflow_trace_info
def message_trace(self, message_id: str | None):
def message_trace(self, message_id: str | None, **kwargs):
if not message_id:
return {}
message_data = get_message_data(message_id)
@ -644,6 +864,19 @@ class TraceTask:
streaming_metrics = self._extract_streaming_metrics(message_data)
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
metadata = {
"conversation_id": message_data.conversation_id,
"ls_provider": message_data.model_provider,
@ -655,7 +888,14 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"message_id": message_id,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
message_tokens = message_data.message_tokens
@ -672,7 +912,9 @@ class TraceTask:
outputs=message_data.answer,
file_list=file_list,
start_time=created_at,
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
end_time=message_data.updated_at
if message_data.updated_at and message_data.updated_at > created_at
else created_at + timedelta(seconds=message_data.provider_response_latency),
metadata=metadata,
message_file_data=message_file_data,
conversation_mode=conversation_mode,
@ -697,6 +939,8 @@ class TraceTask:
"preset_response": moderation_result.preset_response,
"query": moderation_result.query,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@ -738,6 +982,8 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@ -777,6 +1023,52 @@ class TraceTask:
if not message_data:
return {}
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
doc_list = [doc.model_dump() for doc in documents] if documents else []
dataset_ids: set[str] = set()
for doc in doc_list:
doc_meta = doc.get("metadata") or {}
did = doc_meta.get("dataset_id")
if did:
dataset_ids.add(did)
embedding_models: dict[str, dict[str, str]] = {}
if dataset_ids:
with Session(db.engine) as session:
rows = session.execute(
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
Dataset.id.in_(list(dataset_ids))
)
).all()
for row in rows:
embedding_models[str(row[0])] = {
"embedding_model": row[1] or "",
"embedding_model_provider": row[2] or "",
}
# Extract rerank model info from retrieval_model kwargs
rerank_model_provider = ""
rerank_model_name = ""
if "retrieval_model" in kwargs:
retrieval_model = kwargs["retrieval_model"]
if isinstance(retrieval_model, dict):
reranking_model = retrieval_model.get("reranking_model")
if isinstance(reranking_model, dict):
rerank_model_provider = reranking_model.get("reranking_provider_name", "")
rerank_model_name = reranking_model.get("reranking_model_name", "")
metadata = {
"message_id": message_id,
"ls_provider": message_data.model_provider,
@ -787,13 +1079,23 @@ class TraceTask:
"agent_based": message_data.agent_based,
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
"embedding_models": embedding_models,
"rerank_model_provider": rerank_model_provider,
"rerank_model_name": rerank_model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents] if documents else [],
documents=doc_list,
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
@ -836,6 +1138,10 @@ class TraceTask:
"error": error,
"tool_parameters": tool_parameters,
}
if message_data.workflow_run_id:
metadata["workflow_run_id"] = message_data.workflow_run_id
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
file_url = ""
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
@ -890,6 +1196,8 @@ class TraceTask:
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
generate_name_trace_info = GenerateNameTraceInfo(
trace_id=self.trace_id,
@ -904,6 +1212,182 @@ class TraceTask:
return generate_name_trace_info
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
tenant_id = kwargs.get("tenant_id", "")
user_id = kwargs.get("user_id", "")
app_id = kwargs.get("app_id")
operation_type = kwargs.get("operation_type", "")
instruction = kwargs.get("instruction", "")
generated_output = kwargs.get("generated_output", "")
prompt_tokens = kwargs.get("prompt_tokens", 0)
completion_tokens = kwargs.get("completion_tokens", 0)
total_tokens = kwargs.get("total_tokens", 0)
model_provider = kwargs.get("model_provider", "")
model_name = kwargs.get("model_name", "")
latency = kwargs.get("latency", 0.0)
timer = kwargs.get("timer")
start_time = timer.get("start") if timer else None
end_time = timer.get("end") if timer else None
total_price = kwargs.get("total_price")
currency = kwargs.get("currency")
error = kwargs.get("error")
app_name = None
workspace_name = None
if app_id:
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
metadata = {
"tenant_id": tenant_id,
"user_id": user_id,
"app_id": app_id or "",
"app_name": app_name,
"workspace_name": workspace_name,
"operation_type": operation_type,
"model_provider": model_provider,
"model_name": model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
return PromptGenerationTraceInfo(
trace_id=self.trace_id,
inputs=instruction,
outputs=generated_output,
start_time=start_time,
end_time=end_time,
metadata=metadata,
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=operation_type,
instruction=instruction,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model_provider=model_provider,
model_name=model_name,
latency=latency,
total_price=total_price,
currency=currency,
error=error,
)
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
node_data: dict = kwargs.get("node_execution_data", {})
if not node_data:
return {}
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(
node_data.get("app_id"), node_data.get("tenant_id")
)
else:
app_name, workspace_name = "", ""
# Try tool credential lookup first
credential_id = node_data.get("credential_id")
if is_enterprise_telemetry_enabled():
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
if not credential_id:
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
tenant_id=node_data.get("tenant_id"),
provider=node_data.get("model_provider"),
model=node_data.get("model_name"),
model_type="llm",
)
if llm_cred_id:
credential_id = llm_cred_id
credential_name = llm_cred_name
else:
credential_name = ""
metadata: dict[str, Any] = {
"tenant_id": node_data.get("tenant_id"),
"app_id": node_data.get("app_id"),
"app_name": app_name,
"workspace_name": workspace_name,
"user_id": node_data.get("user_id"),
"invoke_from": node_data.get("invoke_from"),
"credential_id": credential_id,
"credential_name": credential_name,
"dataset_ids": node_data.get("dataset_ids"),
"dataset_names": node_data.get("dataset_names"),
"plugin_name": node_data.get("plugin_name"),
}
parent_trace_context = node_data.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
message_id: str | None = None
conversation_id = node_data.get("conversation_id")
workflow_execution_id = node_data.get("workflow_execution_id")
if conversation_id and workflow_execution_id and not parent_trace_context:
with Session(db.engine) as session:
msg_id = session.scalar(
select(Message.id).where(
Message.conversation_id == conversation_id,
Message.workflow_run_id == workflow_execution_id,
)
)
if msg_id:
message_id = str(msg_id)
metadata["message_id"] = message_id
if conversation_id:
metadata["conversation_id"] = conversation_id
return WorkflowNodeTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
start_time=node_data.get("created_at"),
end_time=node_data.get("finished_at"),
metadata=metadata,
workflow_id=node_data.get("workflow_id", ""),
workflow_run_id=node_data.get("workflow_execution_id", ""),
tenant_id=node_data.get("tenant_id", ""),
node_execution_id=node_data.get("node_execution_id", ""),
node_id=node_data.get("node_id", ""),
node_type=node_data.get("node_type", ""),
title=node_data.get("title", ""),
status=node_data.get("status", ""),
error=node_data.get("error"),
elapsed_time=node_data.get("elapsed_time", 0.0),
index=node_data.get("index", 0),
predecessor_node_id=node_data.get("predecessor_node_id"),
total_tokens=node_data.get("total_tokens", 0),
total_price=node_data.get("total_price", 0.0),
currency=node_data.get("currency"),
model_provider=node_data.get("model_provider"),
model_name=node_data.get("model_name"),
prompt_tokens=node_data.get("prompt_tokens"),
completion_tokens=node_data.get("completion_tokens"),
tool_name=node_data.get("tool_name"),
iteration_id=node_data.get("iteration_id"),
iteration_index=node_data.get("iteration_index"),
loop_id=node_data.get("loop_id"),
loop_index=node_data.get("loop_index"),
parallel_id=node_data.get("parallel_id"),
node_inputs=node_data.get("node_inputs"),
node_outputs=node_data.get("node_outputs"),
process_data=node_data.get("process_data"),
invoked_by=self._get_user_id_from_metadata(metadata),
)
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
node_trace = self.node_execution_trace(**kwargs)
if not isinstance(node_trace, WorkflowNodeTraceInfo):
return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict:
if not message_data.message_metadata:
return {}
@ -937,13 +1421,17 @@ class TraceQueueManager:
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
self.flask_app = current_app._get_current_object() # type: ignore
from core.telemetry.gateway import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
if trace_manager_timer is None:
self.start_timer()
def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer, trace_manager_queue
try:
if self.trace_instance:
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception:
@ -979,20 +1467,27 @@ class TraceQueueManager:
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
if task.app_id is None:
continue
storage_id = task.app_id
if storage_id is None:
tenant_id = task.kwargs.get("tenant_id")
if tenant_id:
storage_id = f"tenant-{tenant_id}"
else:
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
app_id=task.app_id,
app_id=storage_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
"app_id": storage_id,
}
process_trace_tasks.delay(file_info) # type: ignore

View File

@ -918,11 +918,11 @@ class ProviderManager:
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
pool_type=ProviderQuotaType.TRIAL,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
pool_type=ProviderQuotaType.PAID,
)
else:
trail_pool = None

View File

@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore
from pymochow.model.database import Database
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
from pymochow.model.schema import (
AutoBuildRowCountIncrement,
Field,
FilteringIndex,
HNSWParams,
@ -51,6 +52,9 @@ class BaiduConfig(BaseModel):
replicas: int = 3
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
inverted_index_parser_mode: str = "COARSE_MODE"
auto_build_row_count_increment: int = 500
auto_build_row_count_increment_ratio: float = 0.05
rebuild_index_timeout_in_seconds: int = 300
@model_validator(mode="before")
@classmethod
@ -107,18 +111,6 @@ class BaiduVector(BaseVector):
rows.append(row)
table.upsert(rows=rows)
# rebuild vector index after upsert finished
table.rebuild_index(self.vector_index)
timeout = 3600 # 1 hour timeout
start_time = time.time()
while True:
time.sleep(1)
index = table.describe_index(self.vector_index)
if index.state == IndexState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
if res and res.code == 0:
@ -232,8 +224,14 @@ class BaiduVector(BaseVector):
return self._client.database(self._client_config.database)
def _table_existed(self) -> bool:
tables = self._db.list_table()
return any(table.table_name == self._collection_name for table in tables)
try:
table = self._db.table(self._collection_name)
except ServerError as e:
if e.code == ServerErrCode.TABLE_NOT_EXIST:
return False
else:
raise
return True
def _create_table(self, dimension: int):
# Try to grab distributed lock and create table
@ -287,6 +285,11 @@ class BaiduVector(BaseVector):
field=VDBField.VECTOR,
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
auto_build=True,
auto_build_index_policy=AutoBuildRowCountIncrement(
row_count_increment=self._client_config.auto_build_row_count_increment,
row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio,
),
)
)
@ -335,7 +338,7 @@ class BaiduVector(BaseVector):
)
# Wait for table created
timeout = 300 # 5 minutes timeout
timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout
start_time = time.time()
while True:
time.sleep(1)
@ -345,6 +348,20 @@ class BaiduVector(BaseVector):
if time.time() - start_time > timeout:
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
redis_client.set(table_exist_cache_key, 1, ex=3600)
# rebuild vector index immediately after table created, make sure index is ready
table.rebuild_index(self.vector_index)
timeout = 3600 # 1 hour timeout
self._wait_for_index_ready(table, timeout)
def _wait_for_index_ready(self, table, timeout: int = 3600):
start_time = time.time()
while True:
time.sleep(1)
index = table.describe_index(self.vector_index)
if index.state == IndexState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
class BaiduVectorFactory(AbstractVectorFactory):
@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory):
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT,
auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO,
rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS,
),
)

View File

@ -33,6 +33,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, TidbAuthBinding
from models.enums import TidbAuthBindingStatus
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
password=new_cluster["password"],
tenant_id=dataset.tenant_id,
active=True,
status="ACTIVE",
status=TidbAuthBindingStatus.ACTIVE,
)
db.session.add(new_tidb_auth_binding)
db.session.commit()

View File

@ -9,6 +9,7 @@ from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
class TidbService:
@ -170,7 +171,7 @@ class TidbService:
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = "ACTIVE"
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()

View File

@ -95,15 +95,11 @@ class FirecrawlApp:
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed":
total = crawl_status_response.get("total", 0)
if total == 0:
# Normalize to avoid None bypassing the zero-guard when the API returns null.
total = crawl_status_response.get("total") or 0
if total <= 0:
raise Exception("Failed to check crawl status. Error: No page found")
data = crawl_status_response.get("data", [])
url_data_list: list[FirecrawlDocumentData] = []
for item in data:
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data = self._extract_common_fields(item)
url_data_list.append(url_data)
url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers)
if url_data_list:
file_key = "website_files/" + job_id + ".txt"
try:
@ -120,6 +116,36 @@ class FirecrawlApp:
self._handle_error(response, "check crawl status")
raise RuntimeError("unreachable: _handle_error always raises")
def _collect_all_crawl_pages(
self, first_page: dict[str, Any], headers: dict[str, str]
) -> list[FirecrawlDocumentData]:
"""Collect all crawl result pages by following pagination links.
Raises an exception if any paginated request fails, to avoid returning
partial data that is inconsistent with the reported total.
The number of pages processed is capped at ``total`` (the
server-reported page count) to guard against infinite loops caused by
a misbehaving server that keeps returning a ``next`` URL.
"""
total: int = first_page.get("total") or 0
url_data_list: list[FirecrawlDocumentData] = []
current_page = first_page
pages_processed = 0
while True:
for item in current_page.get("data", []):
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data_list.append(self._extract_common_fields(item))
next_url: str | None = current_page.get("next")
pages_processed += 1
if not next_url or pages_processed >= total:
break
response = self._get_request(next_url, headers)
if response.status_code != 200:
self._handle_error(response, "fetch next crawl page")
current_page = response.json()
return url_data_list
def _format_crawl_status_response(
self,
status: str,

View File

@ -0,0 +1,43 @@
"""Telemetry facade.
Thin public API for emitting telemetry events. All routing logic
lives in ``core.telemetry.gateway`` which is shared by both CE and EE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
from core.telemetry.gateway import emit as gateway_emit
from core.telemetry.gateway import get_trace_task_to_case
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
"""Emit a telemetry event.
Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a
``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``.
"""
case = get_trace_task_to_case().get(event.name)
if case is None:
return
context: dict[str, object] = {
"tenant_id": event.context.tenant_id,
"user_id": event.context.user_id,
"app_id": event.context.app_id,
}
gateway_emit(case, context, event.payload, trace_manager)
__all__ = [
"TelemetryContext",
"TelemetryEvent",
"TraceTaskName",
"emit",
]

View File

@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.ops.entities.trace_entity import TraceTaskName
@dataclass(frozen=True)
class TelemetryContext:
tenant_id: str | None = None
user_id: str | None = None
app_id: str | None = None
@dataclass(frozen=True)
class TelemetryEvent:
name: TraceTaskName
context: TelemetryContext
payload: dict[str, Any]

View File

@ -0,0 +1,239 @@
"""Telemetry gateway — single routing layer for all editions.
Maps ``TelemetryCase`` ``CaseRoute`` and dispatches events to either
the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only
metric/log Celery queue.
This module lives in ``core/`` so both CE and EE share one routing table
and one ``emit()`` entry point. No separate enterprise gateway module is
needed enterprise-specific dispatch (Celery task, payload offloading)
is handled here behind lazy imports that no-op in CE.
"""
from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any
from core.ops.entities.trace_entity import TraceTaskName
from enterprise.telemetry.contracts import SignalType
from extensions.ext_storage import storage
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from enterprise.telemetry.contracts import TelemetryCase
logger = logging.getLogger(__name__)
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
# ---------------------------------------------------------------------------
# Routing table — authoritative mapping for all editions
# ---------------------------------------------------------------------------
_case_to_trace_task: dict | None = None
_case_routing: dict | None = None
def _get_case_to_trace_task() -> dict:
global _case_to_trace_task
if _case_to_trace_task is None:
from enterprise.telemetry.contracts import TelemetryCase
_case_to_trace_task = {
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE,
TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE,
TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE,
TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE,
TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE,
}
return _case_to_trace_task
def get_trace_task_to_case() -> dict:
"""Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task)."""
return {v: k for k, v in _get_case_to_trace_task().items()}
def _get_case_routing() -> dict:
global _case_routing
if _case_routing is None:
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase
_case_routing = {
# TRACE — CE-eligible (flow in both CE and EE)
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
# TRACE — enterprise-only
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
# METRIC_LOG — enterprise-only (signal-driven, not trace)
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
}
return _case_routing
def __getattr__(name: str) -> dict:
"""Lazy module-level access to routing tables."""
if name == "CASE_ROUTING":
return _get_case_routing()
if name == "CASE_TO_TRACE_TASK":
return _get_case_to_trace_task()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def is_enterprise_telemetry_enabled() -> bool:
try:
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
return is_enterprise_telemetry_enabled()
except Exception:
return False
def _handle_payload_sizing(
payload: dict[str, Any],
tenant_id: str,
event_id: str,
) -> tuple[dict[str, Any], str | None]:
"""Inline or offload payload based on size.
Returns ``(payload_for_envelope, storage_key | None)``. Payloads
exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object
storage and replaced with an empty dict in the envelope.
"""
try:
payload_json = json.dumps(payload)
payload_size = len(payload_json.encode("utf-8"))
except (TypeError, ValueError):
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
return payload, None
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
return payload, None
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
try:
storage.save(storage_key, payload_json.encode("utf-8"))
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
return {}, storage_key
except Exception:
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
return payload, None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def emit(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None = None,
) -> None:
"""Route a telemetry event to the correct pipeline.
TRACE events are enqueued into ``TraceQueueManager`` (works in both CE
and EE). Enterprise-only traces are silently dropped when EE is
disabled.
METRIC_LOG events are dispatched to the enterprise Celery queue;
silently dropped when enterprise telemetry is unavailable.
"""
route = _get_case_routing().get(case)
if route is None:
logger.warning("Unknown telemetry case: %s, dropping event", case)
return
if not route.ce_eligible and not is_enterprise_telemetry_enabled():
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
return
if route.signal_type == SignalType.TRACE:
_emit_trace(case, context, payload, trace_manager)
else:
_emit_metric_log(case, context, payload)
def _emit_trace(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None,
) -> None:
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
from core.ops.ops_trace_manager import TraceTask
trace_task_name = _get_case_to_trace_task().get(case)
if trace_task_name is None:
logger.warning("No TraceTaskName mapping for case: %s", case)
return
queue_manager = trace_manager or LocalTraceQueueManager(
app_id=context.get("app_id"),
user_id=context.get("user_id"),
)
queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload))
logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id"))
def _emit_metric_log(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
) -> None:
"""Build envelope and dispatch to enterprise Celery queue.
No-ops when the enterprise telemetry task is not importable (CE mode).
"""
try:
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
except ImportError:
logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case)
return
tenant_id = context.get("tenant_id") or ""
event_id = str(uuid.uuid4())
payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id)
from enterprise.telemetry.contracts import TelemetryEnvelope
envelope = TelemetryEnvelope(
case=case,
tenant_id=tenant_id,
event_id=event_id,
payload=payload_for_envelope,
metadata={"payload_ref": payload_ref} if payload_ref else None,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())
logger.debug(
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
case,
tenant_id,
event_id,
)

View File

@ -50,7 +50,7 @@ class BuiltinTool(Tool):
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_type=ToolProviderType.BUILT_IN,
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
)

View File

@ -38,7 +38,7 @@ class ToolLabelManager:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type.value,
tool_type=controller.provider_type,
label_name=label,
)
)
@ -58,7 +58,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
ToolLabelBinding.tool_type == controller.provider_type,
)
labels = db.session.scalars(stmt).all()

View File

@ -9,6 +9,7 @@ from decimal import Decimal
from typing import cast
from core.model_manager import ModelManager
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.model_runtime.entities.llm_entities import LLMResult
from dify_graph.model_runtime.entities.message_entities import PromptMessage
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
@ -78,7 +79,7 @@ class ModelInvocationUtils:
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context

View File

@ -0,0 +1,525 @@
# Dify Enterprise Telemetry Data Dictionary
Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md).
## Resource Attributes
Attached to every signal (Span, Metric, Log).
| Attribute | Type | Example |
|-----------|------|---------|
| `service.name` | string | `dify` |
| `host.name` | string | `dify-api-7f8b` |
## Traces (Spans)
### `dify.workflow.run`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.trace_id` | string | Business trace ID (Workflow Run ID) |
| `dify.tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.workflow.id` | string | Workflow definition ID |
| `dify.workflow.run_id` | string | Unique ID for this run |
| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. |
| `dify.workflow.error` | string | Error message if failed |
| `dify.workflow.elapsed_time` | float | Total execution time (seconds) |
| `dify.invoke_from` | string | `api`, `webapp`, `debug` |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.message.id` | string | Message ID (optional) |
| `dify.invoked_by` | string | User ID who triggered the run |
| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) |
| `gen_ai.user.id` | string | End-user identifier (optional) |
| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) |
| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) |
| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) |
| `dify.parent.app.id` | string | Parent app ID (optional) |
### `dify.node.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.trace_id` | string | Business trace ID |
| `dify.tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.workflow.id` | string | Workflow definition ID |
| `dify.workflow.run_id` | string | Workflow Run ID |
| `dify.message.id` | string | Message ID (optional) |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.node.execution_id` | string | Unique node execution ID |
| `dify.node.id` | string | Node ID in workflow graph |
| `dify.node.type` | string | Node type (see appendix) |
| `dify.node.title` | string | Display title |
| `dify.node.status` | string | `succeeded`, `failed` |
| `dify.node.error` | string | Error message if failed |
| `dify.node.elapsed_time` | float | Execution time (seconds) |
| `dify.node.index` | int | Execution order index |
| `dify.node.predecessor_node_id` | string | Triggering node ID |
| `dify.node.iteration_id` | string | Iteration ID (optional) |
| `dify.node.loop_id` | string | Loop ID (optional) |
| `dify.node.parallel_id` | string | Parallel branch ID (optional) |
| `dify.node.invoked_by` | string | User ID who triggered execution |
| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) |
| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) |
| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) |
| `gen_ai.request.model` | string | LLM model name (LLM nodes only) |
| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) |
| `gen_ai.user.id` | string | End-user identifier (optional) |
### `dify.node.execution.draft`
Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs.
## Counters
All counters are cumulative and emitted at 100% accuracy.
### Token Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.tokens.total` | `{token}` | Total tokens consumed |
| `dify.tokens.input` | `{token}` | Input (prompt) tokens |
| `dify.tokens.output` | `{token}` | Output (completion) tokens |
**Labels:**
- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution)
⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting.
#### Token Hierarchy & Query Patterns
Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting:
```
App-level total
├── workflow ← sum of all node_execution tokens (DO NOT add both)
│ └── node_execution ← per-node breakdown
├── message ← independent (non-workflow chat apps only)
├── rule_generate ← independent helper LLM call
├── code_generate ← independent helper LLM call
├── structured_output ← independent helper LLM call
└── instruction_modify← independent helper LLM call
```
**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both.
**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`.
App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries.
**Common queries** (PromQL):
```promql
# ── Totals ──────────────────────────────────────────────────
# App-level total (exclude node_execution to avoid double-counting)
sum by (app_id) (dify_tokens_total{operation_type!="node_execution"})
# Single app total
sum (dify_tokens_total{app_id="<app_id>", operation_type!="node_execution"})
# Per-tenant totals
sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"})
# ── Drill-down ──────────────────────────────────────────────
# Workflow-level tokens for an app
sum (dify_tokens_total{app_id="<app_id>", operation_type="workflow"})
# Node-level breakdown within an app
sum by (node_type) (dify_tokens_total{app_id="<app_id>", operation_type="node_execution"})
# Model breakdown for an app
sum by (model_provider, model_name) (dify_tokens_total{app_id="<app_id>"})
# Input vs output per model
sum by (model_name) (dify_tokens_input_total{app_id="<app_id>"})
sum by (model_name) (dify_tokens_output_total{app_id="<app_id>"})
# ── Rates ───────────────────────────────────────────────────
# Token consumption rate (per hour)
sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
# Per-app consumption rate
sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
```
**Finding `app_id` from app name** (trace query — Tempo / Jaeger):
```
{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id)
```
### Request Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.requests.total` | `{request}` | Total operations count |
**Labels by type:**
| `type` | Additional Labels |
|--------|-------------------|
| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` |
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` |
| `tool` | `tenant_id`, `app_id`, `tool_name` |
| `moderation` | `tenant_id`, `app_id` |
| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dataset_retrieval` | `tenant_id`, `app_id` |
| `generate_name` | `tenant_id`, `app_id` |
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` |
### Error Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.errors.total` | `{error}` | Total failed operations |
**Labels by type:**
| `type` | Additional Labels |
|--------|-------------------|
| `workflow` | `tenant_id`, `app_id` |
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `tool` | `tenant_id`, `app_id`, `tool_name` |
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
### Other Counters
| Metric | Unit | Labels |
|--------|------|--------|
| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` |
| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` |
| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` |
| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` |
| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` |
## Histograms
| Metric | Unit | Labels |
|--------|------|--------|
| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` |
| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` |
| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` |
| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
## Structured Logs
### Span Companion Logs
Logs that accompany spans. Signal type: `span_detail`
#### `dify.workflow.run` Companion Log
**Common attributes:** All span attributes (see Traces section) plus:
| Additional Attribute | Type | Always Present | Description |
|---------------------|------|----------------|-------------|
| `dify.app.name` | string | No | Application display name |
| `dify.workspace.name` | string | No | Workspace display name |
| `dify.workflow.version` | string | Yes | Workflow definition version |
| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) |
| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) |
| `dify.workflow.query` | string | No | User query text (content-gated) |
**Event attributes:**
- `dify.event.name`: `"dify.workflow.run"`
- `dify.event.signal`: `"span_detail"`
- `trace_id`, `span_id`, `tenant_id`, `user_id`
#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs
**Common attributes:** All span attributes (see Traces section) plus:
| Additional Attribute | Type | Always Present | Description |
|---------------------|------|----------------|-------------|
| `dify.app.name` | string | No | Application display name |
| `dify.workspace.name` | string | No | Workspace display name |
| `dify.invoke_from` | string | No | Invocation source |
| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) |
| `dify.node.total_price` | float | No | Cost (LLM nodes only) |
| `dify.node.currency` | string | No | Currency code (LLM nodes only) |
| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) |
| `dify.node.loop_index` | int | No | Loop index (loop nodes) |
| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) |
| `dify.credential.name` | string | No | Credential name (plugin nodes) |
| `dify.credential.id` | string | No | Credential ID (plugin nodes) |
| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) |
| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) |
| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) |
| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) |
| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) |
**Event attributes:**
- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"`
- `dify.event.signal`: `"span_detail"`
- `trace_id`, `span_id`, `tenant_id`, `user_id`
### Standalone Logs
Logs without structural spans. Signal type: `metric_only`
#### `dify.message.run`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.message.run"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID (32-char hex) |
| `span_id` | string | OTEL span ID (16-char hex) |
| `tenant_id` | string | Tenant identifier |
| `user_id` | string | User identifier (optional) |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.workflow.run_id` | string | Workflow run ID (optional) |
| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` |
| `gen_ai.provider.name` | string | LLM provider |
| `gen_ai.request.model` | string | LLM model |
| `gen_ai.usage.input_tokens` | int | Input tokens |
| `gen_ai.usage.output_tokens` | int | Output tokens |
| `gen_ai.usage.total_tokens` | int | Total tokens |
| `dify.message.status` | string | `succeeded`, `failed` |
| `dify.message.error` | string | Error message (if failed) |
| `dify.message.duration` | float | Duration (seconds) |
| `dify.message.time_to_first_token` | float | TTFT (seconds) |
| `dify.message.inputs` | string/JSON | Inputs (content-gated) |
| `dify.message.outputs` | string/JSON | Outputs (content-gated) |
#### `dify.tool.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.tool.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.tool.name` | string | Tool name |
| `dify.tool.duration` | float | Duration (seconds) |
| `dify.tool.status` | string | `succeeded`, `failed` |
| `dify.tool.error` | string | Error message (if failed) |
| `dify.tool.inputs` | string/JSON | Inputs (content-gated) |
| `dify.tool.outputs` | string/JSON | Outputs (content-gated) |
| `dify.tool.parameters` | string/JSON | Parameters (content-gated) |
| `dify.tool.config` | string/JSON | Configuration (content-gated) |
#### `dify.moderation.check`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.moderation.check"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.moderation.type` | string | `input`, `output` |
| `dify.moderation.action` | string | `pass`, `block`, `flag` |
| `dify.moderation.flagged` | boolean | Whether flagged |
| `dify.moderation.categories` | JSON array | Flagged categories |
| `dify.moderation.query` | string | Content (content-gated) |
#### `dify.suggested_question.generation`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.suggested_question.generation"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.suggested_question.count` | int | Number of questions |
| `dify.suggested_question.duration` | float | Duration (seconds) |
| `dify.suggested_question.status` | string | `succeeded`, `failed` |
| `dify.suggested_question.error` | string | Error message (if failed) |
| `dify.suggested_question.questions` | JSON array | Questions (content-gated) |
#### `dify.dataset.retrieval`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.dataset.retrieval"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.dataset.id` | string | Dataset identifier |
| `dify.dataset.name` | string | Dataset name |
| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) |
| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) |
| `dify.retrieval.rerank_provider` | string | Rerank model provider |
| `dify.retrieval.rerank_model` | string | Rerank model name |
| `dify.retrieval.query` | string | Search query (content-gated) |
| `dify.retrieval.document_count` | int | Documents retrieved |
| `dify.retrieval.duration` | float | Duration (seconds) |
| `dify.retrieval.status` | string | `succeeded`, `failed` |
| `dify.retrieval.error` | string | Error message (if failed) |
| `dify.dataset.documents` | JSON array | Documents (content-gated) |
#### `dify.generate_name.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.generate_name.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.conversation.id` | string | Conversation identifier |
| `dify.generate_name.duration` | float | Duration (seconds) |
| `dify.generate_name.status` | string | `succeeded`, `failed` |
| `dify.generate_name.error` | string | Error message (if failed) |
| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) |
| `dify.generate_name.outputs` | string | Generated name (content-gated) |
#### `dify.prompt_generation.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.prompt_generation.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) |
| `gen_ai.provider.name` | string | LLM provider |
| `gen_ai.request.model` | string | LLM model |
| `gen_ai.usage.input_tokens` | int | Input tokens |
| `gen_ai.usage.output_tokens` | int | Output tokens |
| `gen_ai.usage.total_tokens` | int | Total tokens |
| `dify.prompt_generation.duration` | float | Duration (seconds) |
| `dify.prompt_generation.status` | string | `succeeded`, `failed` |
| `dify.prompt_generation.error` | string | Error message (if failed) |
| `dify.prompt_generation.instruction` | string | Instruction (content-gated) |
| `dify.prompt_generation.output` | string/JSON | Output (content-gated) |
#### `dify.app.created`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.created"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` |
| `dify.app.created_at` | string | Timestamp (ISO 8601) |
#### `dify.app.updated`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.updated"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.updated_at` | string | Timestamp (ISO 8601) |
#### `dify.app.deleted`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.deleted"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.deleted_at` | string | Timestamp (ISO 8601) |
#### `dify.feedback.created`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.feedback.created"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.feedback.rating` | string | `like`, `dislike`, `null` |
| `dify.feedback.content` | string | Feedback text (content-gated) |
| `dify.feedback.created_at` | string | Timestamp (ISO 8601) |
#### `dify.telemetry.rehydration_failed`
Diagnostic event for telemetry system health monitoring.
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.telemetry.error` | string | Error message |
| `dify.telemetry.payload_type` | string | Payload type (see appendix) |
| `dify.telemetry.correlation_id` | string | Correlation ID |
## Content-Gated Attributes
When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`).
| Attribute | Signal |
|-----------|--------|
| `dify.workflow.inputs` | `dify.workflow.run` |
| `dify.workflow.outputs` | `dify.workflow.run` |
| `dify.workflow.query` | `dify.workflow.run` |
| `dify.node.inputs` | `dify.node.execution` |
| `dify.node.outputs` | `dify.node.execution` |
| `dify.node.process_data` | `dify.node.execution` |
| `dify.message.inputs` | `dify.message.run` |
| `dify.message.outputs` | `dify.message.run` |
| `dify.tool.inputs` | `dify.tool.execution` |
| `dify.tool.outputs` | `dify.tool.execution` |
| `dify.tool.parameters` | `dify.tool.execution` |
| `dify.tool.config` | `dify.tool.execution` |
| `dify.moderation.query` | `dify.moderation.check` |
| `dify.suggested_question.questions` | `dify.suggested_question.generation` |
| `dify.retrieval.query` | `dify.dataset.retrieval` |
| `dify.dataset.documents` | `dify.dataset.retrieval` |
| `dify.generate_name.inputs` | `dify.generate_name.execution` |
| `dify.generate_name.outputs` | `dify.generate_name.execution` |
| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` |
| `dify.prompt_generation.output` | `dify.prompt_generation.execution` |
| `dify.feedback.content` | `dify.feedback.created` |
## Appendix
### Operation Types
- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify`
### Node Types
- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input`
### Workflow Statuses
- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused`
### Payload Types
- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback`
### Null Value Behavior
**Spans:** Attributes with `null` values are omitted.
**Logs:** Attributes with `null` values appear as `null` in JSON.
**Content-Gated:** Replaced with reference strings, not set to `null`.

View File

@ -0,0 +1,121 @@
# Dify Enterprise Telemetry
This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb.
## Overview
Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage.
- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes).
- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`.
- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking.
### Signal Architecture
```mermaid
graph TD
A[Workflow Run] -->|Span| B(dify.workflow.run)
A -->|Log| C(dify.workflow.run detail)
B ---|trace_id| C
D[Node Execution] -->|Span| E(dify.node.execution)
D -->|Log| F(dify.node.execution detail)
E ---|span_id| F
G[Message/Tool/etc] -->|Log| H(dify.* event)
G -->|Metric| I(dify.* counter/histogram)
```
## Configuration
The Enterprise OTEL exporter is configured via environment variables.
| Variable | Description | Default |
|----------|-------------|---------|
| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` |
| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` |
| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - |
| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - |
| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` |
| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - |
| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `true` |
| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` |
| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` |
## Correlation Model
Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks.
### ID Generation Rules
- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))`
- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)`
### Scenario A: Simple Workflow
A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`).
```
trace_id = UUID(workflow_run_id)
├── [root span] dify.workflow.run (span_id = hash(workflow_run_id))
│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1))
│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2))
│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3))
```
### Scenario B: Nested Sub-Workflow
A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id.
```
trace_id = UUID(outer_workflow_run_id) ← shared across both workflows
├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id))
│ ├── dify.node.execution - "Start Node"
│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow)
│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id))
│ │ ├── dify.node.execution - "Inner Start"
│ │ └── dify.node.execution - "Inner End"
│ └── dify.node.execution - "End Node"
```
**Key attributes for nested workflows:**
- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id`
- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id`
- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id`
- Inner workflow's `dify.parent.app.id` = outer `app_id`
### Scenario C: Draft Node Execution
A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root.
```
trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow
└── dify.node.execution.draft (span_id = hash(node_execution_id))
```
**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace.
## Content Gating
When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector.
**Reference String Format:**
```
ref:{id_type}={uuid}
```
**Examples:**
```
ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000
ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001
ref:message_id=770e8400-e29b-41d4-a716-446655440002
```
To retrieve actual content when gating is enabled, query the Dify database using the provided UUID.
## Reference
For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md).

View File

View File

@ -0,0 +1,73 @@
"""Telemetry gateway contracts and data structures.
This module defines the envelope format for telemetry events and the routing
configuration that determines how each event type is processed.
"""
from __future__ import annotations
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict
class TelemetryCase(StrEnum):
"""Enumeration of all known telemetry event cases."""
WORKFLOW_RUN = "workflow_run"
NODE_EXECUTION = "node_execution"
DRAFT_NODE_EXECUTION = "draft_node_execution"
MESSAGE_RUN = "message_run"
TOOL_EXECUTION = "tool_execution"
MODERATION_CHECK = "moderation_check"
SUGGESTED_QUESTION = "suggested_question"
DATASET_RETRIEVAL = "dataset_retrieval"
GENERATE_NAME = "generate_name"
PROMPT_GENERATION = "prompt_generation"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
FEEDBACK_CREATED = "feedback_created"
class SignalType(StrEnum):
"""Signal routing type for telemetry cases."""
TRACE = "trace"
METRIC_LOG = "metric_log"
class CaseRoute(BaseModel):
"""Routing configuration for a telemetry case.
Attributes:
signal_type: The type of signal (trace or metric_log).
ce_eligible: Whether this case is eligible for community edition tracing.
"""
signal_type: SignalType
ce_eligible: bool
class TelemetryEnvelope(BaseModel):
"""Envelope for telemetry events.
Attributes:
case: The telemetry case type.
tenant_id: The tenant identifier.
event_id: Unique event identifier for deduplication.
payload: The main event payload (inline for small payloads,
empty when offloaded to storage via ``payload_ref``).
metadata: Optional metadata dictionary. When the gateway
offloads a large payload to object storage, this contains
``{"payload_ref": "<storage_key>"}``.
"""
model_config = ConfigDict(extra="forbid", use_enum_values=False)
case: TelemetryCase
tenant_id: str
event_id: str
payload: dict[str, Any]
metadata: dict[str, Any] | None = None

View File

@ -0,0 +1,77 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from models.workflow import WorkflowNodeExecutionModel
def enqueue_draft_node_execution_trace(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
user_id: str,
) -> None:
node_data = _build_node_execution_data(
execution=execution,
outputs=outputs,
workflow_execution_id=workflow_execution_id,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id=execution.tenant_id,
user_id=user_id,
app_id=execution.app_id,
),
payload={"node_execution_data": node_data},
)
)
def _build_node_execution_data(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
) -> dict[str, Any]:
metadata = execution.execution_metadata_dict
node_outputs = outputs if outputs is not None else execution.outputs_dict
execution_id = workflow_execution_id or execution.workflow_run_id or execution.id
return {
"workflow_id": execution.workflow_id,
"workflow_execution_id": execution_id,
"tenant_id": execution.tenant_id,
"app_id": execution.app_id,
"node_execution_id": execution.id,
"node_id": execution.node_id,
"node_type": execution.node_type,
"title": execution.title,
"status": execution.status,
"error": execution.error,
"elapsed_time": execution.elapsed_time,
"index": execution.index,
"predecessor_node_id": execution.predecessor_node_id,
"created_at": execution.created_at,
"finished_at": execution.finished_at,
"total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
"total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
"currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
"tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
else None,
"iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
"iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
"loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
"loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
"parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
"node_inputs": execution.inputs_dict,
"node_outputs": node_outputs,
"process_data": execution.process_data_dict,
}

View File

@ -0,0 +1,955 @@
"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass.
Invoked directly in the Celery task, not through OpsTraceManager dispatch.
Only requires a matching ``trace(trace_info)`` method signature.
Signal strategy:
- **Traces (spans)**: workflow run, node execution, draft node execution only.
- **Metrics + structured logs**: all other event types.
Token metric labels (unified structure):
All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the
same label set for consistent filtering and aggregation:
- tenant_id: Tenant identifier
- app_id: Application identifier
- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.)
- model_provider: LLM provider name (empty string if not applicable)
- model_name: LLM model name (empty string if not applicable)
- node_type: Workflow node type (empty string if not node_execution)
This unified structure allows filtering by operation_type to separate:
- Workflow-level aggregates (operation_type=workflow)
- Individual node executions (operation_type=node_execution)
- Direct message calls (operation_type=message)
- Prompt generation operations (operation_type=rule_generate, code_generate, etc.)
Without this, tokens are double-counted when querying totals (workflow totals include
node totals, since workflow.total_tokens is the sum of all node tokens).
"""
from __future__ import annotations
import json
import logging
from typing import Any, cast
from opentelemetry.util.types import AttributeValue
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
OperationType,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from enterprise.telemetry.entities import (
EnterpriseTelemetryCounter,
EnterpriseTelemetryEvent,
EnterpriseTelemetryHistogram,
EnterpriseTelemetrySpan,
TokenMetricLabels,
)
from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log
logger = logging.getLogger(__name__)
class EnterpriseOtelTrace:
"""Duck-typed enterprise trace handler.
``*_trace`` methods emit spans (workflow/node only) or structured logs
(all other events), plus metrics at 100 % accuracy.
"""
def __init__(self) -> None:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if exporter is None:
raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized")
self._exporter = exporter
def trace(self, trace_info: BaseTraceInfo) -> None:
if isinstance(trace_info, WorkflowTraceInfo):
self._workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self._message_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self._tool_trace(trace_info)
elif isinstance(trace_info, DraftNodeExecutionTrace):
self._draft_node_execution_trace(trace_info)
elif isinstance(trace_info, WorkflowNodeTraceInfo):
self._node_execution_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
self._moderation_trace(trace_info)
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self._suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self._dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
self._generate_name_trace(trace_info)
elif isinstance(trace_info, PromptGenerationTraceInfo):
self._prompt_generation_trace(trace_info)
def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
metadata = self._metadata(trace_info)
tenant_id, app_id, user_id = self._context_ids(trace_info, metadata)
return {
"dify.trace_id": trace_info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"dify.message.id": trace_info.message_id,
}
def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
return trace_info.metadata
def _context_ids(
self,
trace_info: BaseTraceInfo,
metadata: dict[str, Any],
) -> tuple[str | None, str | None, str | None]:
tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id")
app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id")
user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id")
return tenant_id, app_id, user_id
def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]:
return dict(values)
def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None:
if isinstance(value, str):
return value
if isinstance(value, dict):
return cast(dict[str, Any], value)
if isinstance(value, list):
items: list[object] = []
for item in cast(list[object], value):
items.append(item)
return items
return None
def _content_or_ref(self, value: Any, ref: str) -> Any:
if self._exporter.include_content:
return self._maybe_json(value)
return ref
def _maybe_json(self, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, default=str)
except (TypeError, ValueError):
return str(value)
# ------------------------------------------------------------------
# SPAN-emitting handlers (workflow, node execution, draft node)
# ------------------------------------------------------------------
def _workflow_trace(self, info: WorkflowTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.workflow.status": info.workflow_run_status,
"dify.workflow.error": info.error,
"dify.workflow.elapsed_time": info.workflow_run_elapsed_time,
"dify.invoke_from": metadata.get("triggered_from"),
"dify.conversation.id": info.conversation_id,
"dify.message.id": info.message_id,
"dify.invoked_by": info.invoked_by,
"gen_ai.usage.total_tokens": info.total_tokens,
"gen_ai.user.id": user_id,
}
trace_correlation_override, parent_span_id_source = info.resolved_parent_context
parent_ctx = metadata.get("parent_trace_context")
if isinstance(parent_ctx, dict):
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id")
span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id")
span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id")
span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id")
self._exporter.export_span(
EnterpriseTelemetrySpan.WORKFLOW_RUN,
span_attrs,
correlation_id=info.workflow_run_id,
span_id_source=info.workflow_run_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
parent_span_id_source=parent_span_id_source,
)
# -- Companion log: ALL attrs (span + detail) for full picture --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.workflow.version": info.workflow_run_version,
}
)
ref = f"ref:workflow_run_id={info.workflow_run_id}"
log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref)
log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref)
log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref)
emit_telemetry_log(
event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.workflow_run_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.WORKFLOW,
model_provider="",
model_name="",
node_type="",
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
invoke_from = metadata.get("triggered_from", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="workflow",
status=info.workflow_run_status,
invoke_from=invoke_from,
),
)
# Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults
# to 0 in the DB and can be stale if the Celery write races with the trace task.
# start_time = workflow_run.created_at, end_time = workflow_run.finished_at.
if info.start_time and info.end_time:
workflow_duration = (info.end_time - info.start_time).total_seconds()
elif info.workflow_run_elapsed_time:
workflow_duration = float(info.workflow_run_elapsed_time)
else:
workflow_duration = 0.0
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.WORKFLOW_DURATION,
workflow_duration,
self._labels(
**labels,
status=info.workflow_run_status,
),
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="workflow",
),
)
def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None:
self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node")
def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None:
self._emit_node_execution_trace(
info,
EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION,
"draft_node",
correlation_id_override=info.node_execution_id,
trace_correlation_override_param=info.workflow_run_id,
)
def _emit_node_execution_trace(
self,
info: WorkflowNodeTraceInfo,
span_name: EnterpriseTelemetrySpan,
request_type: str,
correlation_id_override: str | None = None,
trace_correlation_override_param: str | None = None,
) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.message.id": info.message_id,
"dify.conversation.id": metadata.get("conversation_id"),
"dify.node.execution_id": info.node_execution_id,
"dify.node.id": info.node_id,
"dify.node.type": info.node_type,
"dify.node.title": info.title,
"dify.node.status": info.status,
"dify.node.error": info.error,
"dify.node.elapsed_time": info.elapsed_time,
"dify.node.index": info.index,
"dify.node.predecessor_node_id": info.predecessor_node_id,
"dify.node.iteration_id": info.iteration_id,
"dify.node.loop_id": info.loop_id,
"dify.node.parallel_id": info.parallel_id,
"dify.node.invoked_by": info.invoked_by,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"gen_ai.request.model": info.model_name,
"gen_ai.provider.name": info.model_provider,
"gen_ai.user.id": user_id,
}
resolved_override, _ = info.resolved_parent_context
trace_correlation_override = trace_correlation_override_param or resolved_override
effective_correlation_id = correlation_id_override or info.workflow_run_id
self._exporter.export_span(
span_name,
span_attrs,
correlation_id=effective_correlation_id,
span_id_source=info.node_execution_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
)
# -- Companion log: ALL attrs (span + detail) --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.invoke_from": metadata.get("invoke_from"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.node.total_price": info.total_price,
"dify.node.currency": info.currency,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.tool.name": info.tool_name,
"dify.node.iteration_index": info.iteration_index,
"dify.node.loop_index": info.loop_index,
"dify.plugin.name": metadata.get("plugin_name"),
"dify.credential.name": metadata.get("credential_name"),
"dify.credential.id": metadata.get("credential_id"),
"dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")),
"dify.dataset.names": self._maybe_json(metadata.get("dataset_names")),
}
)
ref = f"ref:node_execution_id={info.node_execution_id}"
log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref)
log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref)
log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref)
emit_telemetry_log(
event_name=span_name.value,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
node_type=info.node_type,
model_provider=info.model_provider or "",
)
if info.total_tokens:
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.NODE_EXECUTION,
model_provider=info.model_provider or "",
model_name=info.model_name or "",
node_type=info.node_type,
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels
)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type=request_type,
status=info.status,
model_name=info.model_name or "",
),
)
duration_labels = dict(labels)
duration_labels["model_name"] = info.model_name or ""
plugin_name = metadata.get("plugin_name")
if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}:
duration_labels["plugin_name"] = plugin_name
self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type=request_type,
model_name=info.model_name or "",
),
)
# ------------------------------------------------------------------
# METRIC-ONLY handlers (structured log + counters/histograms)
# ------------------------------------------------------------------
def _message_trace(self, info: MessageTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.invoke_from": metadata.get("from_source"),
"dify.conversation.id": metadata.get("conversation_id"),
"dify.conversation.mode": info.conversation_mode,
"gen_ai.provider.name": metadata.get("ls_provider"),
"gen_ai.request.model": metadata.get("ls_model_name"),
"gen_ai.usage.input_tokens": info.message_tokens,
"gen_ai.usage.output_tokens": info.answer_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.message.status": metadata.get("status"),
"dify.message.error": info.error,
"dify.message.from_source": metadata.get("from_source"),
"dify.message.from_end_user_id": metadata.get("from_end_user_id"),
"dify.message.from_account_id": metadata.get("from_account_id"),
"dify.streaming": info.is_streaming_request,
"dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token,
"dify.message.streaming_duration": info.llm_streaming_time_to_generate,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.MESSAGE_RUN,
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None),
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
model_provider=metadata.get("ls_provider") or "",
model_name=metadata.get("ls_model_name") or "",
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.MESSAGE,
model_provider=metadata.get("ls_provider") or "",
model_name=metadata.get("ls_model_name") or "",
node_type="",
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.message_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels)
if info.answer_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels)
invoke_from = metadata.get("from_source", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="message",
status=metadata.get("status", ""),
invoke_from=invoke_from,
),
)
if info.start_time and info.end_time:
duration = (info.end_time - info.start_time).total_seconds()
self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels)
if info.gen_ai_server_time_to_first_token is not None:
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="message",
),
)
def _tool_trace(self, info: ToolTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"gen_ai.tool.name": info.tool_name,
"dify.tool.time_cost": info.time_cost,
"dify.tool.error": info.error,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref)
attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref)
attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref)
attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
tool_name=info.tool_name,
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="tool",
),
)
self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="tool",
),
)
def _moderation_trace(self, info: ModerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.moderation.flagged": info.flagged,
"dify.moderation.action": info.action,
"dify.moderation.preset_response": info.preset_response,
"dify.moderation.type": "unknown",
"dify.moderation.categories": self._maybe_json([]),
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.moderation.query"] = self._content_or_ref(
info.query,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.MODERATION_CHECK,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="moderation",
),
)
def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
duration: float | None = None
if info.start_time is not None and info.end_time is not None:
duration = (info.end_time - info.start_time).total_seconds()
error = info.error or (info.metadata.get("error") if info.metadata else None)
status = "failed" if error else (info.status or "succeeded")
attrs.update(
{
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.suggested_question.status": status,
"dify.suggested_question.error": error,
"dify.suggested_question.duration": duration,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_id,
"dify.suggested_question.count": len(info.suggested_question),
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.suggested_question.questions"] = self._content_or_ref(
info.suggested_question,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="suggested_question",
model_provider=info.model_provider or "",
model_name=info.model_id or "",
),
)
def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.dataset.error"] = info.error
attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id")
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
docs: list[dict[str, Any]] = []
documents_any: Any = info.documents
documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else []
for entry in documents_list:
if isinstance(entry, dict):
entry_dict: dict[str, Any] = cast(dict[str, Any], entry)
docs.append(entry_dict)
dataset_ids: list[str] = []
dataset_names: list[str] = []
structured_docs: list[dict[str, Any]] = []
for doc in docs:
meta_raw = doc.get("metadata")
meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {}
did = meta.get("dataset_id")
dname = meta.get("dataset_name")
if did and did not in dataset_ids:
dataset_ids.append(did)
if dname and dname not in dataset_names:
dataset_names.append(dname)
structured_docs.append(
{
"dataset_id": did,
"document_id": meta.get("document_id"),
"segment_id": meta.get("segment_id"),
"score": meta.get("score"),
}
)
attrs["dify.dataset.ids"] = self._maybe_json(dataset_ids)
attrs["dify.dataset.names"] = self._maybe_json(dataset_names)
attrs["dify.retrieval.document_count"] = len(docs)
embedding_models_raw: Any = metadata.get("embedding_models")
embedding_models: dict[str, Any] = (
cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {}
)
if embedding_models:
providers: list[str] = []
models: list[str] = []
for ds_info in embedding_models.values():
if isinstance(ds_info, dict):
ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info)
p = ds_info_dict.get("embedding_model_provider", "")
m = ds_info_dict.get("embedding_model", "")
if p and p not in providers:
providers.append(p)
if m and m not in models:
models.append(m)
attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers)
attrs["dify.dataset.embedding_models"] = self._maybe_json(models)
# Add rerank model to logs
rerank_provider = metadata.get("rerank_model_provider", "")
rerank_model = metadata.get("rerank_model_name", "")
if rerank_provider or rerank_model:
attrs["dify.retrieval.rerank_provider"] = rerank_provider
attrs["dify.retrieval.rerank_model"] = rerank_model
ref = f"ref:message_id={info.message_id}"
retrieval_inputs = self._safe_payload_value(info.inputs)
attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref)
attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL,
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None),
span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None),
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="dataset_retrieval",
),
)
for did in dataset_ids:
# Get embedding model for this specific dataset
ds_embedding_info = embedding_models.get(did, {})
embedding_provider = ds_embedding_info.get("embedding_model_provider", "")
embedding_model = ds_embedding_info.get("embedding_model", "")
# Get rerank model (same for all datasets in this retrieval)
rerank_provider = metadata.get("rerank_model_provider", "")
rerank_model = metadata.get("rerank_model_name", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.DATASET_RETRIEVALS,
1,
self._labels(
**labels,
dataset_id=did,
embedding_model_provider=embedding_provider,
embedding_model=embedding_model,
rerank_model_provider=rerank_provider,
rerank_model=rerank_model,
),
)
def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.conversation.id"] = info.conversation_id
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
duration: float | None = None
if info.start_time is not None and info.end_time is not None:
duration = (info.end_time - info.start_time).total_seconds()
error: str | None = metadata.get("error") if metadata else None
status = "failed" if error else "succeeded"
attrs["dify.generate_name.duration"] = duration
attrs["dify.generate_name.status"] = status
attrs["dify.generate_name.error"] = error
ref = f"ref:conversation_id={info.conversation_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="generate_name",
),
)
def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"gen_ai.user.id": user_id,
"dify.app.id": app_id or "",
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.prompt_generation.operation_type": info.operation_type,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.prompt_generation.latency": info.latency,
"dify.prompt_generation.error": info.error,
}
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
if info.total_price is not None:
attrs["dify.prompt_generation.total_price"] = info.total_price
attrs["dify.prompt_generation.currency"] = info.currency
ref = f"ref:trace_id={info.trace_id}"
outputs = self._safe_payload_value(info.outputs)
attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref)
attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
node_type="",
).to_dict()
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
)
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
if info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
status = "failed" if info.error else "success"
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="prompt_generation",
status=status,
),
)
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION,
info.latency,
labels,
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="prompt_generation",
),
)

View File

@ -0,0 +1,121 @@
from enum import StrEnum
from typing import cast
from opentelemetry.util.types import AttributeValue
from pydantic import BaseModel, ConfigDict
class EnterpriseTelemetrySpan(StrEnum):
WORKFLOW_RUN = "dify.workflow.run"
NODE_EXECUTION = "dify.node.execution"
DRAFT_NODE_EXECUTION = "dify.node.execution.draft"
class EnterpriseTelemetryEvent(StrEnum):
"""Event names for enterprise telemetry logs."""
APP_CREATED = "dify.app.created"
APP_UPDATED = "dify.app.updated"
APP_DELETED = "dify.app.deleted"
FEEDBACK_CREATED = "dify.feedback.created"
WORKFLOW_RUN = "dify.workflow.run"
MESSAGE_RUN = "dify.message.run"
TOOL_EXECUTION = "dify.tool.execution"
MODERATION_CHECK = "dify.moderation.check"
SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation"
DATASET_RETRIEVAL = "dify.dataset.retrieval"
GENERATE_NAME_EXECUTION = "dify.generate_name.execution"
PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution"
REHYDRATION_FAILED = "dify.telemetry.rehydration_failed"
class EnterpriseTelemetryCounter(StrEnum):
TOKENS = "tokens"
INPUT_TOKENS = "input_tokens"
OUTPUT_TOKENS = "output_tokens"
REQUESTS = "requests"
ERRORS = "errors"
FEEDBACK = "feedback"
DATASET_RETRIEVALS = "dataset_retrievals"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
class EnterpriseTelemetryHistogram(StrEnum):
WORKFLOW_DURATION = "workflow_duration"
NODE_DURATION = "node_duration"
MESSAGE_DURATION = "message_duration"
MESSAGE_TTFT = "message_ttft"
TOOL_DURATION = "tool_duration"
PROMPT_GENERATION_DURATION = "prompt_generation_duration"
class TokenMetricLabels(BaseModel):
"""Unified label structure for all dify.token.* metrics.
All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST
use this exact label set to ensure consistent filtering and aggregation across
different operation types.
Attributes:
tenant_id: Tenant identifier.
app_id: Application identifier.
operation_type: Source of token usage (workflow | node_execution | message |
rule_generate | code_generate | structured_output | instruction_modify).
model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level).
model_name: LLM model name. Empty string if not applicable (e.g., workflow-level).
node_type: Workflow node type. Empty string unless operation_type=node_execution.
Usage:
labels = TokenMetricLabels(
tenant_id="tenant-123",
app_id="app-456",
operation_type=OperationType.WORKFLOW,
model_provider="",
model_name="",
node_type="",
)
exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS,
100,
labels.to_dict()
)
Design rationale:
Without this unified structure, tokens get double-counted when querying totals
because workflow.total_tokens is already the sum of all node tokens. The
operation_type label allows filtering to separate workflow-level aggregates from
node-level detail, while keeping the same label cardinality for consistent queries.
"""
tenant_id: str
app_id: str
operation_type: str
model_provider: str
model_name: str
node_type: str
model_config = ConfigDict(extra="forbid", frozen=True)
def to_dict(self) -> dict[str, AttributeValue]:
return cast(
dict[str, AttributeValue],
{
"tenant_id": self.tenant_id,
"app_id": self.app_id,
"operation_type": self.operation_type,
"model_provider": self.model_provider,
"model_name": self.model_name,
"node_type": self.node_type,
},
)
__all__ = [
"EnterpriseTelemetryCounter",
"EnterpriseTelemetryEvent",
"EnterpriseTelemetryHistogram",
"EnterpriseTelemetrySpan",
"TokenMetricLabels",
]

View File

@ -0,0 +1,40 @@
"""Blinker signal handlers for enterprise telemetry.
Registered at import time via ``@signal.connect`` decorators.
Import must happen during ``ext_enterprise_telemetry.init_app()`` to
ensure handlers fire. Each handler delegates to ``core.telemetry.gateway``
which handles routing, EE-gating, and dispatch.
All handlers are best-effort: exceptions are caught and logged so that
telemetry failures never break user-facing operations.
"""
from __future__ import annotations
import logging
from events.app_event import app_was_created
logger = logging.getLogger(__name__)
__all__ = [
"_handle_app_created",
]
@app_was_created.connect
def _handle_app_created(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_CREATED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={
"app_id": getattr(sender, "id", None),
"mode": getattr(sender, "mode", None),
},
)
except Exception:
logger.warning("Failed to emit app_created telemetry", exc_info=True)

View File

@ -0,0 +1,289 @@
"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation.
Uses dedicated TracerProvider and MeterProvider instances (configurable sampling,
independent from ext_otel.py infrastructure).
Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py).
Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process.
"""
import logging
import socket
import uuid
from datetime import UTC, datetime
from typing import Any, cast
from opentelemetry import trace
from opentelemetry.context import Context
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanContext, TraceFlags
from opentelemetry.util.types import Attributes, AttributeValue
from configs import dify_config
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
from enterprise.telemetry.id_generator import (
CorrelationIdGenerator,
compute_deterministic_span_id,
set_correlation_id,
set_span_id_source,
)
logger = logging.getLogger(__name__)
def is_enterprise_telemetry_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def _parse_otlp_headers(raw: str) -> dict[str, str]:
"""Parse ``key=value,key2=value2`` into a dict."""
if not raw:
return {}
headers: dict[str, str] = {}
for pair in raw.split(","):
if "=" not in pair:
continue
k, v = pair.split("=", 1)
headers[k.strip().lower()] = v.strip()
return headers
def _datetime_to_ns(dt: datetime) -> int:
"""Convert a datetime to nanoseconds since epoch (OTEL convention)."""
# Ensure we always interpret naive datetimes as UTC instead of local time.
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
else:
dt = dt.astimezone(UTC)
return int(dt.timestamp() * 1_000_000_000)
class _ExporterFactory:
def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool):
self._protocol = protocol
self._endpoint = endpoint
self._headers = headers
self._grpc_headers = tuple(headers.items()) if headers else None
self._http_headers = headers or None
self._insecure = insecure
def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter:
if self._protocol == "grpc":
return GRPCSpanExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=self._insecure,
)
trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else ""
return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers)
def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter:
if self._protocol == "grpc":
return GRPCMetricExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=self._insecure,
)
metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else ""
return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers)
class EnterpriseExporter:
"""Shared OTEL exporter for all enterprise telemetry.
``export_span`` creates spans with optional real timestamps, deterministic
span/trace IDs, and cross-workflow parent linking.
``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy.
"""
def __init__(self, config: object) -> None:
endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "")
headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "")
protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower()
service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify")
sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0)
self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True)
api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "")
# Auto-detect TLS: https:// uses secure, everything else is insecure
insecure = not endpoint.startswith("https://")
resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
sampler = ParentBasedTraceIdRatio(sampling_rate)
id_generator = CorrelationIdGenerator()
self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator)
headers = _parse_otlp_headers(headers_raw)
if api_key:
if "authorization" in headers:
logger.warning(
"ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains "
"'authorization'; the API key will take precedence."
)
headers["authorization"] = f"Bearer {api_key}"
factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure)
trace_exporter = factory.create_trace_exporter()
self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
self._tracer = self._tracer_provider.get_tracer("dify.enterprise")
metric_exporter = factory.create_metric_exporter()
self._meter_provider = MeterProvider(
resource=resource,
metric_readers=[PeriodicExportingMetricReader(metric_exporter)],
)
meter = self._meter_provider.get_meter("dify.enterprise")
self._counters = {
EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"),
EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"),
EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"),
EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"),
EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"),
EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"),
EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter(
"dify.dataset.retrievals.total", unit="{retrieval}"
),
EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"),
EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"),
EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"),
}
self._histograms = {
EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"),
EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram(
"dify.message.time_to_first_token", unit="s"
),
EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"),
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram(
"dify.prompt_generation.duration", unit="s"
),
}
def export_span(
self,
name: str,
attributes: dict[str, Any],
correlation_id: str | None = None,
span_id_source: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
trace_correlation_override: str | None = None,
parent_span_id_source: str | None = None,
) -> None:
"""Export an OTEL span with optional deterministic IDs and real timestamps.
Args:
name: Span operation name.
attributes: Span attributes dict.
correlation_id: Source for trace_id derivation (groups spans in one trace).
span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id).
start_time: Real span start time. When None, uses current time.
end_time: Real span end time. When None, span ends immediately.
trace_correlation_override: Override trace_id source (for cross-workflow linking).
When set, trace_id is derived from this instead of ``correlation_id``.
parent_span_id_source: Override parent span_id source (for cross-workflow linking).
When set, parent span_id is derived from this value. When None and
``correlation_id`` is set, parent is the workflow root span.
"""
effective_trace_correlation = trace_correlation_override or correlation_id
set_correlation_id(effective_trace_correlation)
set_span_id_source(span_id_source)
try:
parent_context: Context | None = None
# A span is the "root" of its correlation group when span_id_source == correlation_id
# (i.e. a workflow root span). All other spans are children.
if parent_span_id_source:
# Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow)
parent_span_id = compute_deterministic_span_id(parent_span_id_source)
try:
parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0
except (ValueError, AttributeError):
logger.warning(
"Invalid trace correlation UUID for cross-workflow link: %s, span=%s",
effective_trace_correlation,
name,
)
parent_trace_id = 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
elif correlation_id and correlation_id != span_id_source:
# Child span: parent is the correlation-group root (workflow root span)
parent_span_id = compute_deterministic_span_id(correlation_id)
try:
parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id))
except (ValueError, AttributeError):
logger.warning(
"Invalid trace correlation UUID for child span link: %s, span=%s",
effective_trace_correlation or correlation_id,
name,
)
parent_trace_id = 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
span_start_time = _datetime_to_ns(start_time) if start_time is not None else None
span_end_on_exit = end_time is None
with self._tracer.start_as_current_span(
name,
context=parent_context,
start_time=span_start_time,
end_on_exit=span_end_on_exit,
) as span:
for key, value in attributes.items():
if value is not None:
span.set_attribute(key, value)
if end_time is not None:
span.end(end_time=_datetime_to_ns(end_time))
except Exception:
logger.exception("Failed to export span %s", name)
finally:
set_correlation_id(None)
set_span_id_source(None)
def increment_counter(
self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue]
) -> None:
counter = self._counters.get(name)
if counter:
counter.add(value, cast(Attributes, labels))
def record_histogram(
self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue]
) -> None:
histogram = self._histograms.get(name)
if histogram:
histogram.record(value, cast(Attributes, labels))
def shutdown(self) -> None:
self._tracer_provider.shutdown()
self._meter_provider.shutdown()

View File

@ -0,0 +1,75 @@
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
When a span_id_source is set, the span_id is derived deterministically
from that value, enabling any span to reference another as parent
without depending on span creation order.
"""
import random
import uuid
from contextvars import ContextVar
from opentelemetry.sdk.trace.id_generator import IdGenerator
_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None)
_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None)
def set_correlation_id(correlation_id: str | None) -> None:
_correlation_id_context.set(correlation_id)
def get_correlation_id() -> str | None:
return _correlation_id_context.get()
def set_span_id_source(source_id: str | None) -> None:
"""Set the source for deterministic span_id generation.
When set, ``generate_span_id()`` derives the span_id from this value
(lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow
root spans or ``node_execution_id`` for node spans.
"""
_span_id_source_context.set(source_id)
def compute_deterministic_span_id(source_id: str) -> int:
"""Derive a deterministic span_id from any UUID string.
Uses the lower 64 bits of the UUID, guaranteeing non-zero output
(OTEL requires span_id != 0).
"""
span_id = uuid.UUID(source_id).int & ((1 << 64) - 1)
return span_id if span_id != 0 else 1
class CorrelationIdGenerator(IdGenerator):
"""ID generator that derives trace_id and optionally span_id from context.
- trace_id: always derived from correlation_id (groups all spans in one trace)
- span_id: derived from span_id_source when set (enables deterministic
parent-child linking), otherwise random
"""
def generate_trace_id(self) -> int:
correlation_id = _correlation_id_context.get()
if correlation_id:
try:
return uuid.UUID(correlation_id).int
except (ValueError, AttributeError):
pass
return random.getrandbits(128)
def generate_span_id(self) -> int:
source = _span_id_source_context.get()
if source:
try:
return compute_deterministic_span_id(source)
except (ValueError, AttributeError):
pass
span_id = random.getrandbits(64)
while span_id == 0:
span_id = random.getrandbits(64)
return span_id

View File

@ -0,0 +1,421 @@
"""Enterprise metric/log event handler.
This module processes metric and log telemetry events after they've been
dequeued from the enterprise_telemetry Celery queue. It handles case routing,
idempotency checking, and payload rehydration.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import Any
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
logger = logging.getLogger(__name__)
class EnterpriseMetricHandler:
"""Handler for enterprise metric and log telemetry events.
Processes envelopes from the enterprise_telemetry queue, routing each
case to the appropriate handler method. Implements idempotency checking
and payload rehydration with fallback.
"""
def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None:
"""Increment a diagnostic counter for operational monitoring.
Args:
counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total').
labels: Optional labels for the counter.
"""
try:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
return
full_counter_name = f"enterprise_telemetry.handler.{counter_name}"
logger.debug(
"Diagnostic counter: %s, labels=%s",
full_counter_name,
labels or {},
)
except Exception:
logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True)
def handle(self, envelope: TelemetryEnvelope) -> None:
"""Main entry point for processing telemetry envelopes.
Args:
envelope: The telemetry envelope to process.
"""
# Check for duplicate events
if self._is_duplicate(envelope):
logger.debug(
"Skipping duplicate event: tenant_id=%s, event_id=%s",
envelope.tenant_id,
envelope.event_id,
)
self._increment_diagnostic_counter("deduped_total")
return
# Route to appropriate handler based on case
case = envelope.case
if case == TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
elif case == TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
elif case == TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
elif case == TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
elif case == TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
elif case == TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
elif case == TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
elif case == TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
elif case == TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
elif case == TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
elif case == TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
else:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
"""Check if this event has already been processed.
Uses Redis with TTL for deduplication. Returns True if duplicate,
False if first time seeing this event.
Args:
envelope: The telemetry envelope to check.
Returns:
True if this event_id has been seen before, False otherwise.
"""
dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}"
try:
# Atomic set-if-not-exists with 1h TTL
# Returns True if key was set (first time), None if already exists (duplicate)
was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600)
return was_set is None
except Exception:
# Fail open: if Redis is unavailable, process the event
# (prefer occasional duplicate over lost data)
logger.warning(
"Redis unavailable for deduplication check, processing event anyway: %s",
envelope.event_id,
exc_info=True,
)
return False
def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]:
"""Rehydrate payload from storage reference or inline data.
If the envelope payload is empty and metadata contains a
``payload_ref``, the full payload is loaded from object storage
(where the gateway wrote it as JSON). When both the inline
payload and storage resolution fail, a degraded-event marker
is emitted so the gap is observable.
Args:
envelope: The telemetry envelope containing payload data.
Returns:
The rehydrated payload dictionary, or ``{}`` on total failure.
"""
payload = envelope.payload
# Resolve from object storage when the gateway offloaded a large payload.
if not payload and envelope.metadata:
payload_ref = envelope.metadata.get("payload_ref")
if payload_ref:
try:
payload_bytes = storage.load(payload_ref)
payload = json.loads(payload_bytes.decode("utf-8"))
logger.debug("Loaded payload from storage: key=%s", payload_ref)
except Exception:
logger.warning(
"Failed to load payload from storage: key=%s, event_id=%s",
payload_ref,
envelope.event_id,
exc_info=True,
)
if not payload:
# Storage resolution failed or no data available — emit degraded event.
logger.error(
"Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s",
envelope.event_id,
envelope.tenant_id,
envelope.case,
)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED,
attributes={
"dify.tenant_id": envelope.tenant_id,
"dify.event_id": envelope.event_id,
"dify.case": envelope.case,
"rehydration_failed": True,
},
tenant_id=envelope.tenant_id,
)
self._increment_diagnostic_counter("rehydration_failed_total")
return {}
return payload
# Stub methods for each metric/log case
# These will be implemented in later tasks with actual emission logic
def _on_app_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle app created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.mode": payload.get("mode"),
"dify.app.created_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_CREATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_CREATED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"mode": str(payload.get("mode", "")),
},
)
def _on_app_updated(self, envelope: TelemetryEnvelope) -> None:
"""Handle app updated event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.updated_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_UPDATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_UPDATED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
},
)
def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None:
"""Handle app deleted event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.deleted_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_DELETED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_DELETED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
},
)
def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle feedback created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
include_content = exporter.include_content
attrs: dict = {
"dify.message.id": payload.get("message_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app_id": payload.get("app_id"),
"dify.conversation.id": payload.get("conversation_id"),
"gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"),
"dify.feedback.rating": payload.get("rating"),
"dify.feedback.from_source": payload.get("from_source"),
"dify.feedback.created_at": datetime.now(UTC).isoformat(),
}
if include_content:
attrs["dify.feedback.content"] = payload.get("content")
user_id = payload.get("from_end_user_id") or payload.get("from_account_id")
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
user_id=str(user_id or ""),
)
exporter.increment_counter(
EnterpriseTelemetryCounter.FEEDBACK,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"rating": str(payload.get("rating", "")),
},
)
def _on_message_run(self, envelope: TelemetryEnvelope) -> None:
"""Handle message run event.
Intentionally a no-op: metrics and structured logs for message runs are
emitted directly by EnterpriseOtelTrace._message_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id)
def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None:
"""Handle tool execution event.
Intentionally a no-op: metrics and structured logs for tool executions
are emitted directly by EnterpriseOtelTrace._tool_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id)
def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None:
"""Handle moderation check event.
Intentionally a no-op: metrics and structured logs for moderation checks
are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id)
def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None:
"""Handle suggested question event.
Intentionally a no-op: metrics and structured logs for suggested questions
are emitted directly by EnterpriseOtelTrace._suggested_question_trace at
trace time, not through the metric handler queue path.
"""
logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id)
def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None:
"""Handle dataset retrieval event.
Intentionally a no-op: metrics and structured logs for dataset retrievals
are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at
trace time, not through the metric handler queue path.
"""
logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id)
def _on_generate_name(self, envelope: TelemetryEnvelope) -> None:
"""Handle generate name event.
Intentionally a no-op: metrics and structured logs for generate name
operations are emitted directly by EnterpriseOtelTrace._generate_name_trace
at trace time, not through the metric handler queue path.
"""
logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id)
def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None:
"""Handle prompt generation event.
Intentionally a no-op: metrics and structured logs for prompt generation
operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace
at trace time, not through the metric handler queue path.
"""
logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id)

View File

@ -0,0 +1,122 @@
"""Structured-log emitter for enterprise telemetry events.
Emits structured JSON log lines correlated with OTEL traces via trace_id.
Picked up by ``StructuredJSONFormatter`` stdout/Loki/Elastic.
"""
from __future__ import annotations
import logging
import uuid
from functools import lru_cache
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
logger = logging.getLogger("dify.telemetry")
@lru_cache(maxsize=4096)
def compute_trace_id_hex(uuid_str: str | None) -> str:
"""Convert a business UUID string to a 32-hex OTEL-compatible trace_id.
Returns empty string when *uuid_str* is ``None`` or invalid.
"""
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
return f"{uuid.UUID(normalized).int:032x}"
except (ValueError, AttributeError):
return ""
@lru_cache(maxsize=4096)
def compute_span_id_hex(uuid_str: str | None) -> str:
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
return f"{compute_deterministic_span_id(normalized):016x}"
except (ValueError, AttributeError):
return ""
def emit_telemetry_log(
*,
event_name: str | EnterpriseTelemetryEvent,
attributes: dict[str, Any],
signal: str = "metric_only",
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
"""Emit a structured log line for a telemetry event.
Parameters
----------
event_name:
Canonical event name, e.g. ``"dify.workflow.run"``.
attributes:
All event-specific attributes (already built by the caller).
signal:
``"metric_only"`` for events with no span, ``"span_detail"``
for detail logs accompanying a slim span.
trace_id_source:
A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex
trace_id for cross-signal correlation.
tenant_id:
Tenant identifier (for the ``IdentityContextFilter``).
user_id:
User identifier (for the ``IdentityContextFilter``).
"""
if not logger.isEnabledFor(logging.INFO):
return
attrs = {
"dify.event.name": event_name,
"dify.event.signal": signal,
**attributes,
}
extra: dict[str, Any] = {"attributes": attrs}
trace_id_hex = compute_trace_id_hex(trace_id_source)
if trace_id_hex:
extra["trace_id"] = trace_id_hex
span_id_hex = compute_span_id_hex(span_id_source)
if span_id_hex:
extra["span_id"] = span_id_hex
if tenant_id:
extra["tenant_id"] = tenant_id
if user_id:
extra["user_id"] = user_id
logger.info("telemetry.%s", signal, extra=extra)
def emit_metric_only_event(
*,
event_name: str | EnterpriseTelemetryEvent,
attributes: dict[str, Any],
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
emit_telemetry_log(
event_name=event_name,
attributes=attributes,
signal="metric_only",
trace_id_source=trace_id_source,
span_id_source=span_id_source,
tenant_id=tenant_id,
user_id=user_id,
)

View File

@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery:
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
}
if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED:
imports.append("tasks.enterprise_telemetry_task")
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@ -0,0 +1,50 @@
"""Flask extension for enterprise telemetry lifecycle management.
Initializes the EnterpriseExporter singleton during ``create_app()``
(single-threaded), registers blinker event handlers, and hooks atexit
for graceful shutdown.
Skipped entirely when ``ENTERPRISE_ENABLED`` and ``ENTERPRISE_TELEMETRY_ENABLED``
are false (``is_enabled()`` gate).
"""
from __future__ import annotations
import atexit
import logging
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from dify_app import DifyApp
from enterprise.telemetry.exporter import EnterpriseExporter
logger = logging.getLogger(__name__)
_exporter: EnterpriseExporter | None = None
def is_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def init_app(app: DifyApp) -> None:
global _exporter
if not is_enabled():
return
from enterprise.telemetry.exporter import EnterpriseExporter
_exporter = EnterpriseExporter(dify_config)
atexit.register(_exporter.shutdown)
# Import to trigger @signal.connect decorator registration
import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport]
logger.info("Enterprise telemetry initialized")
def get_enterprise_exporter() -> EnterpriseExporter | None:
return _exporter

View File

@ -78,16 +78,24 @@ def init_app(app: DifyApp):
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
if protocol == "grpc":
# Auto-detect TLS: https:// uses secure, everything else is insecure
endpoint = dify_config.OTLP_BASE_ENDPOINT
insecure = not endpoint.startswith("https://")
exporter = GRPCSpanExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
endpoint=endpoint,
# Header field names must consist of lowercase letters, check RFC7540
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
headers=(
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None
),
insecure=insecure,
)
metric_exporter = GRPCMetricExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
endpoint=endpoint,
headers=(
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None
),
insecure=insecure,
)
else:
headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None

View File

@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set
OpenTelemetry span attributes according to semantic conventions.
"""
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content
from extensions.otel.parser.llm import LLMNodeOTelParser
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
from extensions.otel.parser.tool import ToolNodeOTelParser
@ -17,4 +17,5 @@ __all__ = [
"RetrievalNodeOTelParser",
"ToolNodeOTelParser",
"safe_json_dumps",
"should_include_content",
]

View File

@ -1,5 +1,10 @@
"""
Base parser interface and utilities for OpenTelemetry node parsers.
Content gating: ``should_include_content()`` controls whether content-bearing
span attributes (inputs, outputs, prompts, completions, documents) are written.
Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when
``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged.
"""
import json
@ -9,6 +14,7 @@ from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from pydantic import BaseModel
from configs import dify_config
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.file.models import File
from dify_graph.graph_events import GraphNodeEventBase
@ -17,6 +23,17 @@ from dify_graph.variables import Segment
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
def should_include_content() -> bool:
"""Return True if content should be written to spans.
CE (ENTERPRISE_ENABLED=False): always True no behaviour change.
EE: follows ENTERPRISE_INCLUDE_CONTENT (default True).
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
return dify_config.ENTERPRISE_INCLUDE_CONTENT
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
"""
Safely serialize objects to JSON, handling non-serializable types.
@ -101,10 +118,11 @@ class DefaultNodeOTelParser:
# Extract inputs and outputs from result_event
if result_event and result_event.node_run_result:
node_run_result = result_event.node_run_result
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if should_include_content():
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if error:
span.record_exception(error)

View File

@ -21,3 +21,15 @@ class DifySpanAttributes:
INVOKE_FROM = "dify.invoke_from"
"""Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER."""
INVOKED_BY = "dify.invoked_by"
"""Invoked by, e.g. end_user, account, user."""
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
"""Number of input tokens (prompt tokens) used."""
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""Number of output tokens (completion tokens) generated."""
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
"""Total number of tokens used."""

View File

@ -1,16 +1,19 @@
import logging
import sys
import urllib.parse
from dataclasses import dataclass
from typing import NotRequired
import httpx
from pydantic import TypeAdapter
from pydantic import TypeAdapter, ValidationError
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
logger = logging.getLogger(__name__)
JsonObject = dict[str, object]
JsonObjectList = list[JsonObject]
@ -30,8 +33,8 @@ class GitHubEmailRecord(TypedDict, total=False):
class GitHubRawUserInfo(TypedDict):
id: int | str
login: str
name: NotRequired[str]
email: NotRequired[str]
name: NotRequired[str | None]
email: NotRequired[str | None]
class GoogleRawUserInfo(TypedDict):
@ -127,9 +130,14 @@ class GitHubOAuth(OAuth):
response.raise_for_status()
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
try:
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_response.raise_for_status()
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
except (httpx.HTTPStatusError, ValidationError):
logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True)
primary_email = None
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
@ -137,8 +145,11 @@ class GitHubOAuth(OAuth):
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
email = payload.get("email")
if not email:
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
raise ValueError(
'Dify currently not supports the "Keep my email addresses private" feature,'
" please disable it and login again"
)
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email)
class GoogleOAuth(OAuth):

View File

@ -43,7 +43,9 @@ from .enums import (
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SegmentType,
SummaryStatus,
TidbAuthBindingStatus,
)
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
@ -494,7 +496,9 @@ class Document(Base):
)
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_form: Mapped[IndexStructureType] = mapped_column(
EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'")
)
doc_language = mapped_column(String(255), nullable=True)
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@ -998,7 +1002,9 @@ class ChildChunk(Base):
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase):
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
status: Mapped[TidbAuthBindingStatus] = mapped_column(
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
)
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(

View File

@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum):
TIME = "time"
class SegmentType(StrEnum):
"""Document segment type"""
AUTOMATIC = "automatic"
CUSTOMIZED = "customized"
class SegmentStatus(StrEnum):
"""Document segment status"""
@ -323,3 +330,10 @@ class ProviderQuotaType(StrEnum):
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ApiTokenType(StrEnum):
"""API Token type"""
APP = "app"
DATASET = "dataset"

View File

@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent):
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
@classmethod
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
return cls(form_id=form_id, message_id=message_id)
def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent":
return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id)
form: Mapped["HumanInputForm"] = relationship(
"HumanInputForm",

View File

@ -21,7 +21,7 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.tools.signature import sign_tool_file
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from dify_graph.file import helpers as file_helpers
from extensions.storage.storage_type import StorageType
from libs.helper import generate_string # type: ignore[import-not-found]
@ -31,6 +31,7 @@ from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db
from .enums import (
ApiTokenType,
AppMCPServerStatus,
AppStatus,
BannerStatus,
@ -43,6 +44,8 @@ from .enums import (
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
ProviderQuotaType,
TagType,
)
from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID
@ -1782,7 +1785,7 @@ class MessageFile(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(
EnumText(FileTransferMethod, length=255), nullable=False
)
@ -2094,7 +2097,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=True)
tenant_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(String(16), nullable=False)
type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False)
token: Mapped[str] = mapped_column(String(255), nullable=False)
last_used_at = mapped_column(sa.DateTime, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -2404,7 +2407,7 @@ class Tag(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(16), nullable=False)
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
@ -2489,7 +2492,9 @@ class TenantCreditPool(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
pool_type: Mapped[ProviderQuotaType] = mapped_column(
EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial"
)
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(

View File

@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from core.tools.entities.tool_entities import (
ApiProviderSchemaType,
ToolProviderType,
WorkflowToolParameterConfiguration,
)
from .base import TypeBase
from .engine import db
from .model import Account, App, Tenant
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
if TYPE_CHECKING:
from core.entities.mcp_provider import MCPProviderEntity
@ -141,7 +145,9 @@ class ApiToolProvider(TypeBase):
icon: Mapped[str] = mapped_column(String(255), nullable=False)
# original schema
schema: Mapped[str] = mapped_column(LongText, nullable=False)
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column(
EnumText(ApiProviderSchemaType, length=40), nullable=False
)
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
@ -208,7 +214,7 @@ class ToolLabelBinding(TypeBase):
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# label name
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
@ -386,7 +392,7 @@ class ToolModelInvoke(TypeBase):
# provider
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# type
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
# tool name
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
# invoke parameters

View File

@ -1221,7 +1221,9 @@ class WorkflowAppLog(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column(
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False
)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
@ -1301,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase):
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column(
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True
)
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
run_status: Mapped[WorkflowExecutionStatus] = mapped_column(
EnumText(WorkflowExecutionStatus, length=255), nullable=False
)
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
)

View File

@ -8,7 +8,7 @@ dependencies = [
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.3",
"beautifulsoup4==4.14.3",
"boto3==1.42.68",
"boto3==1.42.73",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.6.2",
@ -23,7 +23,7 @@ dependencies = [
"gevent~=25.9.1",
"gmpy2~=2.3.0",
"google-api-core>=2.19.1",
"google-api-python-client==2.192.0",
"google-api-python-client==2.193.0",
"google-auth>=2.47.0",
"google-auth-httplib2==0.3.0",
"google-cloud-aiplatform>=1.123.0",
@ -40,7 +40,7 @@ dependencies = [
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.28.0",
"opentelemetry-distro==0.49b0",
"opentelemetry-exporter-otlp==1.28.0",
@ -72,13 +72,14 @@ dependencies = [
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=7.3.0",
"resend~=2.23.0",
"sentry-sdk[flask]~=2.54.0",
"resend~=2.26.0",
"sentry-sdk[flask]~=2.55.0",
"sqlalchemy~=2.0.29",
"starlette==0.52.1",
"starlette==1.0.0",
"tiktoken~=0.12.0",
"transformers~=5.3.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
"pypandoc~=1.13",
"yarl~=1.23.0",
"webvtt-py~=0.5.1",
"sseclient-py~=1.9.0",
@ -91,7 +92,7 @@ dependencies = [
"apscheduler>=3.11.0",
"weave>=0.52.16",
"fastopenapi[flask]>=0.7.0",
"bleach~=6.2.0",
"bleach~=6.3.0",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@ -118,7 +119,7 @@ dev = [
"ruff~=0.15.5",
"pytest~=9.0.2",
"pytest-benchmark~=5.2.3",
"pytest-cov~=7.0.0",
"pytest-cov~=7.1.0",
"pytest-env~=1.6.0",
"pytest-mock~=3.15.1",
"testcontainers~=4.14.1",
@ -202,7 +203,7 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
# Required by vector store clients
############################################################
vdb = [
"alibabacloud_gpdb20160503~=3.8.0",
"alibabacloud_gpdb20160503~=5.1.0",
"alibabacloud_tea_openapi~=0.4.3",
"chromadb==0.5.20",
"clickhouse-connect~=0.14.1",

View File

@ -8,6 +8,7 @@ from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
@app.celery.task(queue="dataset")
@ -57,7 +58,7 @@ def create_clusters(batch_size):
account=new_cluster["account"],
password=new_cluster["password"],
active=False,
status="CREATING",
status=TidbAuthBindingStatus.CREATING,
)
db.session.add(tidb_auth_binding)
db.session.commit()

View File

@ -9,6 +9,7 @@ from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
@app.celery.task(queue="dataset")
@ -18,7 +19,10 @@ def update_tidb_serverless_status_task():
try:
# check the number of idle tidb serverless
tidb_serverless_list = db.session.scalars(
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
select(TidbAuthBinding).where(
TidbAuthBinding.active == False,
TidbAuthBinding.status == TidbAuthBindingStatus.CREATING,
)
).all()
if len(tidb_serverless_list) == 0:
return

View File

@ -241,7 +241,7 @@ class AppService:
class ArgsDict(TypedDict):
name: str
description: str
icon_type: str
icon_type: IconType | str | None
icon: str
icon_background: str
use_icon_as_answer_icon: bool
@ -257,7 +257,13 @@ class AppService:
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None
icon_type = args.get("icon_type")
if icon_type is None:
resolved_icon_type = app.icon_type
else:
resolved_icon_type = IconType(icon_type)
app.icon_type = resolved_icon_type
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)

View File

@ -1,8 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any
from typing_extensions import TypedDict
class AuthCredentials(TypedDict):
auth_type: str
config: dict[str, Any]
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
self.credentials = credentials
@abstractmethod

View File

@ -1,9 +1,9 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
def __init__(self, provider: str, credentials: AuthCredentials):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@ -3,11 +3,11 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: AuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":

View File

@ -335,7 +335,11 @@ class BillingService:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
plan_dict = json.loads(json_str)
# NOTE (hj24): New billing versions may return timestamp as str, and validate_python
# in non-strict mode will coerce it to the expected int type.
# To preserve compatibility, always keep non-strict mode here and avoid strict mode.
subscription_plan = subscription_adapter.validate_python(plan_dict)
# NOTE END
tenant_plans[tenant_id] = subscription_plan
except Exception:
logger.exception(

View File

@ -7,6 +7,7 @@ from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
from models.enums import ProviderQuotaType
logger = logging.getLogger(__name__)
@ -16,7 +17,10 @@ class CreditPoolService:
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
credit_pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
tenant_id=tenant_id,
quota_limit=dify_config.HOSTED_POOL_CREDITS,
quota_used=0,
pool_type=ProviderQuotaType.TRIAL,
)
db.session.add(credit_pool)
db.session.commit()

View File

@ -58,6 +58,7 @@ from models.enums import (
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SegmentType,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
@ -1439,7 +1440,7 @@ class DocumentService:
.filter(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != "qa_model", # Skip qa_model documents
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
)
@ -2039,7 +2040,7 @@ class DocumentService:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_form = IndexStructureType(knowledge_config.doc_form)
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
@ -2639,7 +2640,7 @@ class DocumentService:
document.splitting_completed_at = None
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = document_data.doc_form
document.doc_form = IndexStructureType(document_data.doc_form)
db.session.add(document)
db.session.commit()
# update document segment
@ -3100,7 +3101,7 @@ class DocumentService:
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
if "answer" not in args or not args["answer"]:
raise ValueError("Answer is required")
if not args["answer"].strip():
@ -3157,7 +3158,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
@ -3231,7 +3232,7 @@ class SegmentService:
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
@ -3254,7 +3255,7 @@ class SegmentService:
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
@ -3321,7 +3322,7 @@ class SegmentService:
content = args.content or segment.content
if segment.content == content:
segment.word_count = len(content)
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3418,7 +3419,7 @@ class SegmentService:
)
# calc embedding use tokens
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
else:
@ -3435,7 +3436,7 @@ class SegmentService:
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
if document.doc_form == "qa_model":
if document.doc_form == IndexStructureType.QA_INDEX:
segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change
@ -3786,7 +3787,7 @@ class SegmentService:
child_chunk.word_count = len(child_chunk.content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
child_chunk.type = SegmentType.CUSTOMIZED
update_child_chunks.append(child_chunk)
else:
new_child_chunks_args.append(child_chunk_update_args)
@ -3845,7 +3846,7 @@ class SegmentService:
child_chunk.word_count = len(content)
child_chunk.updated_by = current_user.id
child_chunk.updated_at = naive_utc_now()
child_chunk.type = "customized"
child_chunk.type = SegmentType.CUSTOMIZED
db.session.add(child_chunk)
VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
db.session.commit()

View File

@ -9,6 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
@ -79,9 +80,9 @@ class RagPipelineTransformService:
pipeline = self._create_pipeline(pipeline_yaml)
# save chunk structure to dataset
if doc_form == "hierarchical_model":
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset.chunk_structure = "hierarchical_model"
elif doc_form == "text_model":
elif doc_form == IndexStructureType.PARAGRAPH_INDEX:
dataset.chunk_structure = "text_model"
else:
raise ValueError("Unsupported doc form")
@ -101,7 +102,7 @@ class RagPipelineTransformService:
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
pipeline_yaml = {}
if doc_form == "text_model":
if doc_form == IndexStructureType.PARAGRAPH_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
if indexing_technique == "high_quality":
@ -132,7 +133,7 @@ class RagPipelineTransformService:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
elif doc_form == IndexStructureType.PARENT_CHILD_INDEX:
match datasource_type:
case DataSourceType.UPLOAD_FILE:
# get graph from transform.file-parentchild.yml

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import TagType
from models.model import App, Tag, TagBinding
@ -83,7 +84,7 @@ class TagService:
raise ValueError("Tag name already exists")
tag = Tag(
name=args["name"],
type=args["type"],
type=TagType(args["type"]),
created_by=current_user.id,
tenant_id=current_user.current_tenant_id,
)

View File

@ -11,6 +11,7 @@ from sqlalchemy import func
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexStructureType
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -109,7 +110,7 @@ def batch_create_segment_to_index_task(
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
@ -159,7 +160,7 @@ def batch_create_segment_to_index_task(
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count

View File

@ -10,6 +10,7 @@ from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
)
if (
document.indexing_status == IndexingStatus.COMPLETED
and document.doc_form != "qa_model"
and document.doc_form != IndexStructureType.QA_INDEX
and document.need_summary is True
):
try:

View File

@ -0,0 +1,52 @@
"""Celery worker for enterprise metric/log telemetry events.
This module defines the Celery task that processes telemetry envelopes
from the enterprise_telemetry queue. It deserializes envelopes and
dispatches them to the EnterpriseMetricHandler.
"""
import json
import logging
from celery import shared_task
from enterprise.telemetry.contracts import TelemetryEnvelope
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
logger = logging.getLogger(__name__)
@shared_task(queue="enterprise_telemetry")
def process_enterprise_telemetry(envelope_json: str) -> None:
"""Process enterprise metric/log telemetry envelope.
This task is enqueued by the TelemetryGateway for metric/log-only
events. It deserializes the envelope and dispatches to the handler.
Best-effort processing: logs errors but never raises, to avoid
failing user requests due to telemetry issues.
Args:
envelope_json: JSON-serialized TelemetryEnvelope.
"""
try:
# Deserialize envelope
envelope_dict = json.loads(envelope_json)
envelope = TelemetryEnvelope.model_validate(envelope_dict)
# Process through handler
handler = EnterpriseMetricHandler()
handler.handle(envelope)
logger.debug(
"Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s",
envelope.tenant_id,
envelope.event_id,
envelope.case,
)
except Exception:
# Best-effort: log and drop on error, never fail user request
logger.warning(
"Failed to process enterprise telemetry envelope, dropping event",
exc_info=True,
)

View File

@ -39,17 +39,36 @@ def process_trace_tasks(file_info):
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
try:
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
if is_ee_telemetry_enabled():
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
try:
EnterpriseOtelTrace().trace(trace_info)
except Exception:
logger.exception("Enterprise trace failed for app_id: %s", app_id)
if trace_instance:
with current_app.app_context():
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logger.info("Processing trace tasks success, app_id: %s", app_id)
except Exception as e:
logger.info("error:\n\n\n%s\n\n\n\n", e)
logger.exception("Processing trace tasks failed, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logger.info("Processing trace tasks failed, app_id: %s", app_id)
finally:
storage.delete(file_path)
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)

View File

@ -9,6 +9,7 @@ from celery import shared_task
from sqlalchemy import or_, select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
@ -106,7 +107,7 @@ def regenerate_summary_index_task(
),
DatasetDocument.enabled == True, # Document must be enabled
DatasetDocument.archived == False, # Document must not be archived
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
.all()
@ -209,7 +210,7 @@ def regenerate_summary_index_task(
for dataset_document in dataset_documents:
# Skip qa_model documents
if dataset_document.doc_form == "qa_model":
if dataset_document.doc_form == IndexStructureType.QA_INDEX:
continue
try:

View File

@ -179,7 +179,7 @@ def _record_trigger_failure_log(
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=created_by_role,
created_by=created_by,
)

View File

@ -13,6 +13,7 @@ from unittest.mock import patch
import pytest
from extensions.ext_redis import redis_client
from models.enums import ApiTokenType
from models.model import ApiToken
from services.api_token_service import ApiTokenCache, CachedApiToken
@ -279,7 +280,7 @@ class TestEndToEndCacheFlow:
test_token = ApiToken()
test_token.id = "test-e2e-id"
test_token.token = test_token_value
test_token.type = test_scope
test_token.type = ApiTokenType.APP
test_token.app_id = "test-app"
test_token.tenant_id = "test-tenant"
test_token.last_used_at = None

View File

@ -0,0 +1,342 @@
"""Authenticated controller integration tests for console message APIs."""
from datetime import timedelta
from decimal import Decimal
from unittest.mock import patch
from uuid import uuid4
import pytest
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload
from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from libs.datetime_utils import naive_utc_now
from models.enums import ConversationFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
create_console_app,
)
def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation:
conversation = Conversation(
app_id=app_id,
app_model_config_id=None,
model_provider=None,
model_id="",
override_model_configs=None,
mode=mode,
name="Test Conversation",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
)
db_session.add(conversation)
db_session.commit()
return conversation
def _create_message(
db_session: Session,
app_id: str,
conversation_id: str,
account_id: str,
*,
created_at_offset_seconds: int = 0,
) -> Message:
created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds)
message = Message(
app_id=app_id,
model_provider=None,
model_id="",
override_model_configs=None,
conversation_id=conversation_id,
inputs={},
query="Hello",
message={"type": "text", "content": "Hello"},
message_tokens=1,
message_unit_price=Decimal("0.0001"),
message_price_unit=Decimal("0.001"),
answer="Hi there",
answer_tokens=1,
answer_unit_price=Decimal("0.0001"),
answer_price_unit=Decimal("0.001"),
parent_message_id=None,
provider_response_latency=0,
total_price=Decimal("0.0002"),
currency="USD",
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
created_at=created_at,
updated_at=created_at,
app_mode=AppMode.CHAT,
)
db_session.add(message)
db_session.commit()
return message
class TestMessageValidators:
def test_chat_messages_query_validators(self) -> None:
assert ChatMessagesQuery.empty_to_none("") is None
assert ChatMessagesQuery.empty_to_none("val") == "val"
assert ChatMessagesQuery.validate_uuid(None) is None
assert (
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_message_feedback_validators(self) -> None:
assert (
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
== "123e4567-e89b-12d3-a456-426614174000"
)
def test_feedback_export_validators(self) -> None:
assert FeedbackExportQuery.parse_bool(None) is None
assert FeedbackExportQuery.parse_bool(True) is True
assert FeedbackExportQuery.parse_bool("1") is True
assert FeedbackExportQuery.parse_bool("0") is False
assert FeedbackExportQuery.parse_bool("off") is False
with pytest.raises(ValueError):
FeedbackExportQuery.parse_bool("invalid")
def test_chat_message_list_not_found(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages",
query_string={"conversation_id": str(uuid4())},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert payload["code"] == "not_found"
def test_chat_message_list_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
_create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0)
second = _create_message(
db_session_with_containers,
app.id,
conversation.id,
account.id,
created_at_offset_seconds=1,
)
with patch(
"controllers.console.app.message.attach_message_extra_contents",
side_effect=_attach_message_extra_contents,
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages",
query_string={"conversation_id": conversation.id, "limit": 1},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["limit"] == 1
assert payload["has_more"] is True
assert len(payload["data"]) == 1
assert payload["data"][0]["id"] == second.id
def test_message_feedback_not_found(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
response = test_client_with_containers.post(
f"/console/api/apps/{app.id}/feedbacks",
json={"message_id": str(uuid4()), "rating": "like"},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert payload["code"] == "not_found"
def test_message_feedback_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
response = test_client_with_containers.post(
f"/console/api/apps/{app.id}/feedbacks",
json={"message_id": message.id, "rating": "like"},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
feedback = db_session_with_containers.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == message.id)
)
assert feedback is not None
assert feedback.rating == FeedbackRating.LIKE
assert feedback.from_account_id == account.id
def test_message_annotation_count(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
db_session_with_containers.add(
MessageAnnotation(
app_id=app.id,
conversation_id=conversation.id,
message_id=message.id,
question="Q",
content="A",
account_id=account.id,
)
)
db_session_with_containers.commit()
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/annotations/count",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"count": 1}
def test_message_suggested_questions_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
message_id = str(uuid4())
with patch(
"controllers.console.app.message.MessageService.get_suggested_questions_after_answer",
return_value=["q1", "q2"],
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"data": ["q1", "q2"]}
@pytest.mark.parametrize(
("exc", "expected_status", "expected_code"),
[
(MessageNotExistsError(), 404, "not_found"),
(ConversationNotExistsError(), 404, "not_found"),
(ProviderTokenNotInitError(), 400, "provider_not_initialize"),
(QuotaExceededError(), 400, "provider_quota_exceeded"),
(ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"),
(SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"),
(Exception(), 500, "internal_server_error"),
],
)
def test_message_suggested_questions_errors(
exc: Exception,
expected_status: int,
expected_code: str,
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
message_id = str(uuid4())
with patch(
"controllers.console.app.message.MessageService.get_suggested_questions_after_answer",
side_effect=exc,
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == expected_status
payload = response.get_json()
assert payload is not None
assert payload["code"] == expected_code
def test_message_feedback_export_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/feedbacks/export",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"exported": True}
def test_message_api_get_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode)
message = _create_message(db_session_with_containers, app.id, conversation.id, account.id)
with patch(
"controllers.console.app.message.attach_message_extra_contents",
side_effect=_attach_message_extra_contents,
):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/messages/{message.id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["id"] == message.id

View File

@ -0,0 +1,334 @@
"""Controller integration tests for console statistic routes."""
from datetime import timedelta
from decimal import Decimal
from unittest.mock import patch
from uuid import uuid4
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from libs.datetime_utils import naive_utc_now
from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageFeedback
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
create_console_app,
)
def _create_conversation(
db_session: Session,
app_id: str,
account_id: str,
*,
mode: AppMode,
created_at_offset_days: int = 0,
) -> Conversation:
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
conversation = Conversation(
app_id=app_id,
app_model_config_id=None,
model_provider=None,
model_id="",
override_model_configs=None,
mode=mode,
name="Stats Conversation",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
from_source=ConversationFromSource.CONSOLE,
from_account_id=account_id,
created_at=created_at,
updated_at=created_at,
)
db_session.add(conversation)
db_session.commit()
return conversation
def _create_message(
db_session: Session,
app_id: str,
conversation_id: str,
*,
from_account_id: str | None,
from_end_user_id: str | None = None,
message_tokens: int = 1,
answer_tokens: int = 1,
total_price: Decimal = Decimal("0.01"),
provider_response_latency: float = 1.0,
created_at_offset_days: int = 0,
) -> Message:
created_at = naive_utc_now() + timedelta(days=created_at_offset_days)
message = Message(
app_id=app_id,
model_provider=None,
model_id="",
override_model_configs=None,
conversation_id=conversation_id,
inputs={},
query="Hello",
message={"type": "text", "content": "Hello"},
message_tokens=message_tokens,
message_unit_price=Decimal("0.001"),
message_price_unit=Decimal("0.001"),
answer="Hi there",
answer_tokens=answer_tokens,
answer_unit_price=Decimal("0.001"),
answer_price_unit=Decimal("0.001"),
parent_message_id=None,
provider_response_latency=provider_response_latency,
total_price=total_price,
currency="USD",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_end_user_id=from_end_user_id,
from_account_id=from_account_id,
created_at=created_at,
updated_at=created_at,
app_mode=AppMode.CHAT,
)
db_session.add(message)
db_session.commit()
return message
def _create_like_feedback(
db_session: Session,
app_id: str,
conversation_id: str,
message_id: str,
account_id: str,
) -> None:
db_session.add(
MessageFeedback(
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.ADMIN,
from_account_id=account_id,
)
)
db_session.commit()
def test_daily_message_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-messages",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["message_count"] == 1
def test_daily_conversation_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-conversations",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["conversation_count"] == 1
def test_daily_terminals_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(
db_session_with_containers,
app.id,
conversation.id,
from_account_id=None,
from_end_user_id=str(uuid4()),
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-end-users",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["terminal_count"] == 1
def test_daily_token_cost_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(
db_session_with_containers,
app.id,
conversation.id,
from_account_id=account.id,
message_tokens=40,
answer_tokens=60,
total_price=Decimal("0.02"),
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/token-costs",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload["data"][0]["token_count"] == 100
assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02")
def test_average_session_interaction_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/average-session-interactions",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["interactions"] == 2.0
def test_user_satisfaction_rate_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
for _ in range(9):
_create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id)
_create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["rate"] == 100.0
def test_average_response_time_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(
db_session_with_containers,
app.id,
conversation.id,
from_account_id=account.id,
provider_response_latency=1.234,
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/average-response-time",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["latency"] == 1234.0
def test_tokens_per_second_statistic(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode)
_create_message(
db_session_with_containers,
app.id,
conversation.id,
from_account_id=account.id,
answer_tokens=31,
provider_response_latency=2.0,
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/tokens-per-second",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json()["data"][0]["tps"] == 15.5
def test_invalid_time_range(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")):
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 400
assert response.get_json()["message"] == "Invalid time"
def test_time_range_params_passed(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
import datetime
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
start = datetime.datetime.now()
end = datetime.datetime.now()
with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse:
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
mock_parse.assert_called_once_with("something", "something", "UTC")

View File

@ -0,0 +1,415 @@
"""Authenticated controller integration tests for workflow draft variable APIs."""
import uuid
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
from dify_graph.variables.segments import StringSegment
from factories.variable_factory import segment_to_variable
from models import Workflow
from models.model import AppMode
from models.workflow import WorkflowDraftVariable
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
create_console_app,
)
def _create_draft_workflow(
db_session: Session,
app_id: str,
tenant_id: str,
account_id: str,
*,
environment_variables: list | None = None,
conversation_variables: list | None = None,
) -> Workflow:
workflow = Workflow.new(
tenant_id=tenant_id,
app_id=app_id,
type="workflow",
version=Workflow.VERSION_DRAFT,
graph='{"nodes": [], "edges": []}',
features="{}",
created_by=account_id,
environment_variables=environment_variables or [],
conversation_variables=conversation_variables or [],
rag_pipeline_variables=[],
)
db_session.add(workflow)
db_session.commit()
return workflow
def _create_node_variable(
db_session: Session,
app_id: str,
user_id: str,
*,
node_id: str = "node_1",
name: str = "test_var",
) -> WorkflowDraftVariable:
variable = WorkflowDraftVariable.new_node_variable(
app_id=app_id,
user_id=user_id,
node_id=node_id,
name=name,
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
visible=True,
editable=True,
)
db_session.add(variable)
db_session.commit()
return variable
def _create_system_variable(
db_session: Session, app_id: str, user_id: str, name: str = "query"
) -> WorkflowDraftVariable:
variable = WorkflowDraftVariable.new_sys_variable(
app_id=app_id,
user_id=user_id,
name=name,
value=StringSegment(value="system-value"),
node_execution_id=str(uuid.uuid4()),
editable=True,
)
db_session.add(variable)
db_session.commit()
return variable
def _build_environment_variable(name: str, value: str):
return segment_to_variable(
segment=StringSegment(value=value),
selector=[ENVIRONMENT_VARIABLE_NODE_ID, name],
name=name,
description=f"Environment variable {name}",
)
def _build_conversation_variable(name: str, value: str):
return segment_to_variable(
segment=StringSegment(value=value),
selector=[CONVERSATION_VARIABLE_NODE_ID, name],
name=name,
description=f"Conversation variable {name}",
)
def test_workflow_variable_collection_get_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"items": [], "total": 0}
def test_workflow_variable_collection_get_not_exist(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert payload["code"] == "draft_workflow_not_exist"
def test_workflow_variable_collection_delete(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_node_variable(db_session_with_containers, app.id, account.id)
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var")
response = test_client_with_containers.delete(
f"/console/api/apps/{app.id}/workflows/draft/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
remaining = db_session_with_containers.scalars(
select(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app.id,
WorkflowDraftVariable.user_id == account.id,
)
).all()
assert remaining == []
def test_node_variable_collection_get_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other")
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert [item["id"] for item in payload["items"]] == [node_variable.id]
def test_node_variable_collection_get_invalid_node_id(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 400
payload = response.get_json()
assert payload is not None
assert payload["code"] == "invalid_param"
def test_node_variable_collection_delete(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456")
target_id = target.id
untouched_id = untouched.id
response = test_client_with_containers.delete(
f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id))
is None
)
assert (
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id))
is not None
)
def test_variable_api_get_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["id"] == variable.id
assert payload["name"] == "test_var"
def test_variable_api_get_not_found(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert payload["code"] == "not_found"
def test_variable_api_patch_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.patch(
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
headers=authenticate_console_client(test_client_with_containers, account),
json={"name": "renamed_var"},
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["id"] == variable.id
assert payload["name"] == "renamed_var"
refreshed = db_session_with_containers.scalar(
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)
)
assert refreshed is not None
assert refreshed.name == "renamed_var"
def test_variable_api_delete_success(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.delete(
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id))
is None
)
def test_variable_reset_api_put_success_returns_no_content_without_execution(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id)
variable = _create_node_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.put(
f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id))
is None
)
def test_conversation_variable_collection_get(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(
db_session_with_containers,
app.id,
tenant.id,
account.id,
conversation_variables=[_build_conversation_variable("session_name", "Alice")],
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/conversation-variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert [item["name"] for item in payload["items"]] == ["session_name"]
created = db_session_with_containers.scalars(
select(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app.id,
WorkflowDraftVariable.user_id == account.id,
WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID,
)
).all()
assert len(created) == 1
def test_system_variable_collection_get(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
variable = _create_system_variable(db_session_with_containers, app.id, account.id)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/system-variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert [item["id"] for item in payload["items"]] == [variable.id]
def test_environment_variable_collection_get(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
_create_draft_workflow(
db_session_with_containers,
app.id,
tenant.id,
account.id,
environment_variables=[_build_environment_variable("api_key", "secret-value")],
)
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/environment-variables",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["items"][0]["name"] == "api_key"
assert payload["items"][0]["value"] == "secret-value"

View File

@ -0,0 +1,131 @@
"""Controller integration tests for API key data source auth routes."""
import json
from unittest.mock import patch
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from models.source import DataSourceApiKeyAuthBinding
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
)
def test_get_api_key_auth_data_source(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant.id,
category="api_key",
provider="custom_provider",
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
disabled=False,
)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
response = test_client_with_containers.get(
"/console/api/api-key-auth/data-source",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert len(payload["sources"]) == 1
assert payload["sources"][0]["provider"] == "custom_provider"
def test_get_api_key_auth_data_source_empty(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
response = test_client_with_containers.get(
"/console/api/api-key-auth/data-source",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"sources": []}
def test_create_binding_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
with (
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"),
):
response = test_client_with_containers.post(
"/console/api/api-key-auth/data-source/binding",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
def test_create_binding_failure(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
with (
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth",
side_effect=ValueError("Invalid structure"),
),
):
response = test_client_with_containers.post(
"/console/api/api-key-auth/data-source/binding",
json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 500
payload = response.get_json()
assert payload is not None
assert payload["code"] == "auth_failed"
assert payload["message"] == "Invalid structure"
def test_delete_binding_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant.id,
category="api_key",
provider="custom_provider",
credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}),
disabled=False,
)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
response = test_client_with_containers.delete(
f"/console/api/api-key-auth/data-source/{binding.id}",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 204
assert (
db_session_with_containers.scalar(
select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id)
)
is None
)

View File

@ -0,0 +1,120 @@
"""Controller integration tests for console OAuth data source routes."""
from unittest.mock import MagicMock, patch
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from models.source import DataSourceOauthBinding
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
)
def test_get_oauth_url_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
provider = MagicMock()
provider.get_authorization_url.return_value = "http://oauth.provider/auth"
with (
patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}),
patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None),
):
response = test_client_with_containers.get(
"/console/api/oauth/data-source/notion",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert tenant.id == account.current_tenant_id
assert response.status_code == 200
assert response.get_json() == {"data": "http://oauth.provider/auth"}
provider.get_authorization_url.assert_called_once()
def test_get_oauth_url_invalid_provider(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get(
"/console/api/oauth/data-source/unknown_provider",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 400
assert response.get_json() == {"error": "Invalid provider"}
def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None:
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code")
assert response.status_code == 302
assert "code=mock_code" in response.location
def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None:
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion")
assert response.status_code == 302
assert "error=Access%20denied" in response.location
def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None:
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code")
assert response.status_code == 400
assert response.get_json() == {"error": "Invalid provider"}
def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None:
provider = MagicMock()
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}):
response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123")
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
provider.get_access_token.assert_called_once_with("auth_code_123")
def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None:
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}):
response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=")
assert response.status_code == 400
assert response.get_json() == {"error": "Invalid code"}
def test_sync_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
binding = DataSourceOauthBinding(
tenant_id=tenant.id,
access_token="test-access-token",
provider="notion",
source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []},
disabled=False,
)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
provider = MagicMock()
with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}):
response = test_client_with_containers.get(
f"/console/api/oauth/data-source/notion/{binding.id}/sync",
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
provider.sync_data_source.assert_called_once_with(binding.id)

View File

@ -0,0 +1,365 @@
"""Controller integration tests for console OAuth server routes."""
from unittest.mock import patch
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
ensure_dify_setup,
)
def _build_oauth_provider_app() -> OAuthProviderApp:
return OAuthProviderApp(
app_icon="icon_url",
client_id="test_client_id",
client_secret="test_secret",
app_label={"en-US": "Test App"},
redirect_uris=["http://localhost/callback"],
scope="read,write",
)
def test_oauth_provider_successful_post(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider",
json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"},
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["app_icon"] == "icon_url"
assert payload["app_label"] == {"en-US": "Test App"}
assert payload["scope"] == "read,write"
def test_oauth_provider_invalid_redirect_uri(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider",
json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"},
)
assert response.status_code == 400
payload = response.get_json()
assert payload is not None
assert "redirect_uri is invalid" in payload["message"]
def test_oauth_provider_invalid_client_id(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
response = test_client_with_containers.post(
"/console/api/oauth/provider",
json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"},
)
assert response.status_code == 404
payload = response.get_json()
assert payload is not None
assert "client_id is invalid" in payload["message"]
def test_oauth_authorize_successful(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
with (
patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
),
patch(
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
return_value="auth_code_123",
) as mock_sign,
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/authorize",
json={"client_id": "test_client_id"},
headers=authenticate_console_client(test_client_with_containers, account),
)
assert response.status_code == 200
assert response.get_json() == {"code": "auth_code_123"}
mock_sign.assert_called_once_with("test_client_id", account.id)
def test_oauth_token_authorization_code_grant(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with (
patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
),
patch(
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
return_value=("access_123", "refresh_123"),
),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
)
assert response.status_code == 200
assert response.get_json() == {
"access_token": "access_123",
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": "refresh_123",
}
def test_oauth_token_authorization_code_grant_missing_code(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"client_secret": "test_secret",
"redirect_uri": "http://localhost/callback",
},
)
assert response.status_code == 400
assert response.get_json()["message"] == "code is required"
def test_oauth_token_authorization_code_grant_invalid_secret(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "invalid_secret",
"redirect_uri": "http://localhost/callback",
},
)
assert response.status_code == 400
assert response.get_json()["message"] == "client_secret is invalid"
def test_oauth_token_authorization_code_grant_invalid_redirect_uri(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={
"client_id": "test_client_id",
"grant_type": "authorization_code",
"code": "auth_code",
"client_secret": "test_secret",
"redirect_uri": "http://invalid/callback",
},
)
assert response.status_code == 400
assert response.get_json()["message"] == "redirect_uri is invalid"
def test_oauth_token_refresh_token_grant(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with (
patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
),
patch(
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token",
return_value=("new_access", "new_refresh"),
),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"},
)
assert response.status_code == 200
assert response.get_json() == {
"access_token": "new_access",
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": "new_refresh",
}
def test_oauth_token_refresh_token_grant_missing_token(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={"client_id": "test_client_id", "grant_type": "refresh_token"},
)
assert response.status_code == 400
assert response.get_json()["message"] == "refresh_token is required"
def test_oauth_token_invalid_grant_type(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/token",
json={"client_id": "test_client_id", "grant_type": "invalid_grant"},
)
assert response.status_code == 400
assert response.get_json()["message"] == "invalid grant_type"
def test_oauth_account_successful_retrieval(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
account.avatar = "avatar_url"
db_session_with_containers.commit()
with (
patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
),
patch(
"controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token",
return_value=account,
),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/account",
json={"client_id": "test_client_id"},
headers={"Authorization": "Bearer valid_access_token"},
)
assert response.status_code == 200
assert response.get_json() == {
"name": "Test User",
"email": account.email,
"avatar": "avatar_url",
"interface_language": "en-US",
"timezone": "UTC",
}
def test_oauth_account_missing_authorization_header(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/account",
json={"client_id": "test_client_id"},
)
assert response.status_code == 401
assert response.get_json() == {"error": "Authorization header is required"}
def test_oauth_account_invalid_authorization_header_format(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
ensure_dify_setup(db_session_with_containers)
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app",
return_value=_build_oauth_provider_app(),
):
response = test_client_with_containers.post(
"/console/api/oauth/provider/account",
json={"client_id": "test_client_id"},
headers={"Authorization": "InvalidFormat"},
)
assert response.status_code == 401
assert response.get_json() == {"error": "Invalid Authorization header format"}

View File

@ -1,17 +1,10 @@
"""
Test suite for password reset authentication flows.
"""Testcontainers integration tests for password reset authentication flows."""
This module tests the password reset mechanism including:
- Password reset email sending
- Verification code validation
- Password reset with token
- Rate limiting and security checks
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.error import (
EmailCodeError,
@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import (
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
@pytest.fixture(autouse=True)
def _mock_forgot_password_session():
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.__exit__.return_value = None
yield mock_session
@pytest.fixture(autouse=True)
def _mock_forgot_password_db():
with patch("controllers.console.auth.forgot_password.db") as mock_db:
mock_db.engine = MagicMock()
yield mock_db
class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_account(self):
@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi:
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
mock_wraps_db,
app,
mock_account,
):
"""
Test successful password reset email sending.
Verifies that:
- Email is sent to valid account
- Reset token is generated and returned
- IP rate limiting is checked
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_account.return_value = mock_account
mock_send_email.return_value = "reset_token_123"
@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi:
assert response["data"] == "reset_token_123"
mock_send_email.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
"""
Test password reset email blocked by IP rate limit.
@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi:
- No email is sent when rate limited
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi:
(None, "en-US"), # Defaults to en-US when not provided
],
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
mock_wraps_db,
app,
mock_account,
language_input,
@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi:
- Unsupported languages default to en-US
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token"
@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi:
"""Test cases for verifying password reset codes."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
"""
@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi:
- Rate limit is reset on success
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_generate_token.return_value = (None, "new_token")
@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi:
)
mock_reset_rate_limit.assert_called_once_with("test@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
mock_generate_token.return_value = (None, "fresh-token")
@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi:
mock_revoke_token.assert_called_once_with("upper_token")
mock_reset_rate_limit.assert_called_once_with("user@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
"""
Test code verification blocked by rate limit.
@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi:
- Prevents brute force attacks on verification codes
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
# Act & Assert
@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi:
with pytest.raises(EmailPasswordResetLimitError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with invalid token.
@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = None
@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with mismatched email.
@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi:
- Prevents token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi:
with pytest.raises(InvalidEmailError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
"""
Test code verification with incorrect code.
@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi:
- Rate limit counter is incremented
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
@ -380,11 +321,8 @@ class TestForgotPasswordResetApi:
"""Test cases for resetting password with verified token."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture
def mock_account(self):
@ -394,7 +332,6 @@ class TestForgotPasswordResetApi:
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@ -405,7 +342,6 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_revoke_token,
mock_get_data,
mock_wraps_db,
app,
mock_account,
):
@ -418,7 +354,6 @@ class TestForgotPasswordResetApi:
- Success response is returned
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_get_account.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
@ -436,9 +371,8 @@ class TestForgotPasswordResetApi:
assert response["result"] == "success"
mock_revoke_token.assert_called_once_with("valid_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
def test_reset_password_mismatch(self, mock_get_data, app):
"""
Test password reset with mismatched passwords.
@ -447,7 +381,6 @@ class TestForgotPasswordResetApi:
- No password update occurs
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
# Act & Assert
@ -460,9 +393,8 @@ class TestForgotPasswordResetApi:
with pytest.raises(PasswordMismatchError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
def test_reset_password_invalid_token(self, mock_get_data, app):
"""
Test password reset with invalid token.
@ -470,7 +402,6 @@ class TestForgotPasswordResetApi:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
@ -483,9 +414,8 @@ class TestForgotPasswordResetApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
def test_reset_password_wrong_phase(self, mock_get_data, app):
"""
Test password reset with token not in reset phase.
@ -494,7 +424,6 @@ class TestForgotPasswordResetApi:
- Prevents use of verification-phase tokens for reset
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
# Act & Assert
@ -507,13 +436,10 @@ class TestForgotPasswordResetApi:
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_reset_password_account_not_found(
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
):
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
"""
Test password reset for non-existent account.
@ -521,7 +447,6 @@ class TestForgotPasswordResetApi:
- AccountNotFound is raised when account doesn't exist
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
mock_get_account.return_value = None

View File

@ -0,0 +1,85 @@
"""Shared helpers for authenticated console controller integration tests."""
import uuid
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HEADER_NAME_CSRF_TOKEN
from libs.datetime_utils import naive_utc_now
from libs.token import _real_cookie_name, generate_csrf_token
from models import Account, DifySetup, Tenant, TenantAccountJoin
from models.account import AccountStatus, TenantAccountRole
from models.model import App, AppMode
from services.account_service import AccountService
def ensure_dify_setup(db_session: Session) -> None:
"""Create a setup marker once so setup-protected console routes can be exercised."""
if db_session.scalar(select(DifySetup).limit(1)) is not None:
return
db_session.add(DifySetup(version=dify_config.project.version))
db_session.commit()
def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]:
"""Create an initialized owner account with a current tenant."""
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="Test User",
interface_language="en-US",
status=AccountStatus.ACTIVE,
)
account.initialized_at = naive_utc_now()
db_session.add(account)
db_session.commit()
tenant = Tenant(name="Test Tenant", status="normal")
db_session.add(tenant)
db_session.commit()
db_session.add(
TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
)
db_session.commit()
account.set_tenant_id(tenant.id)
account.timezone = "UTC"
db_session.commit()
ensure_dify_setup(db_session)
return account, tenant
def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App:
"""Create a minimal app row that can be loaded by get_app_model."""
app = App(
tenant_id=tenant_id,
name="Test App",
mode=mode,
enable_site=True,
enable_api=True,
created_by=account_id,
)
db_session.add(app)
db_session.commit()
return app
def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]:
"""Attach console auth cookies/headers for endpoints guarded by login_required."""
access_token = AccountService.get_account_jwt_token(account)
csrf_token = generate_csrf_token(account.id)
test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost")
return {
"Authorization": f"Bearer {access_token}",
HEADER_NAME_CSRF_TOKEN: csrf_token,
}

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration:
name=f"Document {i}",
created_from=DocumentCreatedFrom.WEB,
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
doc_language="en",
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=True, # Archived
@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=False, # Disabled
archived=False,
@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=status, # Not completed
enabled=True,
archived=False,
@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration:
created_from=DocumentCreatedFrom.WEB,
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration:
indexing_status=IndexingStatus.COMPLETED,
enabled=True,
archived=False,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()

View File

@ -1,27 +0,0 @@
from __future__ import annotations
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
create_human_input_message_fixture,
)
def test_get_by_message_ids_returns_human_input_content(db_session_with_containers):
fixture = create_human_input_message_fixture(db_session_with_containers)
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=sessionmaker(bind=db.engine, expire_on_commit=False)
)
results = repository.get_by_message_ids([fixture.message.id])
assert len(results) == 1
assert len(results[0]) == 1
content = results[0][0]
assert content.submitted is True
assert content.form_submission_data is not None
assert content.form_submission_data.action_id == fixture.action_id
assert content.form_submission_data.action_text == fixture.action_text
assert content.form_submission_data.rendered_content == fixture.form.rendered_content

View File

@ -27,7 +27,7 @@ from models.human_input import (
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
@ -218,7 +218,7 @@ class TestDeleteRunsWithRelated:
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)
@ -278,7 +278,7 @@ class TestCountRunsWithRelated:
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_from=WorkflowAppLogCreatedFrom.SERVICE_API,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)

View File

@ -0,0 +1,407 @@
"""Integration tests for SQLAlchemyExecutionExtraContentRepository using Testcontainers.
Part of #32454 — replaces the mock-based unit tests with real database interactions.
"""
from __future__ import annotations
from collections.abc import Generator
from dataclasses import dataclass
from datetime import datetime, timedelta
from decimal import Decimal
from uuid import uuid4
import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from dify_graph.nodes.human_input.entities import FormDefinition, UserAction
from dify_graph.nodes.human_input.enums import HumanInputFormStatus
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import ConversationFromSource, InvokeFrom
from models.execution_extra_content import ExecutionExtraContent, HumanInputContent
from models.human_input import (
ConsoleRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from models.model import App, Conversation, Message
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
@dataclass
class _TestScope:
"""Per-test data scope used to isolate DB rows.
IDs are populated after flushing the base entities to the database.
"""
tenant_id: str = ""
app_id: str = ""
user_id: str = ""
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
"""Remove test-created DB rows for a test scope."""
form_ids_subquery = select(HumanInputForm.id).where(
HumanInputForm.tenant_id == scope.tenant_id,
)
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
session.execute(
delete(ExecutionExtraContent).where(
ExecutionExtraContent.workflow_run_id.in_(
select(HumanInputForm.workflow_run_id).where(HumanInputForm.tenant_id == scope.tenant_id)
)
)
)
session.execute(delete(HumanInputForm).where(HumanInputForm.tenant_id == scope.tenant_id))
session.execute(delete(Message).where(Message.app_id == scope.app_id))
session.execute(delete(Conversation).where(Conversation.app_id == scope.app_id))
session.execute(delete(App).where(App.id == scope.app_id))
session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == scope.tenant_id))
session.execute(delete(Account).where(Account.id == scope.user_id))
session.execute(delete(Tenant).where(Tenant.id == scope.tenant_id))
session.commit()
def _seed_base_entities(session: Session, scope: _TestScope) -> None:
"""Create the base tenant, account, and app needed by tests."""
tenant = Tenant(name="Test Tenant")
session.add(tenant)
session.flush()
scope.tenant_id = tenant.id
account = Account(
name="Test Account",
email=f"test_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
session.add(account)
session.flush()
scope.user_id = account.id
tenant_join = TenantAccountJoin(
tenant_id=scope.tenant_id,
account_id=scope.user_id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(tenant_join)
app = App(
tenant_id=scope.tenant_id,
name="Test App",
description="",
mode="chat",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=scope.user_id,
updated_by=scope.user_id,
)
session.add(app)
session.flush()
scope.app_id = app.id
def _create_conversation(session: Session, scope: _TestScope) -> Conversation:
conversation = Conversation(
app_id=scope.app_id,
mode="chat",
name="Test Conversation",
summary="",
introduction="",
system_instruction="",
status="normal",
invoke_from=InvokeFrom.EXPLORE,
from_source=ConversationFromSource.CONSOLE,
from_account_id=scope.user_id,
from_end_user_id=None,
)
conversation.inputs = {}
session.add(conversation)
session.flush()
return conversation
def _create_message(
session: Session,
scope: _TestScope,
conversation_id: str,
workflow_run_id: str,
) -> Message:
message = Message(
app_id=scope.app_id,
conversation_id=conversation_id,
inputs={},
query="test query",
message={"messages": []},
answer="test answer",
message_tokens=50,
message_unit_price=Decimal("0.001"),
answer_tokens=80,
answer_unit_price=Decimal("0.001"),
provider_response_latency=0.5,
currency="USD",
from_source=ConversationFromSource.CONSOLE,
from_account_id=scope.user_id,
workflow_run_id=workflow_run_id,
)
session.add(message)
session.flush()
return message
def _create_submitted_form(
session: Session,
scope: _TestScope,
*,
workflow_run_id: str,
action_id: str = "approve",
action_title: str = "Approve",
node_title: str = "Approval",
) -> HumanInputForm:
expiration_time = datetime.utcnow() + timedelta(days=1)
form_definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id=action_id, title=action_title)],
rendered_content="rendered",
expiration_time=expiration_time,
node_title=node_title,
display_in_ui=True,
)
form = HumanInputForm(
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_run_id=workflow_run_id,
node_id="node-id",
form_definition=form_definition.model_dump_json(),
rendered_content=f"Rendered {action_title}",
status=HumanInputFormStatus.SUBMITTED,
expiration_time=expiration_time,
selected_action_id=action_id,
)
session.add(form)
session.flush()
return form
def _create_waiting_form(
session: Session,
scope: _TestScope,
*,
workflow_run_id: str,
default_values: dict | None = None,
) -> HumanInputForm:
expiration_time = datetime.utcnow() + timedelta(days=1)
form_definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values=default_values or {"name": "John"},
node_title="Approval",
display_in_ui=True,
)
form = HumanInputForm(
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_run_id=workflow_run_id,
node_id="node-id",
form_definition=form_definition.model_dump_json(),
rendered_content="Rendered block",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
session.add(form)
session.flush()
return form
def _create_human_input_content(
session: Session,
*,
workflow_run_id: str,
message_id: str,
form_id: str,
) -> HumanInputContent:
content = HumanInputContent.new(
workflow_run_id=workflow_run_id,
message_id=message_id,
form_id=form_id,
)
session.add(content)
return content
def _create_recipient(
session: Session,
*,
form_id: str,
delivery_id: str,
recipient_type: RecipientType = RecipientType.CONSOLE,
access_token: str = "token-1",
) -> HumanInputFormRecipient:
payload = ConsoleRecipientPayload(account_id=None)
recipient = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=recipient_type,
recipient_payload=payload.model_dump_json(),
access_token=access_token,
)
session.add(recipient)
return recipient
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
from dify_graph.nodes.human_input.enums import DeliveryMethodType
from models.human_input import ConsoleDeliveryPayload
delivery = HumanInputDelivery(
form_id=form_id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
)
session.add(delivery)
session.flush()
return delivery
@pytest.fixture
def repository(db_session_with_containers: Session) -> SQLAlchemyExecutionExtraContentRepository:
"""Build a repository backed by the testcontainers database engine."""
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
return SQLAlchemyExecutionExtraContentRepository(sessionmaker(bind=engine, expire_on_commit=False))
@pytest.fixture
def test_scope(db_session_with_containers: Session) -> Generator[_TestScope]:
"""Provide an isolated scope and clean related data after each test."""
scope = _TestScope()
_seed_base_entities(db_session_with_containers, scope)
db_session_with_containers.commit()
yield scope
_cleanup_scope_data(db_session_with_containers, scope)
class TestGetByMessageIds:
"""Tests for SQLAlchemyExecutionExtraContentRepository.get_by_message_ids."""
def test_groups_contents_by_message(
self,
db_session_with_containers: Session,
repository: SQLAlchemyExecutionExtraContentRepository,
test_scope: _TestScope,
) -> None:
"""Submitted forms are correctly mapped and grouped by message ID."""
workflow_run_id = str(uuid4())
conversation = _create_conversation(db_session_with_containers, test_scope)
msg1 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
msg2 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
form = _create_submitted_form(
db_session_with_containers,
test_scope,
workflow_run_id=workflow_run_id,
action_id="approve",
action_title="Approve",
)
_create_human_input_content(
db_session_with_containers,
workflow_run_id=workflow_run_id,
message_id=msg1.id,
form_id=form.id,
)
db_session_with_containers.commit()
result = repository.get_by_message_ids([msg1.id, msg2.id])
assert len(result) == 2
# msg1 has one submitted content
assert len(result[0]) == 1
content = result[0][0]
assert content.submitted is True
assert content.workflow_run_id == workflow_run_id
assert content.form_submission_data is not None
assert content.form_submission_data.action_id == "approve"
assert content.form_submission_data.action_text == "Approve"
assert content.form_submission_data.rendered_content == "Rendered Approve"
assert content.form_submission_data.node_id == "node-id"
assert content.form_submission_data.node_title == "Approval"
# msg2 has no content
assert result[1] == []
def test_returns_unsubmitted_form_definition(
self,
db_session_with_containers: Session,
repository: SQLAlchemyExecutionExtraContentRepository,
test_scope: _TestScope,
) -> None:
"""Waiting forms return full form_definition with resolved token and defaults."""
workflow_run_id = str(uuid4())
conversation = _create_conversation(db_session_with_containers, test_scope)
msg = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id)
form = _create_waiting_form(
db_session_with_containers,
test_scope,
workflow_run_id=workflow_run_id,
default_values={"name": "John"},
)
delivery = _create_delivery(db_session_with_containers, form_id=form.id)
_create_recipient(
db_session_with_containers,
form_id=form.id,
delivery_id=delivery.id,
access_token="token-1",
)
_create_human_input_content(
db_session_with_containers,
workflow_run_id=workflow_run_id,
message_id=msg.id,
form_id=form.id,
)
db_session_with_containers.commit()
result = repository.get_by_message_ids([msg.id])
assert len(result) == 1
assert len(result[0]) == 1
domain_content = result[0][0]
assert domain_content.submitted is False
assert domain_content.workflow_run_id == workflow_run_id
assert domain_content.form_definition is not None
form_def = domain_content.form_definition
assert form_def.form_id == form.id
assert form_def.node_id == "node-id"
assert form_def.node_title == "Approval"
assert form_def.form_content == "Rendered block"
assert form_def.display_in_ui is True
assert form_def.form_token == "token-1"
assert form_def.resolved_default_values == {"name": "John"}
assert form_def.expiration_time == int(form.expiration_time.timestamp())
def test_empty_message_ids_returns_empty_list(
self,
repository: SQLAlchemyExecutionExtraContentRepository,
) -> None:
"""Passing no message IDs returns an empty list without hitting the DB."""
result = repository.get_by_message_ids([])
assert result == []

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
from models import Account
from models.dataset import Dataset, Document
@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory:
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_form="text_model",
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
document.id = document_id
document.indexing_status = indexing_status

View File

@ -525,3 +525,147 @@ class TestAPIBasedExtensionService:
# Try to get extension with wrong tenant ID
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
def test_save_extension_api_key_exactly_four_chars_rejected(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""API key with exactly 4 characters should be rejected (boundary)."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key="1234",
)
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_api_key_exactly_five_chars_accepted(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""API key with exactly 5 characters should be accepted (boundary)."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key="12345",
)
saved = APIBasedExtensionService.save(extension_data)
assert saved.id is not None
def test_save_extension_requestor_constructor_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Exception raised by requestor constructor is wrapped in ValueError."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config")
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="connection error: bad config"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_network_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Network exceptions during ping are wrapped in ValueError."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError(
"network failure"
)
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="connection error: network failure"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_update_duplicate_name_rejected(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Updating an existing extension to use another extension's name should fail."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
ext1 = APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant.id,
name="Extension Alpha",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
ext2 = APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant.id,
name="Extension Beta",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
# Try to rename ext2 to ext1's name
ext2.name = "Extension Alpha"
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(ext2)
def test_get_all_returns_empty_for_different_tenant(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Extensions from one tenant should not be visible to another."""
fake = Faker()
_, tenant1 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
_, tenant2 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant1 is not None
APIBasedExtensionService.save(
APIBasedExtension(
tenant_id=tenant1.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
)
assert tenant2 is not None
result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id)
assert result == []

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from constants.model_template import default_app_templates
from models import Account
from models.model import App, Site
from models.model import App, IconType, Site
from services.account_service import AccountService, TenantService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -463,6 +463,109 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_should_preserve_icon_type_when_omitted(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update_app keeps the persisted icon_type when the update payload omits it.
"""
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
},
account,
)
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app(
app,
{
"name": "Updated App Name",
"description": "Updated app description",
"icon_type": None,
"icon": "🔄",
"icon_background": "#FF8C42",
"use_icon_as_answer_icon": True,
},
)
assert updated_app.icon_type == IconType.EMOJI
def test_update_app_should_reject_empty_icon_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update_app rejects an explicit empty icon_type.
"""
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
},
account,
)
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
with pytest.raises(ValueError):
app_service.update_app(
app,
{
"name": "Updated App Name",
"description": "Updated app description",
"icon_type": "",
"icon": "🔄",
"icon_background": "#FF8C42",
"use_icon_as_answer_icon": True,
},
)
def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app name update.

View File

@ -0,0 +1,80 @@
"""Testcontainers integration tests for AttachmentService."""
import base64
from datetime import UTC, datetime
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
import services.attachment_service as attachment_service_module
from extensions.ext_database import db
from extensions.storage.storage_type import StorageType
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.attachment_service import AttachmentService
class TestAttachmentService:
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id or str(uuid4()),
storage_type=StorageType.OPENDAL,
key=f"upload/{uuid4()}.txt",
name="test-file.txt",
size=100,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=datetime.now(UTC),
used=False,
)
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
def test_should_initialize_with_sessionmaker(self):
session_factory = sessionmaker()
service = AttachmentService(session_factory=session_factory)
assert service._session_maker is session_factory
def test_should_initialize_with_engine(self):
engine = create_engine("sqlite:///:memory:")
service = AttachmentService(session_factory=engine)
session = service._session_maker()
try:
assert session.bind == engine
finally:
session.close()
engine.dispose()
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory):
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
AttachmentService(session_factory=invalid_session_factory)
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
upload_file = self._create_upload_file(db_session_with_containers)
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
result = service.get_file_base64(upload_file.id)
assert result == base64.b64encode(b"binary-content").decode()
mock_load.assert_called_once_with(upload_file.key)
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
with pytest.raises(NotFound, match="File not found"):
service.get_file_base64(str(uuid4()))
mock_load.assert_not_called()

View File

@ -0,0 +1,58 @@
"""Testcontainers integration tests for ConversationVariableUpdater."""
from uuid import uuid4
import pytest
from sqlalchemy.orm import sessionmaker
from dify_graph.variables import StringVariable
from extensions.ext_database import db
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
class TestConversationVariableUpdater:
def _create_conversation_variable(
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
) -> ConversationVariable:
row = ConversationVariable(
id=variable.id,
conversation_id=conversation_id,
app_id=app_id or str(uuid4()),
data=variable.model_dump_json(),
)
db_session_with_containers.add(row)
db_session_with_containers.commit()
return row
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
self._create_conversation_variable(
db_session_with_containers, conversation_id=conversation_id, variable=variable
)
updated_variable = StringVariable(id=variable.id, name="topic", value="new value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
updater.update(conversation_id=conversation_id, variable=updated_variable)
db_session_with_containers.expire_all()
row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id))
assert row is not None
assert row.data == updated_variable.model_dump_json()
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
updater.update(conversation_id=conversation_id, variable=variable)
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
result = updater.flush()
assert result is None

Some files were not shown because too many files have changed in this diff Show More