Merge remote-tracking branch 'origin/main' into review-myscale-sqli

This commit is contained in:
-LAN- 2026-03-18 14:32:57 +08:00
commit f8ad55212e
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
1414 changed files with 95397 additions and 16940 deletions

View File

@ -187,53 +187,12 @@ const Template = useMemo(() => {
**When**: Component directly handles API calls, data transformation, or complex async operations.
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
```typescript
// ❌ Before: API logic in component
const MCPServiceCard = () => {
const [basicAppConfig, setBasicAppConfig] = useState({})
useEffect(() => {
if (isBasicApp && appId) {
(async () => {
const res = await fetchAppDetail({ url: '/apps', id: appId })
setBasicAppConfig(res?.model_config || {})
})()
}
}, [appId, isBasicApp])
// More API-related logic...
}
// ✅ After: Extract to data hook using React Query
// use-app-config.ts
import { useQuery } from '@tanstack/react-query'
import { get } from '@/service/base'
const NAME_SPACE = 'appConfig'
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
return useQuery({
enabled: isBasicApp && !!appId,
queryKey: [NAME_SPACE, 'detail', appId],
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || {},
})
}
// Component becomes cleaner
const MCPServiceCard = () => {
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
// UI only
}
```
**React Query Best Practices in Dify**:
- Define `NAME_SPACE` for query key organization
- Use `enabled` option for conditional fetching
- Use `select` for data transformation
- Export invalidation hooks: `useInvalidXxx`
**Dify Convention**:
- This skill is for component decomposition, not query/mutation design.
- When refactoring data fetching, follow `web/AGENTS.md`.
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
- Do not introduce deprecated `useInvalid` / `useReset`.
- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state.
**Dify Examples**:
- `web/service/use-workflow.ts`

View File

@ -155,48 +155,14 @@ const Configuration: FC = () => {
## Common Hook Patterns in Dify
### 1. Data Fetching Hook (React Query)
### 1. Data Fetching / Mutation Hooks
```typescript
// Pattern: Use @tanstack/react-query for data fetching
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { get } from '@/service/base'
import { useInvalid } from '@/service/use-base'
When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns.
const NAME_SPACE = 'appConfig'
// Query keys for cache management
export const appConfigQueryKeys = {
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
}
// Main data hook
export const useAppConfig = (appId: string) => {
return useQuery({
enabled: !!appId,
queryKey: appConfigQueryKeys.detail(appId),
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || null,
})
}
// Invalidation hook for refreshing data
export const useInvalidAppConfig = () => {
return useInvalid([NAME_SPACE])
}
// Usage in component
const Component = () => {
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
const invalidAppConfig = useInvalidAppConfig()
const handleRefresh = () => {
invalidAppConfig() // Invalidates cache and triggers refetch
}
return <div>...</div>
}
```
- Follow `web/AGENTS.md` first.
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
- Do not introduce deprecated `useInvalid` / `useReset`.
- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks.
### 2. Form State Hook

View File

@ -0,0 +1,44 @@
---
name: frontend-query-mutation
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
---
# Frontend Query & Mutation
## Intent
- Keep contract as the single source of truth in `web/contract/*`.
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
- Keep invalidation and mutation flow knowledge in the service layer.
- Keep abstractions minimal to preserve TypeScript inference.
## Workflow
1. Identify the change surface.
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
- Read both references when a task spans contract shape and runtime behavior.
2. Implement the smallest abstraction that fits the task.
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
- Extract a small shared query helper only when multiple call sites share the same extra options.
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
3. Preserve Dify conventions.
- Keep contract inputs in `{ params, query?, body? }` shape.
- Bind invalidation in the service-layer mutation definition.
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
## Files Commonly Touched
- `web/contract/console/*.ts`
- `web/contract/marketplace.ts`
- `web/contract/router.ts`
- `web/service/client.ts`
- `web/service/use-*.ts`
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
## References
- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference.
- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules.
Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs.

View File

@ -0,0 +1,4 @@
interface:
display_name: "Frontend Query & Mutation"
short_description: "Dify TanStack Query and oRPC patterns"
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."

View File

@ -0,0 +1,98 @@
# Contract Patterns
## Table of Contents
- Intent
- Minimal structure
- Core workflow
- Query usage decision rule
- Mutation usage decision rule
- Anti-patterns
- Contract rules
- Type export
## Intent
- Keep contract as the single source of truth in `web/contract/*`.
- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
- Keep abstractions minimal and preserve TypeScript inference.
## Minimal Structure
```text
web/contract/
├── base.ts
├── router.ts
├── marketplace.ts
└── console/
├── billing.ts
└── ...other domains
web/service/client.ts
```
## Core Workflow
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`.
- Use `base.route({...}).output(type<...>())` as the baseline.
- Add `.input(type<...>())` only when the request has `params`, `query`, or `body`.
- For `GET` without input, omit `.input(...)`; do not use `.input(type<unknown>())`.
2. Register contract in `web/contract/router.ts`.
- Import directly from domain files and nest by API prefix.
3. Consume from UI call sites via oRPC query utilities.
```typescript
import { useQuery } from '@tanstack/react-query'
import { consoleQuery } from '@/service/client'
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
staleTime: 5 * 60 * 1000,
throwOnError: true,
select: invoice => invoice.url,
}))
```
## Query Usage Decision Rule
1. Default to direct `*.queryOptions(...)` usage at the call site.
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
3. Create `web/service/use-{domain}.ts` only for orchestration.
- Combine multiple queries or mutations.
- Share domain-level derived state or invalidation helpers.
```typescript
const invoicesBaseQueryOptions = () =>
consoleQuery.billing.invoices.queryOptions({ retry: false })
const invoiceQuery = useQuery({
...invoicesBaseQueryOptions(),
throwOnError: true,
})
```
## Mutation Usage Decision Rule
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
## Anti-Patterns
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
- Do not create thin `use-*` passthrough hooks for a single endpoint.
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
## Contract Rules
- Input structure: always use `{ params, query?, body? }`.
- No-input `GET`: omit `.input(...)`; do not use `.input(type<unknown>())`.
- Path params: use `{paramName}` in the path and match it in the `params` object.
- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`.
- No barrel files: import directly from specific files.
- Types: import from `@/types/` and use the `type<T>()` helper.
- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools.
## Type Export
```typescript
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
```

View File

@ -0,0 +1,133 @@
# Runtime Rules
## Table of Contents
- Conditional queries
- Cache invalidation
- Key API guide
- `mutate` vs `mutateAsync`
- Legacy migration
## Conditional Queries
Prefer contract-shaped `queryOptions(...)`.
When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions.
Use `enabled` only for extra business gating after the input itself is already valid.
```typescript
import { skipToken, useQuery } from '@tanstack/react-query'
// Disable the query by skipping input construction.
function useAccessMode(appId: string | undefined) {
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
input: appId
? { params: { appId } }
: skipToken,
}))
}
// Avoid runtime-only guards that bypass type checking.
function useBadAccessMode(appId: string | undefined) {
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
input: { params: { appId: appId! } },
enabled: !!appId,
}))
}
```
## Cache Invalidation
Bind invalidation in the service-layer mutation definition.
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
Use:
- `.key()` for namespace or prefix invalidation
- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData`
- `queryClient.invalidateQueries(...)` in mutation `onSuccess`
Do not use deprecated `useInvalid` from `use-base.ts`.
```typescript
// Service layer owns cache invalidation.
export const useUpdateAccessMode = () => {
const queryClient = useQueryClient()
return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
})
},
}))
}
// Component only adds UI behavior.
updateAccessMode({ appId, mode }, {
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
})
// Avoid putting invalidation knowledge in the component.
mutate({ appId, mode }, {
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
})
},
})
```
## Key API Guide
- `.key(...)`
- Use for partial matching operations.
- Prefer it for invalidation, refetch, and cancel patterns.
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
- `.queryKey(...)`
- Use for a specific query's full key.
- Prefer it for exact cache addressing and direct reads or writes.
- `.mutationKey(...)`
- Use for a specific mutation's full key.
- Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping.
## `mutate` vs `mutateAsync`
Prefer `mutate` by default.
Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies.
Rules:
- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`.
- Every `await mutateAsync(...)` must be wrapped in `try/catch`.
- Do not use `mutateAsync` when callbacks already express the flow clearly.
```typescript
// Default case.
mutation.mutate(data, {
onSuccess: result => router.push(result.url),
})
// Promise semantics are required.
try {
const order = await createOrder.mutateAsync(orderData)
await confirmPayment.mutateAsync({ orderId: order.id, token })
router.push(`/orders/${order.id}`)
}
catch (error) {
Toast.notify({
type: 'error',
message: error instanceof Error ? error.message : 'Unknown error',
})
}
```
## Legacy Migration
When touching old code, migrate it toward these rules:
| Old pattern | New pattern |
|---|---|
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |

View File

@ -63,7 +63,8 @@ pnpm analyze-component <path> --review
### File Naming
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory
- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`.
- Integration tests: `web/__tests__/` directory
## Test Structure Template

View File

@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = vi.fn()
// vi.mock('next/navigation', () => ({
// vi.mock('@/next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))

View File

@ -1,103 +0,0 @@
---
name: orpc-contract-first
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service.
---
# oRPC Contract-First Development
## Intent
- Keep contract as single source of truth in `web/contract/*`.
- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
- Keep abstractions minimal and preserve TypeScript inference.
## Minimal Structure
```text
web/contract/
├── base.ts
├── router.ts
├── marketplace.ts
└── console/
├── billing.ts
└── ...other domains
web/service/client.ts
```
## Core Workflow
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`
- Use `base.route({...}).output(type<...>())` as baseline.
- Add `.input(type<...>())` only when request has `params/query/body`.
- For `GET` without input, omit `.input(...)` (do not use `.input(type<unknown>())`).
2. Register contract in `web/contract/router.ts`
- Import directly from domain files and nest by API prefix.
3. Consume from UI call sites via oRPC query utils.
```typescript
import { useQuery } from '@tanstack/react-query'
import { consoleQuery } from '@/service/client'
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
staleTime: 5 * 60 * 1000,
throwOnError: true,
select: invoice => invoice.url,
}))
```
## Query Usage Decision Rule
1. Default: call site directly uses `*.queryOptions(...)`.
2. If 3+ call sites share the same extra options (for example `retry: false`), extract a small queryOptions helper, not a `use-*` passthrough hook.
3. Create `web/service/use-{domain}.ts` only for orchestration:
- Combine multiple queries/mutations.
- Share domain-level derived state or invalidation helpers.
```typescript
const invoicesBaseQueryOptions = () =>
consoleQuery.billing.invoices.queryOptions({ retry: false })
const invoiceQuery = useQuery({
...invoicesBaseQueryOptions(),
throwOnError: true,
})
```
## Mutation Usage Decision Rule
1. Default: call mutation helpers from `consoleQuery` / `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
2. If mutation flow is heavily custom, use oRPC clients as `mutationFn` (for example `consoleClient.xxx` / `marketplaceClient.xxx`), instead of generic handwritten non-oRPC mutation logic.
## Key API Guide (`.key` vs `.queryKey` vs `.mutationKey`)
- `.key(...)`:
- Use for partial matching operations (recommended for invalidation/refetch/cancel patterns).
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
- `.queryKey(...)`:
- Use for a specific query's full key (exact query identity / direct cache addressing).
- `.mutationKey(...)`:
- Use for a specific mutation's full key.
- Typical use cases: mutation defaults registration, mutation-status filtering (`useIsMutating`, `queryClient.isMutating`), or explicit devtools grouping.
## Anti-Patterns
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
- Do not split local `queryKey/queryFn` when oRPC `queryOptions` already exists and fits the use case.
- Do not create thin `use-*` passthrough hooks for a single endpoint.
- Reason: these patterns can degrade inference (`data` may become `unknown`, especially around `throwOnError`/`select`) and add unnecessary indirection.
## Contract Rules
- **Input structure**: Always use `{ params, query?, body? }` format
- **No-input GET**: Omit `.input(...)`; do not use `.input(type<unknown>())`
- **Path params**: Use `{paramName}` in path, match in `params` object
- **Router nesting**: Group by API prefix (e.g., `/billing/*` -> `billing: {}`)
- **No barrel files**: Import directly from specific files
- **Types**: Import from `@/types/`, use `type<T>()` helper
- **Mutations**: Prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults/filtering/devtools
## Type Export
```typescript
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
```

View File

@ -0,0 +1 @@
../../.agents/skills/frontend-query-mutation

View File

@ -1 +0,0 @@
../../.agents/skills/orpc-contract-first

View File

@ -27,7 +27,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -39,7 +39,7 @@ jobs:
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'

View File

@ -113,7 +113,7 @@ jobs:
context: "web"
steps:
- name: Download digests
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*

View File

@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: "3.12"
@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -28,7 +28,7 @@ jobs:
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
id: changes
with:
filters: |
@ -63,8 +63,9 @@ jobs:
if: needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
with:
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
base_sha: ${{ github.event.before || github.event.pull_request.base.sha }}
diff_range_mode: ${{ github.event.before && 'exact' || 'merge-base' }}
head_sha: ${{ github.event.after || github.event.pull_request.head.sha || github.sha }}
style-check:
name: Style Check

View File

@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true

View File

@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: false
python-version: "3.12"

View File

@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -31,7 +31,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -6,6 +6,9 @@ on:
base_sha:
required: false
type: string
diff_range_mode:
required: false
type: string
head_sha:
required: false
type: string
@ -26,8 +29,8 @@ jobs:
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
shardIndex: [1, 2, 3, 4, 5, 6]
shardTotal: [6]
defaults:
run:
shell: bash
@ -77,7 +80,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Download blob reports
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with:
path: web/.vitest-reports
pattern: blob-report-*
@ -86,13 +89,24 @@ jobs:
- name: Merge reports
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
- name: Check app/components diff coverage
- name: Report app/components baseline coverage
run: node ./scripts/report-components-coverage-baseline.mjs
- name: Report app/components test touch
env:
BASE_SHA: ${{ inputs.base_sha }}
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
HEAD_SHA: ${{ inputs.head_sha }}
run: node ./scripts/report-components-test-touch.mjs
- name: Check app/components pure diff coverage
env:
BASE_SHA: ${{ inputs.base_sha }}
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
HEAD_SHA: ${{ inputs.head_sha }}
run: node ./scripts/check-components-diff-coverage.mjs
- name: Coverage Summary
- name: Check Coverage Summary
if: always()
id: coverage-summary
run: |
@ -101,313 +115,15 @@ jobs:
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
if [ -f "$COVERAGE_FILE" ] || [ -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Vitest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 app/components Diff Coverage" >> "$GITHUB_STEP_SUMMARY"
echo "" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage artifacts not found. Ensure Vitest merge reports ran with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'

3
.gitignore vendored
View File

@ -237,3 +237,6 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md
# Code Agent Folder
.qoder/*

View File

@ -22,10 +22,10 @@ APP_WEB_URL=http://localhost:3000
# Files URL
FILES_URL=http://localhost:5001
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
# Set this to the internal Docker service URL for proper plugin file access.
# Example: INTERNAL_FILES_URL=http://api:5001
INTERNAL_FILES_URL=http://127.0.0.1:5001
# INTERNAL_FILES_URL is used by services running in Docker to reach the API file endpoints.
# For Docker Desktop (Mac/Windows), use http://host.docker.internal:5001 when the API runs on the host.
# For Docker Compose on Linux, use http://api:5001 when the API runs inside the Docker network.
INTERNAL_FILES_URL=http://host.docker.internal:5001
# TRIGGER URL
TRIGGER_URL=http://localhost:5001
@ -180,7 +180,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
COOKIE_DOMAIN=
# Vector database configuration
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@ -217,6 +217,20 @@ COUCHBASE_PASSWORD=password
COUCHBASE_BUCKET_NAME=Embeddings
COUCHBASE_SCOPE_NAME=_default
# Hologres configuration
# access_key_id is used as the PG username, access_key_secret is used as the PG password
HOLOGRES_HOST=
HOLOGRES_PORT=80
HOLOGRES_DATABASE=
HOLOGRES_ACCESS_KEY_ID=
HOLOGRES_ACCESS_KEY_SECRET=
HOLOGRES_SCHEMA=public
HOLOGRES_TOKENIZER=jieba
HOLOGRES_DISTANCE_METHOD=Cosine
HOLOGRES_BASE_QUANTIZATION_TYPE=rabitq
HOLOGRES_MAX_DEGREE=64
HOLOGRES_EF_CONSTRUCTION=400
# Milvus configuration
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
@ -723,24 +737,25 @@ SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
# Redis URL used for PubSub between API and
# Redis URL used for event bus between API and
# celery worker
# defaults to url constructed from `REDIS_*`
# configurations
PUBSUB_REDIS_URL=
# Pub/sub channel type for streaming events.
# valid options are:
EVENT_BUS_REDIS_URL=
# Event transport type. Options are:
#
# - pubsub: for normal Pub/Sub
# - sharded: for sharded Pub/Sub
# - pubsub: normal Pub/Sub (at-most-once)
# - sharded: sharded Pub/Sub (at-most-once)
# - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)
#
# It's highly recommended to use sharded Pub/Sub AND redis cluster
# for large deployments.
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while running
# PubSub.
# Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.
# Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce
# the risk of data loss from Redis auto-eviction under memory pressure.
# Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE.
EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while use redis as event bus.
# It's highly recommended to enable this for large deployments.
PUBSUB_REDIS_USE_CLUSTERS=false
EVENT_BUS_REDIS_USE_CLUSTERS=false
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true

View File

@ -96,7 +96,6 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
@ -104,7 +103,6 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager
@ -116,7 +114,6 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
dify_graph.nodes.llm.node -> models.dataset
dify_graph.nodes.llm.file_saver -> core.tools.signature
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager

View File

@ -78,7 +78,7 @@ class UserProfile(TypedDict):
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
```python
from datetime import datetime

View File

@ -97,7 +97,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN mkdir -p /usr/local/share/nltk_data \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache

View File

@ -1,16 +1,45 @@
import logging
import time
from flask import request
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from controllers.console.error import UnauthorizedAndForceLogout
from core.logging.context import init_request_context
from dify_app import DifyApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
# Console bootstrap APIs exempt from license check.
# Defined at module level to avoid per-request tuple construction.
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
# - setup: install/setup status check (AppInitializer)
# - init: init password validation for fresh install (InitPasswordPopup)
# - login: auto-login after setup completion (InstallForm)
# - features: billing/plan features (ProviderContextProvider)
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
# - workspaces/current: workspace + model providers (AppContextProvider)
# - version: version check (AppContextProvider)
# - activate/check: invitation link validation (signin page)
# Without these exemptions, the signin page triggers location.reload()
# on unauthorized_and_force_logout, causing an infinite loop.
_CONSOLE_EXEMPT_PREFIXES = (
"/console/api/system-features",
"/console/api/setup",
"/console/api/init",
"/console/api/login",
"/console/api/features",
"/console/api/account/profile",
"/console/api/workspaces/current",
"/console/api/version",
"/console/api/activate/check",
)
# ----------------------------
# Application Factory Function
@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp:
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# Enterprise license validation for API endpoints (both console and webapp)
# When license expires, block all API access except bootstrap endpoints needed
# for the frontend to load the license expiration page without infinite reloads.
if dify_config.ENTERPRISE_ENABLED:
is_console_api = request.path.startswith("/console/api/")
is_webapp_api = request.path.startswith("/api/")
if is_console_api or is_webapp_api:
if is_console_api:
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
else: # webapp API
is_exempt = request.path.startswith("/api/system-features")
if not is_exempt:
try:
# Check license status (cached — see EnterpriseService for TTL details)
license_status = EnterpriseService.get_cached_license_status()
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
raise UnauthorizedAndForceLogout(
f"Enterprise license is {license_status}. Please contact your administrator."
)
if license_status is None:
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
except UnauthorizedAndForceLogout:
raise
except Exception:
logger.exception("Failed to check enterprise license status")
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request

View File

@ -88,6 +88,8 @@ def clean_workflow_runs(
"""
Clean workflow runs and related workflow data for free tenants.
"""
from extensions.otel.runtime import flush_telemetry
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
@ -104,16 +106,27 @@ def clean_workflow_runs(
end_before = now - datetime.timedelta(days=to_days_ago)
before_days = 0
if from_days_ago is not None and to_days_ago is not None:
task_label = f"{from_days_ago}to{to_days_ago}"
elif start_from is None:
task_label = f"before-{before_days}"
else:
task_label = "custom"
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
WorkflowRunCleanup(
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
).run()
try:
WorkflowRunCleanup(
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
task_label=task_label,
).run()
finally:
flush_telemetry()
end_time = datetime.datetime.now(datetime.UTC)
elapsed = end_time - start_time
@ -659,6 +672,8 @@ def clean_expired_messages(
"""
Clean expired messages and related data for tenants based on clean policy.
"""
from extensions.otel.runtime import flush_telemetry
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
@ -698,6 +713,13 @@ def clean_expired_messages(
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
if from_days_ago is not None and before_days is not None:
task_label = f"{from_days_ago}to{before_days}"
elif start_from is None and before_days is not None:
task_label = f"before-{before_days}"
else:
task_label = "custom"
# Create and run the cleanup service
if abs_mode:
assert start_from is not None
@ -708,6 +730,7 @@ def clean_expired_messages(
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
elif from_days_ago is None:
assert before_days is not None
@ -716,6 +739,7 @@ def clean_expired_messages(
days=before_days,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
else:
assert before_days is not None
@ -727,6 +751,7 @@ def clean_expired_messages(
end_before=now - datetime.timedelta(days=before_days),
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
stats = service.run()
@ -752,6 +777,8 @@ def clean_expired_messages(
)
)
raise
finally:
flush_telemetry()
click.echo(click.style("messages cleanup completed.", fg="green"))

View File

@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -160,6 +161,7 @@ def migrate_knowledge_vector_database():
}
lower_collection_vector_types = {
VectorType.ANALYTICDB,
VectorType.HOLOGRES,
VectorType.CHROMA,
VectorType.MYSCALE,
VectorType.PGVECTO_RS,
@ -241,7 +243,7 @@ def migrate_knowledge_vector_database():
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
@ -253,7 +255,7 @@ def migrate_knowledge_vector_database():
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.status == SegmentStatus.COMPLETED,
DocumentSegment.enabled == True,
)
).all()
@ -429,7 +431,7 @@ def old_metadata_migration():
tenant_id=document.tenant_id,
dataset_id=document.dataset_id,
name=key,
type="string",
type=DatasetMetadataType.STRING,
created_by=document.created_by,
)
db.session.add(dataset_metadata)

View File

@ -26,6 +26,7 @@ from .vdb.chroma_config import ChromaConfig
from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.hologres_config import HologresConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig
@ -347,6 +348,7 @@ class MiddlewareConfig(
AnalyticdbConfig,
ChromaConfig,
ClickzettaConfig,
HologresConfig,
HuaweiCloudConfig,
IrisVectorConfig,
MilvusConfig,

View File

@ -1,4 +1,4 @@
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
from pydantic_settings import BaseSettings
@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
description="Maximum connections in the Redis connection pool (unset for library default)",
default=None,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):
"""Allow empty string in env/.env to mean 'unset' (None)."""
if v is None:
return None
if isinstance(v, str) and v.strip() == "":
return None
return v

View File

@ -1,4 +1,4 @@
from typing import Literal, Protocol
from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
REDIS_PASSWORD: str | None
REDIS_DB: int
REDIS_USE_SSL: bool
REDIS_USE_SENTINEL: bool | None
REDIS_USE_CLUSTERS: bool
class RedisConfigDefaultsMixin:
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
return self
def _redis_defaults(config: object) -> RedisConfigDefaults:
return cast(RedisConfigDefaults, config)
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
class RedisPubSubConfig(BaseSettings):
"""
Configuration settings for event transport between API and workers.
@ -41,10 +38,10 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
)
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
validation_alias=AliasChoices("EVENT_BUS_REDIS_USE_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
description=(
"Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. "
"Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS."
"Also accepts ENV: EVENT_BUS_REDIS_USE_CLUSTERS."
),
default=False,
)
@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
if userinfo:
userinfo = f"{userinfo}@"
host = defaults.REDIS_HOST
port = defaults.REDIS_PORT
db = defaults.REDIS_DB
netloc = f"{userinfo}{host}:{port}"
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
@property

View File

@ -0,0 +1,68 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
class HologresConfig(BaseSettings):
"""
Configuration settings for Hologres vector database.
Hologres is compatible with PostgreSQL protocol.
access_key_id is used as the PostgreSQL username,
and access_key_secret is used as the PostgreSQL password.
"""
HOLOGRES_HOST: str | None = Field(
description="Hostname or IP address of the Hologres instance.",
default=None,
)
HOLOGRES_PORT: int = Field(
description="Port number for connecting to the Hologres instance.",
default=80,
)
HOLOGRES_DATABASE: str | None = Field(
description="Name of the Hologres database to connect to.",
default=None,
)
HOLOGRES_ACCESS_KEY_ID: str | None = Field(
description="Alibaba Cloud AccessKey ID, also used as the PostgreSQL username.",
default=None,
)
HOLOGRES_ACCESS_KEY_SECRET: str | None = Field(
description="Alibaba Cloud AccessKey Secret, also used as the PostgreSQL password.",
default=None,
)
HOLOGRES_SCHEMA: str = Field(
description="Schema name in the Hologres database.",
default="public",
)
HOLOGRES_TOKENIZER: TokenizerType = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)
HOLOGRES_MAX_DEGREE: int = Field(
description="Max degree (M) parameter for HNSW vector index.",
default=64,
)
HOLOGRES_EF_CONSTRUCTION: int = Field(
description="ef_construction parameter for HNSW vector index.",
default=400,
)

View File

@ -25,7 +25,8 @@ from controllers.console.wraps import (
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.enums import NodeType, WorkflowExecutionStatus
from core.trigger.constants import TRIGGER_NODE_TYPES
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import helpers as file_helpers
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
@ -508,11 +509,7 @@ class AppListApi(Resource):
.scalars()
.all()
)
trigger_node_types = {
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
}
trigger_node_types = TRIGGER_NODE_TYPES
for workflow in draft_workflows:
node_id = None
try:

View File

@ -22,6 +22,7 @@ from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.trace_id_helper import get_external_trace_id
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from core.trigger.debug.event_selectors import (
TriggerDebugEvent,
TriggerDebugEventPoller,
@ -1209,7 +1210,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
event: TriggerDebugEvent | None = None
# for schedule trigger, when run single node, just execute directly
if node_type == NodeType.TRIGGER_SCHEDULE:
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
event = TriggerDebugEvent(
workflow_args={},
node_id=node_id,

View File

@ -23,7 +23,7 @@ from dify_graph.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import login_required
from libs.login import current_user, login_required
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@ -100,6 +100,18 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
}
def _ensure_variable_access(
variable: WorkflowDraftVariable | None,
app_id: str,
variable_id: str,
) -> WorkflowDraftVariable:
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_id or variable.user_id != current_user.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
@ -238,6 +250,7 @@ class WorkflowVariableCollectionApi(Resource):
app_id=app_model.id,
page=args.page,
limit=args.limit,
user_id=current_user.id,
)
return workflow_vars
@ -250,7 +263,7 @@ class WorkflowVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(app_model.id)
draft_var_srv.delete_user_workflow_variables(app_model.id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@ -287,7 +300,7 @@ class NodeVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id, user_id=current_user.id)
return node_vars
@ -298,7 +311,7 @@ class NodeVariableCollectionApi(Resource):
def delete(self, app_model: App, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(app_model.id, node_id)
srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@ -319,11 +332,11 @@ class VariableApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id,
)
return variable
@console_ns.doc("update_variable")
@ -360,11 +373,11 @@ class VariableApi(Resource):
)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id,
)
new_name = args_model.name
raw_value = args_model.value
@ -397,11 +410,11 @@ class VariableApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id,
)
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@ -427,11 +440,11 @@ class VariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id,
)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
@ -447,11 +460,15 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(app_model.id)
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
draft_vars = draft_var_srv.list_node_variables(
app_id=app_model.id,
node_id=node_id,
user_id=current_user.id,
)
return draft_vars
@ -472,7 +489,7 @@ class ConversationVariableCollectionApi(Resource):
if draft_workflow is None:
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id)
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)

View File

@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import SegmentStatus
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -263,6 +264,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
VectorType.HOLOGRES,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
@ -740,13 +742,15 @@ class DatasetIndexingStatusApi(Resource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields

View File

@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.enums import IndexingStatus, SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
document.completed_segments = completed_segments
@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = []
for document in documents:
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
.count()
)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource):
match action:
case "pause":
if document.indexing_status != "indexing":
if document.indexing_status != IndexingStatus.INDEXING:
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
case "resume":
if document.indexing_status not in {"paused", "error"}:
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource):
raise ArchivedDocumentImmutableError()
# 400 if document is completed
if document.indexing_status == "completed":
if document.indexing_status == IndexingStatus.COMPLETED:
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception:

View File

@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
if pipeline_template is None:
return {"error": "Pipeline template not found from upstream service."}, 404
return pipeline_template, 200

View File

@ -102,6 +102,7 @@ class RagPipelineVariableCollectionApi(Resource):
app_id=pipeline.id,
page=query.page,
limit=query.limit,
user_id=current_user.id,
)
return workflow_vars
@ -111,7 +112,7 @@ class RagPipelineVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(pipeline.id)
draft_var_srv.delete_user_workflow_variables(pipeline.id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@ -144,7 +145,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id, user_id=current_user.id)
return node_vars
@ -152,7 +153,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
def delete(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(pipeline.id, node_id)
srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@ -283,11 +284,11 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id)
return draft_vars

View File

@ -36,6 +36,7 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource):
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
)
.count()
)
# Create a dictionary with document attributes and additional fields

View File

@ -3,7 +3,7 @@ import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
from flask import current_app, request
from flask_login import user_logged_in
@ -44,10 +44,22 @@ class FetchUserArg(BaseModel):
required: bool = False
def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None):
def decorator(view_func: Callable[P, R]):
@overload
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
@overload
def validate_app_token(
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def validate_app_token(
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
api_token = validate_and_get_api_token("app")
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
@ -213,10 +225,20 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
return interceptor
def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@overload
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
@overload
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
def validate_dataset_token(
view: Callable[Concatenate[T, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
@wraps(view_func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs
@ -287,7 +309,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view(api_token.tenant_id, *args, **kwargs)
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
return decorated

View File

@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def handle_webhook_debug(webhook_id: str):
"""Handle webhook debug calls without triggering production workflow execution."""
"""Handle webhook debug calls without triggering production workflow execution.
The debug webhook endpoint is only for draft inspection flows. It never enqueues
Celery work for the published workflow; instead it dispatches an in-memory debug
event to an active Variable Inspector listener. Returning a clear error when no
listener is registered prevents a misleading 200 response for requests that are
effectively dropped.
"""
try:
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
if error:
@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
"method": webhook_data.get("method"),
},
)
TriggerDebugEventBus.dispatch(
dispatch_count = TriggerDebugEventBus.dispatch(
tenant_id=webhook_trigger.tenant_id,
event=event,
pool_key=pool_key,
)
if dispatch_count == 0:
logger.warning(
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
webhook_trigger.webhook_id,
webhook_trigger.tenant_id,
webhook_trigger.app_id,
webhook_trigger.node_id,
)
return (
jsonify(
{
"error": "No active debug listener",
"message": (
"The webhook debug URL only works while the Variable Inspector is listening. "
"Use the published webhook URL to execute the workflow in Celery."
),
"execution_url": webhook_trigger.webhook_url,
}
),
409,
)
response_data, status_code = WebhookService.generate_webhook_response(node_config)
return jsonify(response_data), status_code

View File

@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
agent_thoughts = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tool_names_raw = agent_thought.tool

View File

@ -1,13 +1,36 @@
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
class SystemParametersDict(TypedDict):
image_file_size_limit: int
video_file_size_limit: int
audio_file_size_limit: int
file_size_limit: int
workflow_file_upload_limit: int
class AppParametersDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: dict[str, Any]
speech_to_text: dict[str, Any]
text_to_speech: dict[str, Any]
retriever_resource: dict[str, Any]
annotation_reply: dict[str, Any]
more_like_this: dict[str, Any]
user_input_form: list[dict[str, Any]]
sensitive_word_avoidance: dict[str, Any]
file_upload: dict[str, Any]
system_parameters: SystemParametersDict
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
) -> AppParametersDict:
"""
Mapping from feature dict to webapp parameters
"""

View File

@ -8,6 +8,7 @@ from core.app.app_config.entities import (
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from models.model import AppMode, AppModelConfigDict
from services.dataset_service import DatasetService
@ -117,8 +118,10 @@ class DatasetConfigManager:
score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None,
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=weights_val if isinstance(weights_val, dict) else None,
reranking_model=cast(RerankingModelDict, reranking_model_val)
if isinstance(reranking_model_val, dict)
else None,
weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None,
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=cast(

View File

@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from dify_graph.file import FileUploadConfig
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel):
top_k: int | None = None
score_threshold: float | None = 0.0
rerank_mode: str | None = "reranking_model"
reranking_model: dict | None = None
weights: dict | None = None
reranking_model: RerankingModelDict | None = None
weights: WeightsDict | None = None
reranking_enabled: bool | None = True
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None

View File

@ -330,9 +330,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
return self._generate(
workflow=workflow,
@ -413,9 +414,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
return self._generate(
workflow=workflow,

View File

@ -69,7 +69,7 @@ from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.nodes import NodeType
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
from dify_graph.runtime import GraphRuntimeState
from dify_graph.system_variable import SystemVariable
@ -357,7 +357,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
if event.node_type in [BuiltinNodeTypes.ANSWER, BuiltinNodeTypes.END, BuiltinNodeTypes.LLM]:
self._recorded_files.extend(
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)

View File

@ -3,7 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from typing import Any, NewType, TypedDict, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -48,12 +48,13 @@ from core.app.entities.task_entities import (
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import (
NodeType,
BuiltinNodeTypes,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
@ -75,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str)
logger = logging.getLogger(__name__)
class AccountCreatedByDict(TypedDict):
id: str
name: str
email: str
class EndUserCreatedByDict(TypedDict):
id: str
user: str
CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict
@dataclass(slots=True)
class _NodeSnapshot:
"""In-memory cache for node metadata between start and completion events."""
@ -248,19 +263,19 @@ class WorkflowResponseConverter:
outputs_mapping = graph_runtime_state.outputs or {}
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
created_by: Mapping[str, object] | None
created_by: CreatedByDict | dict[str, object] = {}
user = self._user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
created_by = AccountCreatedByDict(
id=user.id,
name=user.name,
email=user.email,
)
elif isinstance(user, EndUser):
created_by = EndUserCreatedByDict(
id=user.id,
user=user.session_id,
)
return WorkflowFinishStreamResponse(
task_id=task_id,
@ -442,7 +457,7 @@ class WorkflowResponseConverter:
event: QueueNodeStartedEvent,
task_id: str,
) -> NodeStartStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._store_snapshot(event)
@ -464,13 +479,13 @@ class WorkflowResponseConverter:
)
try:
if event.node_type == NodeType.TOOL:
if event.node_type == BuiltinNodeTypes.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
elif event.node_type == BuiltinNodeTypes.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
@ -479,7 +494,7 @@ class WorkflowResponseConverter:
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
elif event.node_type == NodeType.TRIGGER_PLUGIN:
elif event.node_type == TRIGGER_PLUGIN_NODE_TYPE:
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
@ -496,7 +511,7 @@ class WorkflowResponseConverter:
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
) -> NodeFinishStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._pop_snapshot(event.node_execution_id)
@ -554,7 +569,7 @@ class WorkflowResponseConverter:
event: QueueNodeRetryEvent,
task_id: str,
) -> NodeRetryStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
@ -612,7 +627,7 @@ class WorkflowResponseConverter:
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
node_type=event.node_type,
title=event.node_title,
created_at=int(time.time()),
extras={},
@ -635,7 +650,7 @@ class WorkflowResponseConverter:
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
node_type=event.node_type,
title=event.node_title,
index=event.index,
created_at=int(time.time()),
@ -662,7 +677,7 @@ class WorkflowResponseConverter:
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
node_type=event.node_type,
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
@ -692,7 +707,7 @@ class WorkflowResponseConverter:
data=LoopNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
node_type=event.node_type,
title=event.node_title,
created_at=int(time.time()),
extras={},
@ -715,7 +730,7 @@ class WorkflowResponseConverter:
data=LoopNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
node_type=event.node_type,
title=event.node_title,
index=event.index,
# The `pre_loop_output` field is not utilized by the frontend.
@ -744,7 +759,7 @@ class WorkflowResponseConverter:
data=LoopNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
node_type=event.node_type,
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,

View File

@ -419,11 +419,12 @@ class PipelineGenerator(BaseAppGenerator):
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
return self._generate(
@ -514,11 +515,12 @@ class PipelineGenerator(BaseAppGenerator):
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
return self._generate(

View File

@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
build_dify_run_context,
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.graph_init_params import GraphInitParams
from dify_graph.enums import WorkflowType
@ -274,6 +274,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if start_node_id is None:
start_node_id = get_default_root_node_id(graph_config)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
if not graph:

View File

@ -414,11 +414,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
return self._generate(
@ -497,11 +498,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
return self._generate(
app_model=app_model,

View File

@ -32,8 +32,8 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.node_resolution import resolve_workflow_node_class
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
@ -140,6 +140,9 @@ class WorkflowBasedAppRunner:
graph_runtime_state=graph_runtime_state,
)
if root_node_id is None:
root_node_id = get_default_root_node_id(graph_config)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
@ -505,7 +508,9 @@ class WorkflowBasedAppRunner:
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources,
retriever_resources=[
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
],
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)

View File

@ -9,9 +9,8 @@ from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.nodes import NodeType
class QueueEvent(StrEnum):

View File

@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.enums import CollectionBindingType
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.dataset_service import DatasetCollectionBindingService
@ -43,7 +44,7 @@ class AnnotationReplyFeature:
embedding_model_name = collection_binding_detail.model_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
)
dataset = Dataset(

View File

@ -2,7 +2,7 @@ import logging
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.conversation_variable_updater import ConversationVariableUpdater
from dify_graph.enums import NodeType
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
@ -22,7 +22,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunSucceededEvent):
return
if event.node_type != NodeType.VARIABLE_ASSIGNER:
if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER:
return
if self.graph_runtime_state is None:
return

View File

@ -1,3 +1,5 @@
from typing import TypedDict
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
@ -6,7 +8,20 @@ from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
class MessageFileInfoDict(TypedDict):
related_id: str
extension: str
filename: str
size: int
mime_type: str
transfer_method: str
type: str
url: str
upload_file_id: str
remote_url: str | None
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict:
"""
Prepare file dictionary for message end stream response.

View File

@ -12,7 +12,7 @@ from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from dify_graph.enums import NodeType
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase
@ -113,11 +113,11 @@ class LLMQuotaLayer(GraphEngineLayer):
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case NodeType.LLM:
case BuiltinNodeTypes.LLM:
return cast("LLMNode", node).model_instance
case NodeType.PARAMETER_EXTRACTOR:
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case NodeType.QUESTION_CLASSIFIER:
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None

View File

@ -16,7 +16,7 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_
from typing_extensions import override
from configs import dify_config
from dify_graph.enums import NodeType
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphNodeEventBase
from dify_graph.nodes.base.node import Node
@ -74,16 +74,13 @@ class ObservabilityLayer(GraphEngineLayer):
def _build_parser_registry(self) -> None:
"""Initialize parser registry for node types."""
self._parsers = {
NodeType.TOOL: ToolNodeOTelParser(),
NodeType.LLM: LLMNodeOTelParser(),
NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
BuiltinNodeTypes.TOOL: ToolNodeOTelParser(),
BuiltinNodeTypes.LLM: LLMNodeOTelParser(),
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
return self._parsers.get(node_type, self._default_parser)
return self._default_parser
return self._parsers.get(node.node_type, self._default_parser)
@override
def on_graph_start(self) -> None:

View File

@ -12,7 +12,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
_logger = logging.getLogger(__name__)
@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source=DatasetQuerySource.APP,
source_app_id=self._app_id,
created_by_role=(
CreatorUserRole.ACCOUNT

View File

@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.db.session_factory import session_factory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from dify_graph.file import File
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile

View File

@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
from models.provider import (
LoadBalancingModelConfig,
Provider,
@ -473,9 +474,21 @@ class ProviderConfiguration(BaseModel):
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
else:
# some historical data may have a provider record but not be set as valid
provider_record.is_valid = True
if provider_record.credential_id is None:
provider_record.credential_id = new_record.id
provider_record.updated_at = naive_utc_now()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
session.commit()
except Exception:
session.rollback()
@ -534,7 +547,7 @@ class ProviderConfiguration(BaseModel):
self._update_load_balancing_configs_with_credential(
credential_id=credential_id,
credential_record=credential_record,
credential_source="provider",
credential_source=CredentialSourceType.PROVIDER,
session=session,
)
except Exception:
@ -611,7 +624,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
)
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
try:
@ -1031,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
self._update_load_balancing_configs_with_credential(
credential_id=credential_id,
credential_record=credential_record,
credential_source="custom_model",
credential_source=CredentialSourceType.CUSTOM_MODEL,
session=session,
)
except Exception:
@ -1061,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
)
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
@ -1699,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
provider_model_lb_configs = [
config
for config in model_setting.load_balancing_configs
if config.credential_source_type != "custom_model"
if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
]
load_balancing_enabled = model_setting.load_balancing_enabled
@ -1757,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
custom_model_lb_configs = [
config
for config in model_setting.load_balancing_configs
if config.credential_source_type != "provider"
if config.credential_source_type != CredentialSourceType.PROVIDER
]
load_balancing_enabled = model_setting.load_balancing_enabled

View File

@ -5,6 +5,7 @@ import re
import threading
import time
import uuid
from collections.abc import Mapping
from typing import Any
from flask import Flask, current_app
@ -37,8 +38,9 @@ from extensions.ext_storage import storage
from libs import helper
from libs.datetime_utils import naive_utc_now
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
from models.model import UploadFile
from services.feature_service import FeatureService
@ -55,7 +57,7 @@ class IndexingRunner:
logger.exception("consume document failed")
document = db.session.get(DatasetDocument, document_id)
if document:
document.indexing_status = "error"
document.indexing_status = IndexingStatus.ERROR
error_message = getattr(error, "description", str(error))
document.error = str(error_message)
document.stopped_at = naive_utc_now()
@ -218,7 +220,7 @@ class IndexingRunner:
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
if document_segment.status != SegmentStatus.COMPLETED:
document = Document(
page_content=document_segment.content,
metadata={
@ -265,7 +267,7 @@ class IndexingRunner:
self,
tenant_id: str,
extract_settings: list[ExtractSetting],
tmp_processing_rule: dict,
tmp_processing_rule: Mapping[str, Any],
doc_form: str | None = None,
doc_language: str = "English",
dataset_id: str | None = None,
@ -376,12 +378,12 @@ class IndexingRunner:
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: Mapping[str, Any]
) -> list[Document]:
data_source_info = dataset_document.data_source_info_dict
text_docs = []
match dataset_document.data_source_type:
case "upload_file":
case DataSourceType.UPLOAD_FILE:
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
@ -394,7 +396,7 @@ class IndexingRunner:
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "notion_import":
case DataSourceType.NOTION_IMPORT:
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
@ -416,7 +418,7 @@ class IndexingRunner:
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "website_crawl":
case DataSourceType.WEBSITE_CRAWL:
if (
not data_source_info
or "provider" not in data_source_info
@ -444,7 +446,7 @@ class IndexingRunner:
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="splitting",
after_indexing_status=IndexingStatus.SPLITTING,
extra_update_params={
DatasetDocument.parsing_completed_at: naive_utc_now(),
},
@ -543,7 +545,8 @@ class IndexingRunner:
"""
Clean the document text according to the processing rules.
"""
if processing_rule.mode == "automatic":
rules: AutomaticRulesConfig | dict[str, Any]
if processing_rule.mode == ProcessRuleMode.AUTOMATIC:
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
@ -634,7 +637,7 @@ class IndexingRunner:
# update document status to completed
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="completed",
after_indexing_status=IndexingStatus.COMPLETED,
extra_update_params={
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: naive_utc_now(),
@ -657,10 +660,10 @@ class IndexingRunner:
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing",
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
@ -701,10 +704,10 @@ class IndexingRunner:
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing",
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
@ -723,7 +726,7 @@ class IndexingRunner:
@staticmethod
def _update_document_index_status(
document_id: str, after_indexing_status: str, extra_update_params: dict | None = None
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
):
"""
Update the document indexing status.
@ -756,7 +759,7 @@ class IndexingRunner:
dataset: Dataset,
text_docs: list[Document],
doc_language: str,
process_rule: dict,
process_rule: Mapping[str, Any],
current_user: Account | None = None,
) -> list[Document]:
# get embedding model instance
@ -801,7 +804,7 @@ class IndexingRunner:
cur_time = naive_utc_now()
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
after_indexing_status=IndexingStatus.INDEXING,
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
@ -813,7 +816,7 @@ class IndexingRunner:
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.status: SegmentStatus.INDEXING,
DocumentSegment.indexing_at: naive_utc_now(),
},
)

View File

@ -55,15 +55,31 @@ def build_protected_resource_metadata_discovery_urls(
"""
urls = []
parsed_server_url = urlparse(server_url)
base_url = f"{parsed_server_url.scheme}://{parsed_server_url.netloc}"
path = parsed_server_url.path.rstrip("/")
# First priority: URL from WWW-Authenticate header
if www_auth_resource_metadata_url:
urls.append(www_auth_resource_metadata_url)
parsed_metadata_url = urlparse(www_auth_resource_metadata_url)
normalized_metadata_url = None
if parsed_metadata_url.scheme and parsed_metadata_url.netloc:
normalized_metadata_url = www_auth_resource_metadata_url
elif not parsed_metadata_url.scheme and parsed_metadata_url.netloc:
normalized_metadata_url = f"{parsed_server_url.scheme}:{www_auth_resource_metadata_url}"
elif (
not parsed_metadata_url.scheme
and not parsed_metadata_url.netloc
and parsed_metadata_url.path.startswith("/")
):
first_segment = parsed_metadata_url.path.lstrip("/").split("/", 1)[0]
if first_segment == ".well-known" or "." not in first_segment:
normalized_metadata_url = urljoin(base_url, parsed_metadata_url.path)
if normalized_metadata_url:
urls.append(normalized_metadata_url)
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
if path:
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"

View File

@ -58,7 +58,7 @@ from core.ops.entities.trace_entity import (
)
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import WorkflowNodeExecutionTriggeredFrom
@ -302,11 +302,11 @@ class AliyunDataTrace(BaseTraceInstance):
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
):
try:
if node_execution.node_type == NodeType.LLM:
if node_execution.node_type == BuiltinNodeTypes.LLM:
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == NodeType.TOOL:
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
else:
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)

View File

@ -155,8 +155,8 @@ def wrap_span_metadata(metadata, **kwargs):
return metadata
# Mapping from NodeType string values to OpenInference span kinds.
# NodeType values not listed here default to CHAIN.
# Mapping from built-in node type strings to OpenInference span kinds.
# Node types not listed here default to CHAIN.
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
"llm": OpenInferenceSpanKindValues.LLM,
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
@ -168,7 +168,7 @@ _NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
"""Return the OpenInference span kind for a given workflow node type.
Covers every ``NodeType`` enum value. Nodes that do not have a
Covers every built-in node type string. Nodes that do not have a
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
"""

View File

@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import NodeType
from dify_graph.enums import BuiltinNodeTypes
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
@ -141,7 +141,7 @@ class LangFuseDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
if node_type == BuiltinNodeTypes.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -163,7 +163,7 @@ class LangSmithDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
if node_type == BuiltinNodeTypes.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}
@ -197,7 +197,7 @@ class LangSmithDataTrace(BaseTraceInstance):
"ls_model_name": process_data.get("model_name", ""),
}
)
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
run_type = LangSmithRunType.retriever
else:
run_type = LangSmithRunType.tool

View File

@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from dify_graph.enums import NodeType
from dify_graph.enums import BuiltinNodeTypes
from extensions.ext_database import db
from models import EndUser
from models.workflow import WorkflowNodeExecutionModel
@ -145,10 +145,10 @@ class MLflowDataTrace(BaseTraceInstance):
"app_name": node.title,
}
if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER):
if node.node_type in (BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER):
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
attributes.update(llm_attributes)
elif node.node_type == NodeType.HTTP_REQUEST:
elif node.node_type == BuiltinNodeTypes.HTTP_REQUEST:
inputs = node.process_data # contains request URL
if not inputs:
@ -180,9 +180,9 @@ class MLflowDataTrace(BaseTraceInstance):
# End node span
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
outputs = json.loads(node.outputs) if node.outputs else {}
if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
outputs = self._parse_knowledge_retrieval_outputs(outputs)
elif node.node_type == NodeType.LLM:
elif node.node_type == BuiltinNodeTypes.LLM:
outputs = outputs.get("text", outputs)
node_span.end(
outputs=outputs,
@ -471,13 +471,13 @@ class MLflowDataTrace(BaseTraceInstance):
def _get_node_span_type(self, node_type: str) -> str:
"""Map Dify node types to MLflow span types"""
node_type_mapping = {
NodeType.LLM: SpanType.LLM,
NodeType.QUESTION_CLASSIFIER: SpanType.LLM,
NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
NodeType.TOOL: SpanType.TOOL,
NodeType.CODE: SpanType.TOOL,
NodeType.HTTP_REQUEST: SpanType.TOOL,
NodeType.AGENT: SpanType.AGENT,
BuiltinNodeTypes.LLM: SpanType.LLM,
BuiltinNodeTypes.QUESTION_CLASSIFIER: SpanType.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
BuiltinNodeTypes.TOOL: SpanType.TOOL,
BuiltinNodeTypes.CODE: SpanType.TOOL,
BuiltinNodeTypes.HTTP_REQUEST: SpanType.TOOL,
BuiltinNodeTypes.AGENT: SpanType.AGENT,
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]

View File

@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -187,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
if node_type == BuiltinNodeTypes.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@ -27,7 +27,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from dify_graph.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from dify_graph.nodes import NodeType
from dify_graph.nodes import BuiltinNodeTypes
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
@ -179,7 +179,7 @@ class TencentDataTrace(BaseTraceInstance):
if node_span:
self.trace_client.add_span(node_span)
if node_execution.node_type == NodeType.LLM:
if node_execution.node_type == BuiltinNodeTypes.LLM:
self._record_llm_metrics(node_execution)
except Exception:
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
@ -192,15 +192,15 @@ class TencentDataTrace(BaseTraceInstance):
) -> SpanData | None:
"""Build span for different node types"""
try:
if node_execution.node_type == NodeType.LLM:
if node_execution.node_type == BuiltinNodeTypes.LLM:
return TencentSpanBuilder.build_workflow_llm_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
return TencentSpanBuilder.build_workflow_retrieval_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.TOOL:
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
return TencentSpanBuilder.build_workflow_tool_span(
trace_id, workflow_span_id, trace_info, node_execution
)

View File

@ -31,7 +31,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -175,7 +175,7 @@ class WeaveDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == NodeType.LLM:
if node_type == BuiltinNodeTypes.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@ -1,5 +1,5 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from dify_graph.enums import NodeType
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
node_data_dict["type"] = BuiltinNodeTypes.PARAMETER_EXTRACTOR
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,

View File

@ -196,6 +196,8 @@ class ProviderManager:
if preferred_provider_type_record:
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM
elif custom_configuration.provider or custom_configuration.models:
preferred_provider_type = ProviderType.CUSTOM
elif system_configuration.enabled:
@ -305,9 +307,7 @@ class ProviderManager:
available_models = provider_configurations.get_models(model_type=model_type, only_active=True)
if available_models:
available_model = next(
(model for model in available_models if model.model == "gpt-4"), available_models[0]
)
available_model = available_models[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,

View File

@ -1,3 +1,5 @@
from typing_extensions import TypedDict
from core.model_manager import ModelInstance, ModelManager
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
class RerankingModelDict(TypedDict):
reranking_provider_name: str
reranking_model_name: str
class VectorSettingDict(TypedDict):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSettingDict(TypedDict):
keyword_weight: float
class WeightsDict(TypedDict):
vector_setting: VectorSettingDict
keyword_setting: KeywordSettingDict
class DataPostProcessor:
"""Interface for data post-processing document."""
@ -17,8 +39,8 @@ class DataPostProcessor:
self,
tenant_id: str,
reranking_mode: str,
reranking_model: dict | None = None,
weights: dict | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
reorder_enabled: bool = False,
):
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
@ -45,8 +67,8 @@ class DataPostProcessor:
self,
reranking_mode: str,
tenant_id: str,
reranking_model: dict | None = None,
weights: dict | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
) -> BaseRerankRunner | None:
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
runner = RerankRunnerFactory.create_rerank_runner(
@ -79,12 +101,14 @@ class DataPostProcessor:
return ReorderRunner()
return None
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
def _get_rerank_model_instance(
self, tenant_id: str, reranking_model: RerankingModelDict | None
) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
reranking_provider_name = reranking_model["reranking_provider_name"]
reranking_model_name = reranking_model["reranking_model_name"]
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(

View File

@ -1,19 +1,20 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, NotRequired
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from typing_extensions import TypedDict
from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
class SegmentAttachmentResult(TypedDict):
attachment_info: AttachmentInfoDict
segment_id: str
class SegmentAttachmentInfoResult(TypedDict):
attachment_id: str
attachment_info: AttachmentInfoDict
segment_id: str
class ChildChunkDetail(TypedDict):
id: str
content: str
position: int
score: float
class SegmentChildMapDetail(TypedDict):
max_score: float
child_chunks: list[ChildChunkDetail]
class SegmentRecord(TypedDict):
segment: DocumentSegment
score: NotRequired[float]
child_chunks: NotRequired[list[ChildChunkDetail]]
files: NotRequired[list[AttachmentInfoDict]]
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod | str
reranking_enable: bool
reranking_model: RerankingModelDict
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -56,9 +96,9 @@ class RetrievalService:
query: str,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
@ -235,7 +275,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: RetrievalMethod,
exceptions: list,
@ -277,8 +317,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
@ -288,8 +328,8 @@ class RetrievalService:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
provider=reranking_model["reranking_provider_name"],
model=reranking_model["reranking_model_name"],
model_type=ModelType.RERANK,
)
if is_support_vision:
@ -329,7 +369,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: str,
exceptions: list,
@ -349,8 +389,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
@ -459,7 +499,7 @@ class RetrievalService:
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[dict[str, Any]]] = {}
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
@ -544,12 +584,12 @@ class RetrievalService:
segment_summary_map[summary.chunk_id] = summary.summary_content
include_segment_ids = set()
segment_child_map: dict[str, dict[str, Any]] = {}
records: list[dict[str, Any]] = []
segment_child_map: dict[str, SegmentChildMapDetail] = {}
records: list[SegmentRecord] = []
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
@ -560,14 +600,14 @@ class RetrievalService:
max_score = summary_score_map.get(segment.id, 0.0)
if child_chunks or attachment_infos:
child_chunk_details = []
child_chunk_details: list[ChildChunkDetail] = []
for child_chunk in child_chunks:
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
if child_document:
child_score = child_document.metadata.get("score", 0.0)
else:
child_score = 0.0
child_chunk_detail = {
child_chunk_detail: ChildChunkDetail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
@ -580,7 +620,7 @@ class RetrievalService:
if file_document:
max_score = max(max_score, file_document.metadata.get("score", 0.0))
map_detail = {
map_detail: SegmentChildMapDetail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
@ -593,7 +633,7 @@ class RetrievalService:
"max_score": summary_score,
"child_chunks": [],
}
record: dict[str, Any] = {
record: SegmentRecord = {
"segment": segment,
}
records.append(record)
@ -617,19 +657,19 @@ class RetrievalService:
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
another_record: SegmentRecord = {
"segment": segment,
"score": max_score,
}
records.append(record)
records.append(another_record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
record["files"] = attachment_map[record["segment"].id]
result: list[RetrievalSegments] = []
for record in records:
@ -693,9 +733,9 @@ class RetrievalService:
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
):
@ -807,7 +847,7 @@ class RetrievalService:
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
) -> SegmentAttachmentResult | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
@ -816,7 +856,7 @@ class RetrievalService:
.first()
)
if attachment_binding:
attachment_info = {
attachment_info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -828,8 +868,10 @@ class RetrievalService:
return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
def get_segment_attachment_infos(
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
@ -843,7 +885,7 @@ class RetrievalService:
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -855,7 +897,7 @@ class RetrievalService:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"attachment_info": info,
"segment_id": attachment_binding.segment_id,
}
)

View File

@ -0,0 +1,361 @@
import json
import logging
import time
from typing import Any
import holo_search_sdk as holo # type: ignore
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from psycopg import sql as psql
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class HologresVectorConfig(BaseModel):
"""
Configuration for Hologres vector database connection.
In Hologres, access_key_id is used as the PostgreSQL username,
and access_key_secret is used as the PostgreSQL password.
"""
host: str
port: int = 80
database: str
access_key_id: str
access_key_secret: str
schema_name: str = "public"
tokenizer: TokenizerType = "jieba"
distance_method: DistanceType = "Cosine"
base_quantization_type: BaseQuantizationType = "rabitq"
max_degree: int = 64
ef_construction: int = 400
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values.get("host"):
raise ValueError("config HOLOGRES_HOST is required")
if not values.get("database"):
raise ValueError("config HOLOGRES_DATABASE is required")
if not values.get("access_key_id"):
raise ValueError("config HOLOGRES_ACCESS_KEY_ID is required")
if not values.get("access_key_secret"):
raise ValueError("config HOLOGRES_ACCESS_KEY_SECRET is required")
return values
class HologresVector(BaseVector):
"""
Hologres vector storage implementation using holo-search-sdk.
Supports semantic search (vector), full-text search, and hybrid search.
"""
def __init__(self, collection_name: str, config: HologresVectorConfig):
super().__init__(collection_name)
self._config = config
self._client = self._init_client(config)
self.table_name = f"embedding_{collection_name}".lower()
def _init_client(self, config: HologresVectorConfig):
"""Initialize and return a holo-search-sdk client."""
client = holo.connect(
host=config.host,
port=config.port,
database=config.database,
access_key_id=config.access_key_id,
access_key_secret=config.access_key_secret,
schema=config.schema_name,
)
client.connect()
return client
def get_type(self) -> str:
return VectorType.HOLOGRES
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""Create collection table with vector and full-text indexes, then add texts."""
dimension = len(embeddings[0])
self._create_collection(dimension)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Add texts with embeddings to the collection using batch upsert."""
if not documents:
return []
pks: list[str] = []
batch_size = 100
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
values = []
column_names = ["id", "text", "meta", "embedding"]
for j, doc in enumerate(batch_docs):
doc_id = doc.metadata.get("doc_id", "") if doc.metadata else ""
pks.append(doc_id)
values.append(
[
doc_id,
doc.page_content,
json.dumps(doc.metadata or {}),
batch_embeddings[j],
]
)
table = self._client.open_table(self.table_name)
table.upsert_multi(
index_column="id",
values=values,
column_names=column_names,
update=True,
update_columns=["text", "meta", "embedding"],
)
return pks
def text_exists(self, id: str) -> bool:
"""Check if a text with the given doc_id exists in the collection."""
if not self._client.check_table_exist(self.table_name):
return False
result = self._client.execute(
psql.SQL("SELECT 1 FROM {} WHERE id = {} LIMIT 1").format(
psql.Identifier(self.table_name), psql.Literal(id)
),
fetch_result=True,
)
return bool(result)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
"""Get document IDs by metadata field key and value."""
result = self._client.execute(
psql.SQL("SELECT id FROM {} WHERE meta->>{} = {}").format(
psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value)
),
fetch_result=True,
)
if result:
return [row[0] for row in result]
return None
def delete_by_ids(self, ids: list[str]):
"""Delete documents by their doc_id list."""
if not ids:
return
if not self._client.check_table_exist(self.table_name):
return
self._client.execute(
psql.SQL("DELETE FROM {} WHERE id IN ({})").format(
psql.Identifier(self.table_name),
psql.SQL(", ").join(psql.Literal(id) for id in ids),
)
)
def delete_by_metadata_field(self, key: str, value: str):
"""Delete documents by metadata field key and value."""
if not self._client.check_table_exist(self.table_name):
return
self._client.execute(
psql.SQL("DELETE FROM {} WHERE meta->>{} = {}").format(
psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value)
)
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search for documents by vector similarity."""
if not self._client.check_table_exist(self.table_name):
return []
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
table = self._client.open_table(self.table_name)
query = (
table.search_vector(
vector=query_vector,
column="embedding",
distance_method=self._config.distance_method,
output_name="distance",
)
.select(["id", "text", "meta"])
.limit(top_k)
)
# Apply document_ids_filter if provided
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter_sql = psql.SQL("meta->>'document_id' IN ({})").format(
psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter)
)
query = query.where(filter_sql)
results = query.fetchall()
return self._process_vector_results(results, score_threshold)
def _process_vector_results(self, results: list, score_threshold: float) -> list[Document]:
"""Process vector search results into Document objects."""
docs = []
for row in results:
# row format: (distance, id, text, meta)
# distance is first because search_vector() adds the computed column before selected columns
distance = row[0]
text = row[2]
meta = row[3]
if isinstance(meta, str):
meta = json.loads(meta)
# Convert distance to similarity score (consistent with pgvector)
score = 1 - distance
meta["score"] = score
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=meta))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search for documents by full-text search."""
if not self._client.check_table_exist(self.table_name):
return []
top_k = kwargs.get("top_k", 4)
table = self._client.open_table(self.table_name)
search_query = table.search_text(
column="text",
expression=query,
return_score=True,
return_score_name="score",
return_all_columns=True,
).limit(top_k)
# Apply document_ids_filter if provided
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter_sql = psql.SQL("meta->>'document_id' IN ({})").format(
psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter)
)
search_query = search_query.where(filter_sql)
results = search_query.fetchall()
return self._process_full_text_results(results)
def _process_full_text_results(self, results: list) -> list[Document]:
"""Process full-text search results into Document objects."""
docs = []
for row in results:
# row format: (id, text, meta, embedding, score)
text = row[1]
meta = row[2]
score = row[-1] # score is the last column from return_score
if isinstance(meta, str):
meta = json.loads(meta)
meta["score"] = score
docs.append(Document(page_content=text, metadata=meta))
return docs
def delete(self):
"""Delete the entire collection table."""
if self._client.check_table_exist(self.table_name):
self._client.drop_table(self.table_name)
def _create_collection(self, dimension: int):
"""Create the collection table with vector and full-text indexes."""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
if not self._client.check_table_exist(self.table_name):
# Create table via SQL with CHECK constraint for vector dimension
create_table_sql = psql.SQL("""
CREATE TABLE IF NOT EXISTS {} (
id TEXT PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding float4[] NOT NULL
CHECK (array_ndims(embedding) = 1
AND array_length(embedding, 1) = {})
);
""").format(psql.Identifier(self.table_name), psql.Literal(dimension))
self._client.execute(create_table_sql)
# Wait for table to be fully ready before creating indexes
max_wait_seconds = 30
poll_interval = 2
for _ in range(max_wait_seconds // poll_interval):
if self._client.check_table_exist(self.table_name):
break
time.sleep(poll_interval)
else:
raise RuntimeError(f"Table {self.table_name} was not ready after {max_wait_seconds}s")
# Open table and set vector index
table = self._client.open_table(self.table_name)
table.set_vector_index(
column="embedding",
distance_method=self._config.distance_method,
base_quantization_type=self._config.base_quantization_type,
max_degree=self._config.max_degree,
ef_construction=self._config.ef_construction,
use_reorder=self._config.base_quantization_type == "rabitq",
)
# Create full-text search index
table.create_text_index(
index_name=f"ft_idx_{self._collection_name}",
column="text",
tokenizer=self._config.tokenizer,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class HologresVectorFactory(AbstractVectorFactory):
"""Factory class for creating HologresVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HologresVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.HOLOGRES, collection_name))
return HologresVector(
collection_name=collection_name,
config=HologresVectorConfig(
host=dify_config.HOLOGRES_HOST or "",
port=dify_config.HOLOGRES_PORT,
database=dify_config.HOLOGRES_DATABASE or "",
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
schema_name=dify_config.HOLOGRES_SCHEMA,
tokenizer=dify_config.HOLOGRES_TOKENIZER,
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD,
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE,
max_degree=dify_config.HOLOGRES_MAX_DEGREE,
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
),
)

View File

@ -135,8 +135,8 @@ class PGVectoRS(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client) as session:
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
result = session.execute(select_statement).fetchall()
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>:key = :value")
result = session.execute(select_statement, {"key": key, "value": value}).fetchall()
if result:
return [item[0] for item in result]
else:
@ -172,9 +172,9 @@ class PGVectoRS(BaseVector):
def text_exists(self, id: str) -> bool:
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = :doc_id limit 1"
)
result = session.execute(select_statement).fetchall()
result = session.execute(select_statement, {"doc_id": id}).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:

View File

@ -154,10 +154,8 @@ class RelytVector(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self.client) as session:
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
)
result = session.execute(select_statement).fetchall()
select_statement = sql_text(f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>:key = :value""")
result = session.execute(select_statement, {"key": key, "value": value}).fetchall()
if result:
return [item[0] for item in result]
else:
@ -201,11 +199,10 @@ class RelytVector(BaseVector):
def delete_by_ids(self, ids: list[str]):
with Session(self.client) as session:
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = ANY(:doc_ids)"""
)
result = session.execute(select_statement).fetchall()
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
if result:
ids = [item[0] for item in result]
self.delete_by_uuids(ids)
@ -218,9 +215,9 @@ class RelytVector(BaseVector):
def text_exists(self, id: str) -> bool:
with Session(self.client) as session:
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = :doc_id limit 1"""
)
result = session.execute(select_statement).fetchall()
result = session.execute(select_statement, {"doc_id": id}).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:

View File

@ -38,7 +38,7 @@ class AbstractVectorFactory(ABC):
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
@ -191,6 +191,10 @@ class Vector:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case VectorType.HOLOGRES:
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
return HologresVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -34,3 +34,4 @@ class VectorType(StrEnum):
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"
IRIS = "iris"
HOLOGRES = "hologres"

View File

@ -8,6 +8,7 @@ document embeddings used in retrieval-augmented generation workflows.
import datetime
import json
import logging
import threading
import uuid as _uuid
from typing import Any
from urllib.parse import urlparse
@ -32,6 +33,9 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
_weaviate_client: weaviate.WeaviateClient | None = None
_weaviate_client_lock = threading.Lock()
class WeaviateConfig(BaseModel):
"""
@ -99,43 +103,52 @@ class WeaviateVector(BaseVector):
Configures both HTTP and gRPC connections with proper authentication.
"""
p = urlparse(config.endpoint)
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
global _weaviate_client
if _weaviate_client and _weaviate_client.is_ready():
return _weaviate_client
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python versions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
with _weaviate_client_lock:
if _weaviate_client and _weaviate_client.is_ready():
return _weaviate_client
p = urlparse(config.endpoint)
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python versions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
http_port=http_port,
http_secure=http_secure,
grpc_host=grpc_host,
grpc_port=grpc_port,
grpc_secure=grpc_secure,
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
http_port=http_port,
http_secure=http_secure,
grpc_host=grpc_host,
grpc_port=grpc_port,
grpc_secure=grpc_secure,
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
)
if not client.is_ready():
raise ConnectionError("Vector database is not ready")
if not client.is_ready():
raise ConnectionError("Vector database is not ready")
return client
_weaviate_client = client
return client
def get_type(self) -> str:
"""Returns the vector database type identifier."""
@ -196,6 +209,7 @@ class WeaviateVector(BaseVector):
),
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_type", data_type=wc.DataType.TEXT),
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
],
vector_config=wc.Configure.Vectors.self_provided(),
@ -225,6 +239,8 @@ class WeaviateVector(BaseVector):
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
if "doc_id" not in existing:
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
if "doc_type" not in existing:
to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT))
if "chunk_index" not in existing:
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))

View File

@ -1,8 +1,18 @@
from pydantic import BaseModel
from typing_extensions import TypedDict
from models.dataset import DocumentSegment
class AttachmentInfoDict(TypedDict):
id: str
name: str
extension: str
mime_type: str
source_url: str
size: int
class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""
@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel):
segment: DocumentSegment
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
files: list[AttachmentInfoDict] | None = None
summary: str | None = None # Summary content if retrieved via summary index

View File

@ -9,8 +9,9 @@ from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from dify_graph.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory
@ -51,7 +52,7 @@ class IndexProcessor:
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
summary_index_setting: SummaryIndexSettingDict | None = None,
):
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
@ -131,7 +132,12 @@ class IndexProcessor:
}
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
self,
chunks: Any,
dataset_id: str,
document_id: str,
chunk_structure: str,
summary_index_setting: SummaryIndexSettingDict | None,
) -> Preview:
doc_language = None
with session_factory.create_session() as session:

View File

@ -7,14 +7,16 @@ import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, NotRequired, Optional
from urllib.parse import unquote, urlparse
import httpx
from typing_extensions import TypedDict
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
@ -35,6 +37,13 @@ if TYPE_CHECKING:
from core.model_manager import ModelInstance
class SummaryIndexSettingDict(TypedDict):
enable: bool
model_name: NotRequired[str]
model_provider_name: NotRequired[str]
summary_prompt: NotRequired[str]
class BaseIndexProcessor(ABC):
"""Interface for extract files."""
@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
raise NotImplementedError
@ -294,7 +303,7 @@ class BaseIndexProcessor(ABC):
logging.warning("Error downloading image from %s: %s", image_url, str(e))
return None
except Exception:
logging.exception("Unexpected error downloading image from %s", image_url)
logging.warning("Unexpected error downloading image from %s", image_url, exc_info=True)
return None
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:

View File

@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.provider_manager import ProviderManager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
@ -22,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""
@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def generate_summary(
tenant_id: str,
text: str,
summary_index_setting: dict | None = None,
summary_index_setting: SummaryIndexSettingDict | None = None,
segment_id: str | None = None,
document_language: str | None = None,
) -> tuple[str, LLMUsage]:

View File

@ -11,6 +11,7 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""

View File

@ -15,13 +15,14 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
):
# Set search parameters.
results = RetrievalService.retrieve(
@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor):
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
summary_index_setting: SummaryIndexSettingDict,
doc_language: str | None = None,
) -> list[PreviewDetail]:
"""

View File

@ -31,7 +31,7 @@ from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
@ -56,18 +56,18 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.knowledge_retrieval import exc
from dify_graph.repositories.rag_retrieval_protocol import (
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.nodes.knowledge_retrieval.retrieval import (
KnowledgeRetrievalRequest,
Source,
SourceChildChunk,
SourceMetadata,
)
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -83,7 +83,7 @@ from models.dataset import (
)
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, DatasetQuerySource
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
@ -727,8 +727,8 @@ class DatasetRetrieval:
top_k: int,
score_threshold: float,
reranking_mode: str,
reranking_model: dict | None = None,
weights: dict[str, Any] | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
reranking_enable: bool = True,
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
@ -1008,7 +1008,7 @@ class DatasetRetrieval:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source=DatasetQuerySource.APP,
source_app_id=app_id,
created_by_role=CreatorUserRole(user_from),
created_by=user_id,
@ -1181,8 +1181,8 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"],
reranking_model_name=retrieve_config.reranking_model["reranking_model_name"],
)
tools.append(tool)
@ -1685,8 +1685,8 @@ class DatasetRetrieval:
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
reranking_model: RerankingModelDict | None,
weights: WeightsDict | None,
top_k: int,
score_threshold: float,
query: str | None,

View File

@ -2,6 +2,7 @@ import concurrent.futures
import logging
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
@ -11,7 +12,11 @@ logger = logging.getLogger(__name__)
class SummaryIndex:
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
self,
dataset_id: str,
document_id: str,
is_preview: bool,
summary_index_setting: SummaryIndexSettingDict | None = None,
) -> None:
if is_preview:
with session_factory.create_session() as session:

View File

@ -18,7 +18,7 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att
from configs import dify_config
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
@ -146,7 +146,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
index=db_model.index,
predecessor_node_id=db_model.predecessor_node_id,
node_id=db_model.node_id,
node_type=NodeType(db_model.node_type),
node_type=db_model.node_type,
title=db_model.title,
inputs=inputs,
process_data=process_data,

View File

@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict):
controller: ApiToolProviderController
class EmojiIconDict(TypedDict):
background: str
content: str
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
@ -916,7 +921,7 @@ class ToolManager:
)
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
@ -933,7 +938,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
@ -950,7 +955,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
try:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
@ -970,7 +975,7 @@ class ToolManager:
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> str | Mapping[str, str]:
) -> str | EmojiIconDict | dict[str, str]:
"""
get the tool icon

View File

@ -116,6 +116,7 @@ class ToolParameterConfigurationManager:
return a deep copy of parameters with decrypted values
"""
parameters = self._deep_copy(parameters)
cache = ToolParameterCache(
tenant_id=self.tenant_id,

View File

@ -1,5 +1,4 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},

View File

@ -1,9 +1,10 @@
from typing import Any, cast
from typing import NotRequired, TypedDict, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
@ -16,7 +17,19 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModelDict
reranking_mode: NotRequired[str]
weights: NotRequired[WeightsDict | None]
score_threshold: NotRequired[float]
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
# use keyword table query

View File

@ -1,4 +1,5 @@
import re
from collections.abc import Mapping
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
@ -20,10 +21,18 @@ class InterfaceDict(TypedDict):
operation: dict[str, Any]
class OpenAPISpecDict(TypedDict):
openapi: str
info: dict[str, str]
servers: list[dict[str, Any]]
paths: dict[str, Any]
components: dict[str, Any]
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
) -> dict[str, Any]:
) -> OpenAPISpecDict:
warning = warning or {}
"""
parse swagger to openapi
@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser:
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
converted_openapi: dict[str, Any] = {
converted_openapi: OpenAPISpecDict = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),

View File

@ -3,7 +3,7 @@ from typing import Any
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from dify_graph.enums import NodeType
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.nodes.base.entities import OutputVariableEntity
from dify_graph.variables.input_entities import VariableEntity
@ -51,7 +51,7 @@ class WorkflowToolConfigurationUtils:
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
if node.get("data", {}).get("type") == BuiltinNodeTypes.HUMAN_INPUT:
raise WorkflowToolHumanInputNotSupportedError()
@classmethod

View File

@ -0,0 +1,18 @@
from typing import Final
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
{
TRIGGER_WEBHOOK_NODE_TYPE,
TRIGGER_SCHEDULE_NODE_TYPE,
TRIGGER_PLUGIN_NODE_TYPE,
}
)
def is_trigger_node_type(node_type: str) -> bool:
return node_type in TRIGGER_NODE_TYPES

Some files were not shown because too many files have changed in this diff Show More