mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into review-myscale-sqli
This commit is contained in:
commit
f8ad55212e
|
|
@ -187,53 +187,12 @@ const Template = useMemo(() => {
|
|||
|
||||
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||
|
||||
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: API logic in component
|
||||
const MCPServiceCard = () => {
|
||||
const [basicAppConfig, setBasicAppConfig] = useState({})
|
||||
|
||||
useEffect(() => {
|
||||
if (isBasicApp && appId) {
|
||||
(async () => {
|
||||
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||
setBasicAppConfig(res?.model_config || {})
|
||||
})()
|
||||
}
|
||||
}, [appId, isBasicApp])
|
||||
|
||||
// More API-related logic...
|
||||
}
|
||||
|
||||
// ✅ After: Extract to data hook using React Query
|
||||
// use-app-config.ts
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
|
||||
return useQuery({
|
||||
enabled: isBasicApp && !!appId,
|
||||
queryKey: [NAME_SPACE, 'detail', appId],
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || {},
|
||||
})
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const MCPServiceCard = () => {
|
||||
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
|
||||
// UI only
|
||||
}
|
||||
```
|
||||
|
||||
**React Query Best Practices in Dify**:
|
||||
- Define `NAME_SPACE` for query key organization
|
||||
- Use `enabled` option for conditional fetching
|
||||
- Use `select` for data transformation
|
||||
- Export invalidation hooks: `useInvalidXxx`
|
||||
**Dify Convention**:
|
||||
- This skill is for component decomposition, not query/mutation design.
|
||||
- When refactoring data fetching, follow `web/AGENTS.md`.
|
||||
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state.
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/service/use-workflow.ts`
|
||||
|
|
|
|||
|
|
@ -155,48 +155,14 @@ const Configuration: FC = () => {
|
|||
|
||||
## Common Hook Patterns in Dify
|
||||
|
||||
### 1. Data Fetching Hook (React Query)
|
||||
### 1. Data Fetching / Mutation Hooks
|
||||
|
||||
```typescript
|
||||
// Pattern: Use @tanstack/react-query for data fetching
|
||||
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns.
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
// Query keys for cache management
|
||||
export const appConfigQueryKeys = {
|
||||
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
|
||||
}
|
||||
|
||||
// Main data hook
|
||||
export const useAppConfig = (appId: string) => {
|
||||
return useQuery({
|
||||
enabled: !!appId,
|
||||
queryKey: appConfigQueryKeys.detail(appId),
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || null,
|
||||
})
|
||||
}
|
||||
|
||||
// Invalidation hook for refreshing data
|
||||
export const useInvalidAppConfig = () => {
|
||||
return useInvalid([NAME_SPACE])
|
||||
}
|
||||
|
||||
// Usage in component
|
||||
const Component = () => {
|
||||
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
|
||||
const invalidAppConfig = useInvalidAppConfig()
|
||||
|
||||
const handleRefresh = () => {
|
||||
invalidAppConfig() // Invalidates cache and triggers refetch
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
- Follow `web/AGENTS.md` first.
|
||||
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks.
|
||||
|
||||
### 2. Form State Hook
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep invalidation and mutation flow knowledge in the service layer.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind invalidation in the service-layer mutation definition.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
|
||||
- `web/contract/console/*.ts`
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- `web/service/use-*.ts`
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference.
|
||||
- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules.
|
||||
|
||||
Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs.
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query and oRPC patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# Contract Patterns
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Intent
|
||||
- Minimal structure
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`.
|
||||
- Use `base.route({...}).output(type<...>())` as the baseline.
|
||||
- Add `.input(type<...>())` only when the request has `params`, `query`, or `body`.
|
||||
- For `GET` without input, omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
2. Register contract in `web/contract/router.ts`.
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utilities.
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
- Input structure: always use `{ params, query?, body? }`.
|
||||
- No-input `GET`: omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
- Path params: use `{paramName}` in the path and match it in the `params` object.
|
||||
- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`.
|
||||
- No barrel files: import directly from specific files.
|
||||
- Types: import from `@/types/` and use the `type<T>()` helper.
|
||||
- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools.
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
# Runtime Rules
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
- Legacy migration
|
||||
|
||||
## Conditional Queries
|
||||
|
||||
Prefer contract-shaped `queryOptions(...)`.
|
||||
When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions.
|
||||
Use `enabled` only for extra business gating after the input itself is already valid.
|
||||
|
||||
```typescript
|
||||
import { skipToken, useQuery } from '@tanstack/react-query'
|
||||
|
||||
// Disable the query by skipping input construction.
|
||||
function useAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: appId
|
||||
? { params: { appId } }
|
||||
: skipToken,
|
||||
}))
|
||||
}
|
||||
|
||||
// Avoid runtime-only guards that bypass type checking.
|
||||
function useBadAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: { params: { appId: appId! } },
|
||||
enabled: !!appId,
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind invalidation in the service-layer mutation definition.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
|
||||
- `.key()` for namespace or prefix invalidation
|
||||
- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData`
|
||||
- `queryClient.invalidateQueries(...)` in mutation `onSuccess`
|
||||
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Service layer owns cache invalidation.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Component only adds UI behavior.
|
||||
updateAccessMode({ appId, mode }, {
|
||||
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
|
||||
})
|
||||
|
||||
// Avoid putting invalidation knowledge in the component.
|
||||
mutate({ appId, mode }, {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Key API Guide
|
||||
|
||||
- `.key(...)`
|
||||
- Use for partial matching operations.
|
||||
- Prefer it for invalidation, refetch, and cancel patterns.
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`
|
||||
- Use for a specific query's full key.
|
||||
- Prefer it for exact cache addressing and direct reads or writes.
|
||||
- `.mutationKey(...)`
|
||||
- Use for a specific mutation's full key.
|
||||
- Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping.
|
||||
|
||||
## `mutate` vs `mutateAsync`
|
||||
|
||||
Prefer `mutate` by default.
|
||||
Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies.
|
||||
|
||||
Rules:
|
||||
|
||||
- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`.
|
||||
- Every `await mutateAsync(...)` must be wrapped in `try/catch`.
|
||||
- Do not use `mutateAsync` when callbacks already express the flow clearly.
|
||||
|
||||
```typescript
|
||||
// Default case.
|
||||
mutation.mutate(data, {
|
||||
onSuccess: result => router.push(result.url),
|
||||
})
|
||||
|
||||
// Promise semantics are required.
|
||||
try {
|
||||
const order = await createOrder.mutateAsync(orderData)
|
||||
await confirmPayment.mutateAsync({ orderId: order.id, token })
|
||||
router.push(`/orders/${order.id}`)
|
||||
}
|
||||
catch (error) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
## Legacy Migration
|
||||
|
||||
When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
|
||||
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
|
|
@ -63,7 +63,8 @@ pnpm analyze-component <path> --review
|
|||
|
||||
### File Naming
|
||||
|
||||
- Test files: `ComponentName.spec.tsx` (same directory as component)
|
||||
- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory
|
||||
- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`.
|
||||
- Integration tests: `web/__tests__/` directory
|
||||
|
||||
## Test Structure Template
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event'
|
|||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||
// const mockPush = vi.fn()
|
||||
// vi.mock('next/navigation', () => ({
|
||||
// vi.mock('@/next/navigation', () => ({
|
||||
// useRouter: () => ({ push: mockPush }),
|
||||
// usePathname: () => '/test-path',
|
||||
// }))
|
||||
|
|
|
|||
|
|
@ -1,103 +0,0 @@
|
|||
---
|
||||
name: orpc-contract-first
|
||||
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service.
|
||||
---
|
||||
|
||||
# oRPC Contract-First Development
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as single source of truth in `web/contract/*`.
|
||||
- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`
|
||||
- Use `base.route({...}).output(type<...>())` as baseline.
|
||||
- Add `.input(type<...>())` only when request has `params/query/body`.
|
||||
- For `GET` without input, omit `.input(...)` (do not use `.input(type<unknown>())`).
|
||||
2. Register contract in `web/contract/router.ts`
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utils.
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default: call site directly uses `*.queryOptions(...)`.
|
||||
2. If 3+ call sites share the same extra options (for example `retry: false`), extract a small queryOptions helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration:
|
||||
- Combine multiple queries/mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default: call mutation helpers from `consoleQuery` / `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If mutation flow is heavily custom, use oRPC clients as `mutationFn` (for example `consoleClient.xxx` / `marketplaceClient.xxx`), instead of generic handwritten non-oRPC mutation logic.
|
||||
|
||||
## Key API Guide (`.key` vs `.queryKey` vs `.mutationKey`)
|
||||
|
||||
- `.key(...)`:
|
||||
- Use for partial matching operations (recommended for invalidation/refetch/cancel patterns).
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`:
|
||||
- Use for a specific query's full key (exact query identity / direct cache addressing).
|
||||
- `.mutationKey(...)`:
|
||||
- Use for a specific mutation's full key.
|
||||
- Typical use cases: mutation defaults registration, mutation-status filtering (`useIsMutating`, `queryClient.isMutating`), or explicit devtools grouping.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey/queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Reason: these patterns can degrade inference (`data` may become `unknown`, especially around `throwOnError`/`select`) and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
- **Input structure**: Always use `{ params, query?, body? }` format
|
||||
- **No-input GET**: Omit `.input(...)`; do not use `.input(type<unknown>())`
|
||||
- **Path params**: Use `{paramName}` in path, match in `params` object
|
||||
- **Router nesting**: Group by API prefix (e.g., `/billing/*` -> `billing: {}`)
|
||||
- **No barrel files**: Import directly from specific files
|
||||
- **Types**: Import from `@/types/`, use `type<T>()` helper
|
||||
- **Mutations**: Prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults/filtering/devtools
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
|
|
@ -0,0 +1 @@
|
|||
../../.agents/skills/frontend-query-mutation
|
||||
|
|
@ -1 +0,0 @@
|
|||
../../.agents/skills/orpc-contract-first
|
||||
|
|
@ -27,7 +27,7 @@ jobs:
|
|||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ jobs:
|
|||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ jobs:
|
|||
context: "web"
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-${{ matrix.context }}-*
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ jobs:
|
|||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
|
@ -69,7 +69,7 @@ jobs:
|
|||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ jobs:
|
|||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
|
|
@ -63,8 +63,9 @@ jobs:
|
|||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
with:
|
||||
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
|
||||
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
base_sha: ${{ github.event.before || github.event.pull_request.base.sha }}
|
||||
diff_range_mode: ${{ github.event.before && 'exact' || 'merge-base' }}
|
||||
head_sha: ${{ github.event.after || github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ jobs:
|
|||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ jobs:
|
|||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
|
||||
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ jobs:
|
|||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ on:
|
|||
base_sha:
|
||||
required: false
|
||||
type: string
|
||||
diff_range_mode:
|
||||
required: false
|
||||
type: string
|
||||
head_sha:
|
||||
required: false
|
||||
type: string
|
||||
|
|
@ -26,8 +29,8 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
shardIndex: [1, 2, 3, 4, 5, 6]
|
||||
shardTotal: [6]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
|
@ -77,7 +80,7 @@ jobs:
|
|||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
|
|
@ -86,13 +89,24 @@ jobs:
|
|||
- name: Merge reports
|
||||
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
|
||||
|
||||
- name: Check app/components diff coverage
|
||||
- name: Report app/components baseline coverage
|
||||
run: node ./scripts/report-components-coverage-baseline.mjs
|
||||
|
||||
- name: Report app/components test touch
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/report-components-test-touch.mjs
|
||||
|
||||
- name: Check app/components pure diff coverage
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/check-components-diff-coverage.mjs
|
||||
|
||||
- name: Coverage Summary
|
||||
- name: Check Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
|
|
@ -101,313 +115,15 @@ jobs:
|
|||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
if [ -f "$COVERAGE_FILE" ] || [ -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
let libCoverage = null;
|
||||
|
||||
try {
|
||||
libCoverage = require('istanbul-lib-coverage');
|
||||
} catch (error) {
|
||||
libCoverage = null;
|
||||
}
|
||||
|
||||
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||
|
||||
const hasSummary = fs.existsSync(summaryPath);
|
||||
const hasFinal = fs.existsSync(finalPath);
|
||||
|
||||
if (!hasSummary && !hasFinal) {
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('No coverage data found.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const summary = hasSummary
|
||||
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||
: null;
|
||||
const coverage = hasFinal
|
||||
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||
: null;
|
||||
|
||||
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||
const lineHits = {};
|
||||
|
||||
if (!statementMap || !statementHits) {
|
||||
return lineHits;
|
||||
}
|
||||
|
||||
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||
const line = statement?.start?.line;
|
||||
if (!line) {
|
||||
return;
|
||||
}
|
||||
const hits = statementHits[key] ?? 0;
|
||||
const previous = lineHits[line];
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||
});
|
||||
|
||||
return lineHits;
|
||||
};
|
||||
|
||||
const getFileCoverage = (entry) => (
|
||||
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||
);
|
||||
|
||||
const getLineHits = (entry, fileCoverage) => {
|
||||
const lineHits = entry.l ?? {};
|
||||
if (Object.keys(lineHits).length > 0) {
|
||||
return lineHits;
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getLineCoverage();
|
||||
}
|
||||
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||
};
|
||||
|
||||
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||
return Object.entries(lineHits)
|
||||
.filter(([, count]) => count === 0)
|
||||
.map(([line]) => Number(line))
|
||||
.sort((a, b) => a - b);
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getUncoveredLines();
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const totals = {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
};
|
||||
const fileSummaries = [];
|
||||
|
||||
if (summary) {
|
||||
const totalEntry = summary.total ?? {};
|
||||
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||
if (totalEntry[key]) {
|
||||
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||
totals[key].total = totalEntry[key].total ?? 0;
|
||||
}
|
||||
});
|
||||
|
||||
Object.entries(summary)
|
||||
.filter(([file]) => file !== 'total')
|
||||
.forEach(([file, data]) => {
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||
lines: {
|
||||
covered: data.lines?.covered ?? 0,
|
||||
total: data.lines?.total ?? 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
} else if (coverage) {
|
||||
Object.entries(coverage).forEach(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
totals.lines.total += lineTotal;
|
||||
totals.lines.covered += lineCovered;
|
||||
totals.statements.total += statementTotal;
|
||||
totals.statements.covered += statementCovered;
|
||||
totals.branches.total += branchTotal;
|
||||
totals.branches.covered += branchCovered;
|
||||
totals.functions.total += functionTotal;
|
||||
totals.functions.covered += functionCovered;
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||
lines: {
|
||||
covered: lineCovered || statementCovered,
|
||||
total: lineTotal || statementTotal,
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('| Metric | Coverage | Covered / Total |');
|
||||
console.log('|--------|----------|-----------------|');
|
||||
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||
console.log('');
|
||||
console.log('```');
|
||||
fileSummaries
|
||||
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||
.slice(0, 25)
|
||||
.forEach(({ file, pct, lines }) => {
|
||||
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||
});
|
||||
console.log('```');
|
||||
console.log('</details>');
|
||||
|
||||
if (coverage) {
|
||||
const pctValue = (covered, tot) => {
|
||||
if (tot === 0) {
|
||||
return '0';
|
||||
}
|
||||
return ((covered / tot) * 100)
|
||||
.toFixed(2)
|
||||
.replace(/\.?0+$/, '');
|
||||
};
|
||||
|
||||
const formatLineRanges = (lines) => {
|
||||
if (lines.length === 0) {
|
||||
return '';
|
||||
}
|
||||
const ranges = [];
|
||||
let start = lines[0];
|
||||
let end = lines[0];
|
||||
|
||||
for (let i = 1; i < lines.length; i += 1) {
|
||||
const current = lines[i];
|
||||
if (current === end + 1) {
|
||||
end = current;
|
||||
continue;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
start = current;
|
||||
end = current;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
return ranges.join(',');
|
||||
};
|
||||
|
||||
const tableTotals = {
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
lines: { covered: 0, total: 0 },
|
||||
};
|
||||
const tableRows = Object.entries(coverage)
|
||||
.map(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
tableTotals.lines.total += lineTotal;
|
||||
tableTotals.lines.covered += lineCovered;
|
||||
tableTotals.statements.total += statementTotal;
|
||||
tableTotals.statements.covered += statementCovered;
|
||||
tableTotals.branches.total += branchTotal;
|
||||
tableTotals.branches.covered += branchCovered;
|
||||
tableTotals.functions.total += functionTotal;
|
||||
tableTotals.functions.covered += functionCovered;
|
||||
|
||||
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||
|
||||
const filePath = entry.path ?? file;
|
||||
const relativePath = path.isAbsolute(filePath)
|
||||
? path.relative(process.cwd(), filePath)
|
||||
: filePath;
|
||||
|
||||
return {
|
||||
file: relativePath || file,
|
||||
statements: pctValue(statementCovered, statementTotal),
|
||||
branches: pctValue(branchCovered, branchTotal),
|
||||
functions: pctValue(functionCovered, functionTotal),
|
||||
lines: pctValue(lineCovered, lineTotal),
|
||||
uncovered: formatLineRanges(uncoveredLines),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => a.file.localeCompare(b.file));
|
||||
|
||||
const columns = [
|
||||
{ key: 'file', header: 'File', align: 'left' },
|
||||
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||
];
|
||||
|
||||
const allFilesRow = {
|
||||
file: 'All files',
|
||||
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||
uncovered: '',
|
||||
};
|
||||
|
||||
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||
const formatRow = (row) => `| ${columns
|
||||
.map(({ key }) => String(row[key] ?? ''))
|
||||
.join(' | ')} |`;
|
||||
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||
const dividerRow = `| ${columns
|
||||
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||
.join(' | ')} |`;
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>Vitest coverage table</summary>');
|
||||
console.log('');
|
||||
console.log(headerRow);
|
||||
console.log(dividerRow);
|
||||
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||
console.log('</details>');
|
||||
}
|
||||
NODE
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 app/components Diff Coverage" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage artifacts not found. Ensure Vitest merge reports ran with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
|
|
|
|||
|
|
@ -237,3 +237,6 @@ scripts/stress-test/reports/
|
|||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
|
|
@ -22,10 +22,10 @@ APP_WEB_URL=http://localhost:3000
|
|||
# Files URL
|
||||
FILES_URL=http://localhost:5001
|
||||
|
||||
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
|
||||
# Set this to the internal Docker service URL for proper plugin file access.
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
# INTERNAL_FILES_URL is used by services running in Docker to reach the API file endpoints.
|
||||
# For Docker Desktop (Mac/Windows), use http://host.docker.internal:5001 when the API runs on the host.
|
||||
# For Docker Compose on Linux, use http://api:5001 when the API runs inside the Docker network.
|
||||
INTERNAL_FILES_URL=http://host.docker.internal:5001
|
||||
|
||||
# TRIGGER URL
|
||||
TRIGGER_URL=http://localhost:5001
|
||||
|
|
@ -180,7 +180,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
|||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
|
|
@ -217,6 +217,20 @@ COUCHBASE_PASSWORD=password
|
|||
COUCHBASE_BUCKET_NAME=Embeddings
|
||||
COUCHBASE_SCOPE_NAME=_default
|
||||
|
||||
# Hologres configuration
|
||||
# access_key_id is used as the PG username, access_key_secret is used as the PG password
|
||||
HOLOGRES_HOST=
|
||||
HOLOGRES_PORT=80
|
||||
HOLOGRES_DATABASE=
|
||||
HOLOGRES_ACCESS_KEY_ID=
|
||||
HOLOGRES_ACCESS_KEY_SECRET=
|
||||
HOLOGRES_SCHEMA=public
|
||||
HOLOGRES_TOKENIZER=jieba
|
||||
HOLOGRES_DISTANCE_METHOD=Cosine
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE=rabitq
|
||||
HOLOGRES_MAX_DEGREE=64
|
||||
HOLOGRES_EF_CONSTRUCTION=400
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_URI=http://127.0.0.1:19530
|
||||
MILVUS_TOKEN=
|
||||
|
|
@ -723,24 +737,25 @@ SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
|||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||
|
||||
|
||||
# Redis URL used for PubSub between API and
|
||||
# Redis URL used for event bus between API and
|
||||
# celery worker
|
||||
# defaults to url constructed from `REDIS_*`
|
||||
# configurations
|
||||
PUBSUB_REDIS_URL=
|
||||
# Pub/sub channel type for streaming events.
|
||||
# valid options are:
|
||||
EVENT_BUS_REDIS_URL=
|
||||
# Event transport type. Options are:
|
||||
#
|
||||
# - pubsub: for normal Pub/Sub
|
||||
# - sharded: for sharded Pub/Sub
|
||||
# - pubsub: normal Pub/Sub (at-most-once)
|
||||
# - sharded: sharded Pub/Sub (at-most-once)
|
||||
# - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)
|
||||
#
|
||||
# It's highly recommended to use sharded Pub/Sub AND redis cluster
|
||||
# for large deployments.
|
||||
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while running
|
||||
# PubSub.
|
||||
# Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.
|
||||
# Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce
|
||||
# the risk of data loss from Redis auto-eviction under memory pressure.
|
||||
# Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE.
|
||||
EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
PUBSUB_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
|
|
|||
|
|
@ -96,7 +96,6 @@ ignore_imports =
|
|||
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
|
|
@ -104,7 +103,6 @@ ignore_imports =
|
|||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
|
|
@ -116,7 +114,6 @@ ignore_imports =
|
|||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
dify_graph.nodes.llm.node -> models.dataset
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.signature
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class UserProfile(TypedDict):
|
|||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
|||
|
||||
# Download nltk data
|
||||
RUN mkdir -p /usr/local/share/nltk_data \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
|
|
|||
|
|
@ -1,16 +1,45 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check.
|
||||
# Defined at module level to avoid per-request tuple construction.
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - features: billing/plan features (ProviderContextProvider)
|
||||
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
|
||||
# - workspaces/current: workspace + model providers (AppContextProvider)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/features",
|
||||
"/console/api/account/profile",
|
||||
"/console/api/workspaces/current",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
|
|
@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Enterprise license validation for API endpoints (both console and webapp)
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/")
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
if not is_exempt:
|
||||
try:
|
||||
# Check license status (cached — see EnterpriseService for TTL details)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
if license_status is None:
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check enterprise license status")
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
|
|
|
|||
|
|
@ -88,6 +88,8 @@ def clean_workflow_runs(
|
|||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
|
|
@ -104,16 +106,27 @@ def clean_workflow_runs(
|
|||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
task_label = f"{from_days_ago}to{to_days_ago}"
|
||||
elif start_from is None:
|
||||
task_label = f"before-{before_days}"
|
||||
else:
|
||||
task_label = "custom"
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
try:
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
).run()
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
|
@ -659,6 +672,8 @@ def clean_expired_messages(
|
|||
"""
|
||||
Clean expired messages and related data for tenants based on clean policy.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
|
|
@ -698,6 +713,13 @@ def clean_expired_messages(
|
|||
# NOTE: graceful_period will be ignored when billing is disabled.
|
||||
policy = create_message_clean_policy(graceful_period_days=graceful_period)
|
||||
|
||||
if from_days_ago is not None and before_days is not None:
|
||||
task_label = f"{from_days_ago}to{before_days}"
|
||||
elif start_from is None and before_days is not None:
|
||||
task_label = f"before-{before_days}"
|
||||
else:
|
||||
task_label = "custom"
|
||||
|
||||
# Create and run the cleanup service
|
||||
if abs_mode:
|
||||
assert start_from is not None
|
||||
|
|
@ -708,6 +730,7 @@ def clean_expired_messages(
|
|||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
|
|
@ -716,6 +739,7 @@ def clean_expired_messages(
|
|||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
|
|
@ -727,6 +751,7 @@ def clean_expired_messages(
|
|||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
|
|
@ -752,6 +777,8 @@ def clean_expired_messages(
|
|||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.rag.models.document import ChildDocument, Document
|
|||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
||||
|
||||
|
|
@ -160,6 +161,7 @@ def migrate_knowledge_vector_database():
|
|||
}
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.HOLOGRES,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.PGVECTO_RS,
|
||||
|
|
@ -241,7 +243,7 @@ def migrate_knowledge_vector_database():
|
|||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
|
|
@ -253,7 +255,7 @@ def migrate_knowledge_vector_database():
|
|||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.status == SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
|
|
@ -429,7 +431,7 @@ def old_metadata_migration():
|
|||
tenant_id=document.tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
name=key,
|
||||
type="string",
|
||||
type=DatasetMetadataType.STRING,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(dataset_metadata)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from .vdb.chroma_config import ChromaConfig
|
|||
from .vdb.clickzetta_config import ClickzettaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.hologres_config import HologresConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
from .vdb.iris_config import IrisVectorConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
|
|
@ -347,6 +348,7 @@ class MiddlewareConfig(
|
|||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HologresConfig,
|
||||
HuaweiCloudConfig,
|
||||
IrisVectorConfig,
|
||||
MilvusConfig,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
|
|
@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
|
|||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
"""Allow empty string in env/.env to mean 'unset' (None)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str) and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Literal, Protocol
|
||||
from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
|
|
@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
|
|||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
def _redis_defaults(config: object) -> RedisConfigDefaults:
|
||||
return cast(RedisConfigDefaults, config)
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
class RedisPubSubConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for event transport between API and workers.
|
||||
|
||||
|
|
@ -41,10 +38,10 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
|||
)
|
||||
|
||||
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_REDIS_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
|
||||
validation_alias=AliasChoices("EVENT_BUS_REDIS_USE_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"),
|
||||
description=(
|
||||
"Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. "
|
||||
"Also accepts ENV: EVENT_BUS_REDIS_CLUSTERS."
|
||||
"Also accepts ENV: EVENT_BUS_REDIS_USE_CLUSTERS."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
|
@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
|||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
|
|
@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
|||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HologresConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Hologres vector database.
|
||||
|
||||
Hologres is compatible with PostgreSQL protocol.
|
||||
access_key_id is used as the PostgreSQL username,
|
||||
and access_key_secret is used as the PostgreSQL password.
|
||||
"""
|
||||
|
||||
HOLOGRES_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the Hologres instance.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOLOGRES_PORT: int = Field(
|
||||
description="Port number for connecting to the Hologres instance.",
|
||||
default=80,
|
||||
)
|
||||
|
||||
HOLOGRES_DATABASE: str | None = Field(
|
||||
description="Name of the Hologres database to connect to.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOLOGRES_ACCESS_KEY_ID: str | None = Field(
|
||||
description="Alibaba Cloud AccessKey ID, also used as the PostgreSQL username.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOLOGRES_ACCESS_KEY_SECRET: str | None = Field(
|
||||
description="Alibaba Cloud AccessKey Secret, also used as the PostgreSQL password.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOLOGRES_SCHEMA: str = Field(
|
||||
description="Schema name in the Hologres database.",
|
||||
default="public",
|
||||
)
|
||||
|
||||
HOLOGRES_TOKENIZER: TokenizerType = Field(
|
||||
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
|
||||
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
|
||||
default="Cosine",
|
||||
)
|
||||
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
|
||||
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
|
||||
default="rabitq",
|
||||
)
|
||||
|
||||
HOLOGRES_MAX_DEGREE: int = Field(
|
||||
description="Max degree (M) parameter for HNSW vector index.",
|
||||
default=64,
|
||||
)
|
||||
|
||||
HOLOGRES_EF_CONSTRUCTION: int = Field(
|
||||
description="ef_construction parameter for HNSW vector index.",
|
||||
default=400,
|
||||
)
|
||||
|
|
@ -25,7 +25,8 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.enums import NodeType, WorkflowExecutionStatus
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -508,11 +509,7 @@ class AppListApi(Resource):
|
|||
.scalars()
|
||||
.all()
|
||||
)
|
||||
trigger_node_types = {
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
trigger_node_types = TRIGGER_NODE_TYPES
|
||||
for workflow in draft_workflows:
|
||||
node_id = None
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
from core.trigger.debug.event_selectors import (
|
||||
TriggerDebugEvent,
|
||||
TriggerDebugEventPoller,
|
||||
|
|
@ -1209,7 +1210,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
|
|||
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
|
||||
event: TriggerDebugEvent | None = None
|
||||
# for schedule trigger, when run single node, just execute directly
|
||||
if node_type == NodeType.TRIGGER_SCHEDULE:
|
||||
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
|
||||
event = TriggerDebugEvent(
|
||||
workflow_args={},
|
||||
node_id=node_id,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from dify_graph.variables.types import SegmentType
|
|||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
|
@ -100,6 +100,18 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
|||
}
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
variable: WorkflowDraftVariable | None,
|
||||
app_id: str,
|
||||
variable_id: str,
|
||||
) -> WorkflowDraftVariable:
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_id or variable.user_id != current_user.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
|
|
@ -238,6 +250,7 @@ class WorkflowVariableCollectionApi(Resource):
|
|||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
|
@ -250,7 +263,7 @@ class WorkflowVariableCollectionApi(Resource):
|
|||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||
draft_var_srv.delete_user_workflow_variables(app_model.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
|
@ -287,7 +300,7 @@ class NodeVariableCollectionApi(Resource):
|
|||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
|
|
@ -298,7 +311,7 @@ class NodeVariableCollectionApi(Resource):
|
|||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id)
|
||||
srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
|
@ -319,11 +332,11 @@ class VariableApi(Resource):
|
|||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_variable")
|
||||
|
|
@ -360,11 +373,11 @@ class VariableApi(Resource):
|
|||
)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
|
|
@ -397,11 +410,11 @@ class VariableApi(Resource):
|
|||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
|
@ -427,11 +440,11 @@ class VariableResetApi(Resource):
|
|||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
|
|
@ -447,11 +460,15 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
|||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||
draft_vars = draft_var_srv.list_node_variables(
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return draft_vars
|
||||
|
||||
|
||||
|
|
@ -472,7 +489,7 @@ class ConversationVariableCollectionApi(Resource):
|
|||
if draft_workflow is None:
|
||||
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ from fields.document_fields import document_status_fields
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
|
@ -263,6 +264,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
|||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
VectorType.IRIS,
|
||||
VectorType.HOLOGRES,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
|
|
@ -740,13 +742,15 @@ class DatasetIndexingStatusApi(Resource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
|
|
@ -332,13 +333,16 @@ class DatasetDocumentListApi(Resource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
|
|
@ -503,7 +507,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
|
|
@ -573,7 +577,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
match document.data_source_type:
|
||||
|
|
@ -671,19 +675,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
|
|
@ -720,20 +726,20 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
|
|
@ -955,7 +961,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||
|
||||
match action:
|
||||
case "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
if document.indexing_status != IndexingStatus.INDEXING:
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
|
|
@ -964,7 +970,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||
db.session.commit()
|
||||
|
||||
case "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
|
|
@ -1169,7 +1175,7 @@ class DocumentRetryApi(DocumentResource):
|
|||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
# 400 if document is completed
|
||||
if document.indexing_status == "completed":
|
||||
if document.indexing_status == IndexingStatus.COMPLETED:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource):
|
|||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
if pipeline_template is None:
|
||||
return {"error": "Pipeline template not found from upstream service."}, 404
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
|||
app_id=pipeline.id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
|
@ -111,7 +112,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
|||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(pipeline.id)
|
||||
draft_var_srv.delete_user_workflow_variables(pipeline.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
|
@ -144,7 +145,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
|||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
|
||||
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
|
|
@ -152,7 +153,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
|||
def delete(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(pipeline.id, node_id)
|
||||
srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
|
@ -283,11 +284,11 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
|||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from extensions.ext_database import db
|
|||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
KnowledgeConfig,
|
||||
|
|
@ -622,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
from collections.abc import Callable
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
|
|
@ -44,10 +44,22 @@ class FetchUserArg(BaseModel):
|
|||
required: bool = False
|
||||
|
||||
|
||||
def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@overload
|
||||
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token(
|
||||
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def validate_app_token(
|
||||
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
api_token = validate_and_get_api_token("app")
|
||||
|
||||
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
|
||||
|
|
@ -213,10 +225,20 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
|||
return interceptor
|
||||
|
||||
|
||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@overload
|
||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def validate_dataset_token(
|
||||
view: Callable[Concatenate[T, P], R] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
|
|
@ -287,7 +309,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
|||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
return decorated
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
|
|||
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook_debug(webhook_id: str):
|
||||
"""Handle webhook debug calls without triggering production workflow execution."""
|
||||
"""Handle webhook debug calls without triggering production workflow execution.
|
||||
|
||||
The debug webhook endpoint is only for draft inspection flows. It never enqueues
|
||||
Celery work for the published workflow; instead it dispatches an in-memory debug
|
||||
event to an active Variable Inspector listener. Returning a clear error when no
|
||||
listener is registered prevents a misleading 200 response for requests that are
|
||||
effectively dropped.
|
||||
"""
|
||||
try:
|
||||
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
|
||||
if error:
|
||||
|
|
@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
|
|||
"method": webhook_data.get("method"),
|
||||
},
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
dispatch_count = TriggerDebugEventBus.dispatch(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
event=event,
|
||||
pool_key=pool_key,
|
||||
)
|
||||
if dispatch_count == 0:
|
||||
logger.warning(
|
||||
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
|
||||
webhook_trigger.webhook_id,
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.app_id,
|
||||
webhook_trigger.node_id,
|
||||
)
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": "No active debug listener",
|
||||
"message": (
|
||||
"The webhook debug URL only works while the Variable Inspector is listening. "
|
||||
"Use the published webhook URL to execute the workflow in Celery."
|
||||
),
|
||||
"execution_url": webhook_trigger.webhook_url,
|
||||
}
|
||||
),
|
||||
409,
|
||||
)
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
|
|
|
|||
|
|
@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
|
|||
continue
|
||||
|
||||
result.append(self.organize_agent_user_prompt(message))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
agent_thoughts = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tool_names_raw = agent_thought.tool
|
||||
|
|
|
|||
|
|
@ -1,13 +1,36 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
|
||||
|
||||
class SystemParametersDict(TypedDict):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
class AppParametersDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: dict[str, Any]
|
||||
speech_to_text: dict[str, Any]
|
||||
text_to_speech: dict[str, Any]
|
||||
retriever_resource: dict[str, Any]
|
||||
annotation_reply: dict[str, Any]
|
||||
more_like_this: dict[str, Any]
|
||||
user_input_form: list[dict[str, Any]]
|
||||
sensitive_word_avoidance: dict[str, Any]
|
||||
file_upload: dict[str, Any]
|
||||
system_parameters: SystemParametersDict
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(
|
||||
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||
) -> Mapping[str, Any]:
|
||||
) -> AppParametersDict:
|
||||
"""
|
||||
Mapping from feature dict to webapp parameters
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from core.app.app_config.entities import (
|
|||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
|
@ -117,8 +118,10 @@ class DatasetConfigManager:
|
|||
score_threshold=float(score_threshold_val)
|
||||
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
|
||||
else None,
|
||||
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
|
||||
weights=weights_val if isinstance(weights_val, dict) else None,
|
||||
reranking_model=cast(RerankingModelDict, reranking_model_val)
|
||||
if isinstance(reranking_model_val, dict)
|
||||
else None,
|
||||
weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None,
|
||||
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
metadata_filtering_mode=cast(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Literal
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from dify_graph.file import FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
|
@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||
top_k: int | None = None
|
||||
score_threshold: float | None = 0.0
|
||||
rerank_mode: str | None = "reranking_model"
|
||||
reranking_model: dict | None = None
|
||||
weights: dict | None = None
|
||||
reranking_model: RerankingModelDict | None = None
|
||||
weights: WeightsDict | None = None
|
||||
reranking_enabled: bool | None = True
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
|
|
|
|||
|
|
@ -330,9 +330,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
|
|
@ -413,9 +414,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ from dify_graph.entities.pause_reason import HumanInputRequired
|
|||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
|
|
@ -357,7 +357,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node succeeded events."""
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
|
||||
if event.node_type in [BuiltinNodeTypes.ANSWER, BuiltinNodeTypes.END, BuiltinNodeTypes.LLM]:
|
||||
self._recorded_files.extend(
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
from typing import Any, NewType, TypedDict, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -48,12 +48,13 @@ from core.app.entities.task_entities import (
|
|||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
BuiltinNodeTypes,
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
|
|
@ -75,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountCreatedByDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class EndUserCreatedByDict(TypedDict):
|
||||
id: str
|
||||
user: str
|
||||
|
||||
|
||||
CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeSnapshot:
|
||||
"""In-memory cache for node metadata between start and completion events."""
|
||||
|
|
@ -248,19 +263,19 @@ class WorkflowResponseConverter:
|
|||
outputs_mapping = graph_runtime_state.outputs or {}
|
||||
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
|
||||
|
||||
created_by: Mapping[str, object] | None
|
||||
created_by: CreatedByDict | dict[str, object] = {}
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
created_by = AccountCreatedByDict(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
)
|
||||
elif isinstance(user, EndUser):
|
||||
created_by = EndUserCreatedByDict(
|
||||
id=user.id,
|
||||
user=user.session_id,
|
||||
)
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
|
|
@ -442,7 +457,7 @@ class WorkflowResponseConverter:
|
|||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
) -> NodeStartStreamResponse | None:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
snapshot = self._store_snapshot(event)
|
||||
|
|
@ -464,13 +479,13 @@ class WorkflowResponseConverter:
|
|||
)
|
||||
|
||||
try:
|
||||
if event.node_type == NodeType.TOOL:
|
||||
if event.node_type == BuiltinNodeTypes.TOOL:
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
elif event.node_type == BuiltinNodeTypes.DATASOURCE:
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
|
|
@ -479,7 +494,7 @@ class WorkflowResponseConverter:
|
|||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
elif event.node_type == TRIGGER_PLUGIN_NODE_TYPE:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
|
|
@ -496,7 +511,7 @@ class WorkflowResponseConverter:
|
|||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
) -> NodeFinishStreamResponse | None:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
snapshot = self._pop_snapshot(event.node_execution_id)
|
||||
|
|
@ -554,7 +569,7 @@ class WorkflowResponseConverter:
|
|||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
) -> NodeRetryStreamResponse | None:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
|
||||
|
|
@ -612,7 +627,7 @@ class WorkflowResponseConverter:
|
|||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
|
|
@ -635,7 +650,7 @@ class WorkflowResponseConverter:
|
|||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
created_at=int(time.time()),
|
||||
|
|
@ -662,7 +677,7 @@ class WorkflowResponseConverter:
|
|||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
outputs=new_outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
|
|
@ -692,7 +707,7 @@ class WorkflowResponseConverter:
|
|||
data=LoopNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
|
|
@ -715,7 +730,7 @@ class WorkflowResponseConverter:
|
|||
data=LoopNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
# The `pre_loop_output` field is not utilized by the frontend.
|
||||
|
|
@ -744,7 +759,7 @@ class WorkflowResponseConverter:
|
|||
data=LoopNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
outputs=new_outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
|
|
|
|||
|
|
@ -419,11 +419,12 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
@ -514,11 +515,12 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.graph_init_params import GraphInitParams
|
||||
from dify_graph.enums import WorkflowType
|
||||
|
|
@ -274,6 +274,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if start_node_id is None:
|
||||
start_node_id = get_default_root_node_id(graph_config)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
|
||||
|
||||
if not graph:
|
||||
|
|
|
|||
|
|
@ -414,11 +414,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
@ -497,11 +498,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
|
|
@ -140,6 +140,9 @@ class WorkflowBasedAppRunner:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
if root_node_id is None:
|
||||
root_node_id = get_default_root_node_id(graph_config)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
|
|
@ -505,7 +508,9 @@ class WorkflowBasedAppRunner:
|
|||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources,
|
||||
retriever_resources=[
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
],
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ from core.app.entities.agent_strategy import AgentStrategyInfo
|
|||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from dify_graph.nodes import NodeType
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
|
@ -43,7 +44,7 @@ class AnnotationReplyFeature:
|
|||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.conversation_variable_updater import ConversationVariableUpdater
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
||||
|
|
@ -22,7 +22,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
|||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunSucceededEvent):
|
||||
return
|
||||
if event.node_type != NodeType.VARIABLE_ASSIGNER:
|
||||
if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER:
|
||||
return
|
||||
if self.graph_runtime_state is None:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import TypedDict
|
||||
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
|
|
@ -6,7 +8,20 @@ from models.model import MessageFile, UploadFile
|
|||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
class MessageFileInfoDict(TypedDict):
|
||||
related_id: str
|
||||
extension: str
|
||||
filename: str
|
||||
size: int
|
||||
mime_type: str
|
||||
transfer_method: str
|
||||
type: str
|
||||
url: str
|
||||
upload_file_id: str
|
||||
remote_url: str | None
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing_extensions import override
|
|||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
|
|
@ -113,11 +113,11 @@ class LLMQuotaLayer(GraphEngineLayer):
|
|||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case NodeType.LLM:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
return cast("LLMNode", node).model_instance
|
||||
case NodeType.PARAMETER_EXTRACTOR:
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
return cast("ParameterExtractorNode", node).model_instance
|
||||
case NodeType.QUESTION_CLASSIFIER:
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
return cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_
|
|||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphNodeEventBase
|
||||
from dify_graph.nodes.base.node import Node
|
||||
|
|
@ -74,16 +74,13 @@ class ObservabilityLayer(GraphEngineLayer):
|
|||
def _build_parser_registry(self) -> None:
|
||||
"""Initialize parser registry for node types."""
|
||||
self._parsers = {
|
||||
NodeType.TOOL: ToolNodeOTelParser(),
|
||||
NodeType.LLM: LLMNodeOTelParser(),
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
|
||||
BuiltinNodeTypes.TOOL: ToolNodeOTelParser(),
|
||||
BuiltinNodeTypes.LLM: LLMNodeOTelParser(),
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
|
||||
}
|
||||
|
||||
def _get_parser(self, node: Node) -> NodeOTelParser:
|
||||
node_type = getattr(node, "node_type", None)
|
||||
if isinstance(node_type, NodeType):
|
||||
return self._parsers.get(node_type, self._default_parser)
|
||||
return self._default_parser
|
||||
return self._parsers.get(node.node_type, self._default_parser)
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from core.rag.models.document import Document
|
|||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class DatasetIndexToolCallbackHandler:
|
|||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source="app",
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=(
|
||||
CreatorUserRole.ACCOUNT
|
||||
|
|
|
|||
|
|
@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
|
|||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel
|
|||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
|
|
@ -473,9 +474,21 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
else:
|
||||
# some historical data may have a provider record but not be set as valid
|
||||
provider_record.is_valid = True
|
||||
|
||||
if provider_record.credential_id is None:
|
||||
provider_record.credential_id = new_record.id
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
|
|
@ -534,7 +547,7 @@ class ProviderConfiguration(BaseModel):
|
|||
self._update_load_balancing_configs_with_credential(
|
||||
credential_id=credential_id,
|
||||
credential_record=credential_record,
|
||||
credential_source="provider",
|
||||
credential_source=CredentialSourceType.PROVIDER,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -611,7 +624,7 @@ class ProviderConfiguration(BaseModel):
|
|||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER,
|
||||
)
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
try:
|
||||
|
|
@ -1031,7 +1044,7 @@ class ProviderConfiguration(BaseModel):
|
|||
self._update_load_balancing_configs_with_credential(
|
||||
credential_id=credential_id,
|
||||
credential_record=credential_record,
|
||||
credential_source="custom_model",
|
||||
credential_source=CredentialSourceType.CUSTOM_MODEL,
|
||||
session=session,
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -1061,7 +1074,7 @@ class ProviderConfiguration(BaseModel):
|
|||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL,
|
||||
)
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
|
||||
|
|
@ -1699,7 +1712,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_model_lb_configs = [
|
||||
config
|
||||
for config in model_setting.load_balancing_configs
|
||||
if config.credential_source_type != "custom_model"
|
||||
if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL
|
||||
]
|
||||
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
|
|
@ -1757,7 +1770,7 @@ class ProviderConfiguration(BaseModel):
|
|||
custom_model_lb_configs = [
|
||||
config
|
||||
for config in model_setting.load_balancing_configs
|
||||
if config.credential_source_type != "provider"
|
||||
if config.credential_source_type != CredentialSourceType.PROVIDER
|
||||
]
|
||||
|
||||
load_balancing_enabled = model_setting.load_balancing_enabled
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import re
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
|
@ -37,8 +38,9 @@ from extensions.ext_storage import storage
|
|||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
|
||||
from models.model import UploadFile
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -55,7 +57,7 @@ class IndexingRunner:
|
|||
logger.exception("consume document failed")
|
||||
document = db.session.get(DatasetDocument, document_id)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
error_message = getattr(error, "description", str(error))
|
||||
document.error = str(error_message)
|
||||
document.stopped_at = naive_utc_now()
|
||||
|
|
@ -218,7 +220,7 @@ class IndexingRunner:
|
|||
if document_segments:
|
||||
for document_segment in document_segments:
|
||||
# transform segment to node
|
||||
if document_segment.status != "completed":
|
||||
if document_segment.status != SegmentStatus.COMPLETED:
|
||||
document = Document(
|
||||
page_content=document_segment.content,
|
||||
metadata={
|
||||
|
|
@ -265,7 +267,7 @@ class IndexingRunner:
|
|||
self,
|
||||
tenant_id: str,
|
||||
extract_settings: list[ExtractSetting],
|
||||
tmp_processing_rule: dict,
|
||||
tmp_processing_rule: Mapping[str, Any],
|
||||
doc_form: str | None = None,
|
||||
doc_language: str = "English",
|
||||
dataset_id: str | None = None,
|
||||
|
|
@ -376,12 +378,12 @@ class IndexingRunner:
|
|||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: Mapping[str, Any]
|
||||
) -> list[Document]:
|
||||
data_source_info = dataset_document.data_source_info_dict
|
||||
text_docs = []
|
||||
match dataset_document.data_source_type:
|
||||
case "upload_file":
|
||||
case DataSourceType.UPLOAD_FILE:
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
|
|
@ -394,7 +396,7 @@ class IndexingRunner:
|
|||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "notion_import":
|
||||
case DataSourceType.NOTION_IMPORT:
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
|
|
@ -416,7 +418,7 @@ class IndexingRunner:
|
|||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||
case "website_crawl":
|
||||
case DataSourceType.WEBSITE_CRAWL:
|
||||
if (
|
||||
not data_source_info
|
||||
or "provider" not in data_source_info
|
||||
|
|
@ -444,7 +446,7 @@ class IndexingRunner:
|
|||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="splitting",
|
||||
after_indexing_status=IndexingStatus.SPLITTING,
|
||||
extra_update_params={
|
||||
DatasetDocument.parsing_completed_at: naive_utc_now(),
|
||||
},
|
||||
|
|
@ -543,7 +545,8 @@ class IndexingRunner:
|
|||
"""
|
||||
Clean the document text according to the processing rules.
|
||||
"""
|
||||
if processing_rule.mode == "automatic":
|
||||
rules: AutomaticRulesConfig | dict[str, Any]
|
||||
if processing_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
|
|
@ -634,7 +637,7 @@ class IndexingRunner:
|
|||
# update document status to completed
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="completed",
|
||||
after_indexing_status=IndexingStatus.COMPLETED,
|
||||
extra_update_params={
|
||||
DatasetDocument.tokens: tokens,
|
||||
DatasetDocument.completed_at: naive_utc_now(),
|
||||
|
|
@ -657,10 +660,10 @@ class IndexingRunner:
|
|||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == "indexing",
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
|
|
@ -701,10 +704,10 @@ class IndexingRunner:
|
|||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == "indexing",
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
|
|
@ -723,7 +726,7 @@ class IndexingRunner:
|
|||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
document_id: str, after_indexing_status: str, extra_update_params: dict | None = None
|
||||
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
|
||||
):
|
||||
"""
|
||||
Update the document indexing status.
|
||||
|
|
@ -756,7 +759,7 @@ class IndexingRunner:
|
|||
dataset: Dataset,
|
||||
text_docs: list[Document],
|
||||
doc_language: str,
|
||||
process_rule: dict,
|
||||
process_rule: Mapping[str, Any],
|
||||
current_user: Account | None = None,
|
||||
) -> list[Document]:
|
||||
# get embedding model instance
|
||||
|
|
@ -801,7 +804,7 @@ class IndexingRunner:
|
|||
cur_time = naive_utc_now()
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
after_indexing_status="indexing",
|
||||
after_indexing_status=IndexingStatus.INDEXING,
|
||||
extra_update_params={
|
||||
DatasetDocument.cleaning_completed_at: cur_time,
|
||||
DatasetDocument.splitting_completed_at: cur_time,
|
||||
|
|
@ -813,7 +816,7 @@ class IndexingRunner:
|
|||
self._update_segments_by_document(
|
||||
dataset_document_id=dataset_document.id,
|
||||
update_params={
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.status: SegmentStatus.INDEXING,
|
||||
DocumentSegment.indexing_at: naive_utc_now(),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -55,15 +55,31 @@ def build_protected_resource_metadata_discovery_urls(
|
|||
"""
|
||||
urls = []
|
||||
|
||||
parsed_server_url = urlparse(server_url)
|
||||
base_url = f"{parsed_server_url.scheme}://{parsed_server_url.netloc}"
|
||||
path = parsed_server_url.path.rstrip("/")
|
||||
|
||||
# First priority: URL from WWW-Authenticate header
|
||||
if www_auth_resource_metadata_url:
|
||||
urls.append(www_auth_resource_metadata_url)
|
||||
parsed_metadata_url = urlparse(www_auth_resource_metadata_url)
|
||||
normalized_metadata_url = None
|
||||
if parsed_metadata_url.scheme and parsed_metadata_url.netloc:
|
||||
normalized_metadata_url = www_auth_resource_metadata_url
|
||||
elif not parsed_metadata_url.scheme and parsed_metadata_url.netloc:
|
||||
normalized_metadata_url = f"{parsed_server_url.scheme}:{www_auth_resource_metadata_url}"
|
||||
elif (
|
||||
not parsed_metadata_url.scheme
|
||||
and not parsed_metadata_url.netloc
|
||||
and parsed_metadata_url.path.startswith("/")
|
||||
):
|
||||
first_segment = parsed_metadata_url.path.lstrip("/").split("/", 1)[0]
|
||||
if first_segment == ".well-known" or "." not in first_segment:
|
||||
normalized_metadata_url = urljoin(base_url, parsed_metadata_url.path)
|
||||
|
||||
if normalized_metadata_url:
|
||||
urls.append(normalized_metadata_url)
|
||||
|
||||
# Fallback: construct from server URL
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
|
||||
if path:
|
||||
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ from core.ops.entities.trace_entity import (
|
|||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
@ -302,11 +302,11 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
|
||||
):
|
||||
try:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
|
||||
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
|
||||
|
|
|
|||
|
|
@ -155,8 +155,8 @@ def wrap_span_metadata(metadata, **kwargs):
|
|||
return metadata
|
||||
|
||||
|
||||
# Mapping from NodeType string values to OpenInference span kinds.
|
||||
# NodeType values not listed here default to CHAIN.
|
||||
# Mapping from built-in node type strings to OpenInference span kinds.
|
||||
# Node types not listed here default to CHAIN.
|
||||
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
|
||||
"llm": OpenInferenceSpanKindValues.LLM,
|
||||
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
|
||||
|
|
@ -168,7 +168,7 @@ _NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
|
|||
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
|
||||
"""Return the OpenInference span kind for a given workflow node type.
|
||||
|
||||
Covers every ``NodeType`` enum value. Nodes that do not have a
|
||||
Covers every built-in node type string. Nodes that do not have a
|
||||
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
|
||||
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
|||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
|
@ -141,7 +141,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == NodeType.LLM:
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == NodeType.LLM:
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
|
|
@ -197,7 +197,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
run_type = LangSmithRunType.retriever
|
||||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
|
|||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
|
@ -145,10 +145,10 @@ class MLflowDataTrace(BaseTraceInstance):
|
|||
"app_name": node.title,
|
||||
}
|
||||
|
||||
if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER):
|
||||
if node.node_type in (BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER):
|
||||
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
|
||||
attributes.update(llm_attributes)
|
||||
elif node.node_type == NodeType.HTTP_REQUEST:
|
||||
elif node.node_type == BuiltinNodeTypes.HTTP_REQUEST:
|
||||
inputs = node.process_data # contains request URL
|
||||
|
||||
if not inputs:
|
||||
|
|
@ -180,9 +180,9 @@ class MLflowDataTrace(BaseTraceInstance):
|
|||
# End node span
|
||||
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
|
||||
outputs = json.loads(node.outputs) if node.outputs else {}
|
||||
if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
outputs = self._parse_knowledge_retrieval_outputs(outputs)
|
||||
elif node.node_type == NodeType.LLM:
|
||||
elif node.node_type == BuiltinNodeTypes.LLM:
|
||||
outputs = outputs.get("text", outputs)
|
||||
node_span.end(
|
||||
outputs=outputs,
|
||||
|
|
@ -471,13 +471,13 @@ class MLflowDataTrace(BaseTraceInstance):
|
|||
def _get_node_span_type(self, node_type: str) -> str:
|
||||
"""Map Dify node types to MLflow span types"""
|
||||
node_type_mapping = {
|
||||
NodeType.LLM: SpanType.LLM,
|
||||
NodeType.QUESTION_CLASSIFIER: SpanType.LLM,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
|
||||
NodeType.TOOL: SpanType.TOOL,
|
||||
NodeType.CODE: SpanType.TOOL,
|
||||
NodeType.HTTP_REQUEST: SpanType.TOOL,
|
||||
NodeType.AGENT: SpanType.AGENT,
|
||||
BuiltinNodeTypes.LLM: SpanType.LLM,
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: SpanType.LLM,
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
|
||||
BuiltinNodeTypes.TOOL: SpanType.TOOL,
|
||||
BuiltinNodeTypes.CODE: SpanType.TOOL,
|
||||
BuiltinNodeTypes.HTTP_REQUEST: SpanType.TOOL,
|
||||
BuiltinNodeTypes.AGENT: SpanType.AGENT,
|
||||
}
|
||||
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
|
|||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
@ -187,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == NodeType.LLM:
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
|||
from dify_graph.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes import BuiltinNodeTypes
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
@ -179,7 +179,7 @@ class TencentDataTrace(BaseTraceInstance):
|
|||
if node_span:
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
self._record_llm_metrics(node_execution)
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
|
||||
|
|
@ -192,15 +192,15 @@ class TencentDataTrace(BaseTraceInstance):
|
|||
) -> SpanData | None:
|
||||
"""Build span for different node types"""
|
||||
try:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
if node_execution.node_type == BuiltinNodeTypes.LLM:
|
||||
return TencentSpanBuilder.build_workflow_llm_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
return TencentSpanBuilder.build_workflow_retrieval_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
|
||||
return TencentSpanBuilder.build_workflow_tool_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from core.ops.entities.trace_entity import (
|
|||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
@ -175,7 +175,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == NodeType.LLM:
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
|
|
@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||
instruction=instruction, # instruct with variables are not supported
|
||||
)
|
||||
node_data_dict = node_data.model_dump()
|
||||
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
|
||||
node_data_dict["type"] = BuiltinNodeTypes.PARAMETER_EXTRACTOR
|
||||
execution = workflow_service.run_free_workflow_node(
|
||||
node_data_dict,
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -196,6 +196,8 @@ class ProviderManager:
|
|||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
elif system_configuration.enabled:
|
||||
|
|
@ -305,9 +307,7 @@ class ProviderManager:
|
|||
available_models = provider_configurations.get_models(model_type=model_type, only_active=True)
|
||||
|
||||
if available_models:
|
||||
available_model = next(
|
||||
(model for model in available_models if model.model == "gpt-4"), available_models[0]
|
||||
)
|
||||
available_model = available_models[0]
|
||||
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
|
|
@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
|||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class RerankingModelDict(TypedDict):
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSettingDict(TypedDict):
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSettingDict(TypedDict):
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightsDict(TypedDict):
|
||||
vector_setting: VectorSettingDict
|
||||
keyword_setting: KeywordSettingDict
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
"""Interface for data post-processing document."""
|
||||
|
||||
|
|
@ -17,8 +39,8 @@ class DataPostProcessor:
|
|||
self,
|
||||
tenant_id: str,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
reorder_enabled: bool = False,
|
||||
):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
|
||||
|
|
@ -45,8 +67,8 @@ class DataPostProcessor:
|
|||
self,
|
||||
reranking_mode: str,
|
||||
tenant_id: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
) -> BaseRerankRunner | None:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
|
||||
runner = RerankRunnerFactory.create_rerank_runner(
|
||||
|
|
@ -79,12 +101,14 @@ class DataPostProcessor:
|
|||
return ReorderRunner()
|
||||
return None
|
||||
|
||||
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
|
||||
def _get_rerank_model_instance(
|
||||
self, tenant_id: str, reranking_model: RerankingModelDict | None
|
||||
) -> ModelInstance | None:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
reranking_provider_name = reranking_model["reranking_provider_name"]
|
||||
reranking_model_name = reranking_model["reranking_model_name"]
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
import concurrent.futures
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
|
|
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
|
|||
from models.model import UploadFile
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
|
||||
class SegmentAttachmentResult(TypedDict):
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class SegmentAttachmentInfoResult(TypedDict):
|
||||
attachment_id: str
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class ChildChunkDetail(TypedDict):
|
||||
id: str
|
||||
content: str
|
||||
position: int
|
||||
score: float
|
||||
|
||||
|
||||
class SegmentChildMapDetail(TypedDict):
|
||||
max_score: float
|
||||
child_chunks: list[ChildChunkDetail]
|
||||
|
||||
|
||||
class SegmentRecord(TypedDict):
|
||||
segment: DocumentSegment
|
||||
score: NotRequired[float]
|
||||
child_chunks: NotRequired[list[ChildChunkDetail]]
|
||||
files: NotRequired[list[AttachmentInfoDict]]
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod | str
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
|
@ -56,9 +96,9 @@ class RetrievalService:
|
|||
query: str,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_ids: list | None = None,
|
||||
):
|
||||
|
|
@ -235,7 +275,7 @@ class RetrievalService:
|
|||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
|
|
@ -277,8 +317,8 @@ class RetrievalService:
|
|||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
|
|
@ -288,8 +328,8 @@ class RetrievalService:
|
|||
model_manager = ModelManager()
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model.get("reranking_provider_name") or "",
|
||||
model=reranking_model.get("reranking_model_name") or "",
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
model=reranking_model["reranking_model_name"],
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if is_support_vision:
|
||||
|
|
@ -329,7 +369,7 @@ class RetrievalService:
|
|||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
|
|
@ -349,8 +389,8 @@ class RetrievalService:
|
|||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
|
|
@ -459,7 +499,7 @@ class RetrievalService:
|
|||
segment_ids: list[str] = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
|
|
@ -544,12 +584,12 @@ class RetrievalService:
|
|||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
segment_child_map: dict[str, SegmentChildMapDetail] = {}
|
||||
records: list[SegmentRecord] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
|
|
@ -560,14 +600,14 @@ class RetrievalService:
|
|||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
child_chunk_details: list[ChildChunkDetail] = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
if child_document:
|
||||
child_score = child_document.metadata.get("score", 0.0)
|
||||
else:
|
||||
child_score = 0.0
|
||||
child_chunk_detail = {
|
||||
child_chunk_detail: ChildChunkDetail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
|
|
@ -580,7 +620,7 @@ class RetrievalService:
|
|||
if file_document:
|
||||
max_score = max(max_score, file_document.metadata.get("score", 0.0))
|
||||
|
||||
map_detail = {
|
||||
map_detail: SegmentChildMapDetail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
|
|
@ -593,7 +633,7 @@ class RetrievalService:
|
|||
"max_score": summary_score,
|
||||
"child_chunks": [],
|
||||
}
|
||||
record: dict[str, Any] = {
|
||||
record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
|
|
@ -617,19 +657,19 @@ class RetrievalService:
|
|||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
|
||||
record = {
|
||||
another_record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
}
|
||||
records.append(record)
|
||||
records.append(another_record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
record["files"] = attachment_map[record["segment"].id]
|
||||
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
|
|
@ -693,9 +733,9 @@ class RetrievalService:
|
|||
query: str | None = None,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_id: str | None = None,
|
||||
):
|
||||
|
|
@ -807,7 +847,7 @@ class RetrievalService:
|
|||
@classmethod
|
||||
def get_segment_attachment_info(
|
||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||
) -> dict[str, Any] | None:
|
||||
) -> SegmentAttachmentResult | None:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||
if upload_file:
|
||||
attachment_binding = (
|
||||
|
|
@ -816,7 +856,7 @@ class RetrievalService:
|
|||
.first()
|
||||
)
|
||||
if attachment_binding:
|
||||
attachment_info = {
|
||||
attachment_info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
|
|
@ -828,8 +868,10 @@ class RetrievalService:
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||
attachment_infos = []
|
||||
def get_segment_attachment_infos(
|
||||
cls, attachment_ids: list[str], session: Session
|
||||
) -> list[SegmentAttachmentInfoResult]:
|
||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
|
|
@ -843,7 +885,7 @@ class RetrievalService:
|
|||
if attachment_bindings:
|
||||
for upload_file in upload_files:
|
||||
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||
attachment_info = {
|
||||
info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
|
|
@ -855,7 +897,7 @@ class RetrievalService:
|
|||
attachment_infos.append(
|
||||
{
|
||||
"attachment_id": attachment_binding.attachment_id,
|
||||
"attachment_info": attachment_info,
|
||||
"attachment_info": info,
|
||||
"segment_id": attachment_binding.segment_id,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,361 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import holo_search_sdk as holo # type: ignore
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from psycopg import sql as psql
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HologresVectorConfig(BaseModel):
|
||||
"""
|
||||
Configuration for Hologres vector database connection.
|
||||
|
||||
In Hologres, access_key_id is used as the PostgreSQL username,
|
||||
and access_key_secret is used as the PostgreSQL password.
|
||||
"""
|
||||
|
||||
host: str
|
||||
port: int = 80
|
||||
database: str
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
schema_name: str = "public"
|
||||
tokenizer: TokenizerType = "jieba"
|
||||
distance_method: DistanceType = "Cosine"
|
||||
base_quantization_type: BaseQuantizationType = "rabitq"
|
||||
max_degree: int = 64
|
||||
ef_construction: int = 400
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config HOLOGRES_HOST is required")
|
||||
if not values.get("database"):
|
||||
raise ValueError("config HOLOGRES_DATABASE is required")
|
||||
if not values.get("access_key_id"):
|
||||
raise ValueError("config HOLOGRES_ACCESS_KEY_ID is required")
|
||||
if not values.get("access_key_secret"):
|
||||
raise ValueError("config HOLOGRES_ACCESS_KEY_SECRET is required")
|
||||
return values
|
||||
|
||||
|
||||
class HologresVector(BaseVector):
|
||||
"""
|
||||
Hologres vector storage implementation using holo-search-sdk.
|
||||
|
||||
Supports semantic search (vector), full-text search, and hybrid search.
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str, config: HologresVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._client = self._init_client(config)
|
||||
self.table_name = f"embedding_{collection_name}".lower()
|
||||
|
||||
def _init_client(self, config: HologresVectorConfig):
|
||||
"""Initialize and return a holo-search-sdk client."""
|
||||
client = holo.connect(
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
database=config.database,
|
||||
access_key_id=config.access_key_id,
|
||||
access_key_secret=config.access_key_secret,
|
||||
schema=config.schema_name,
|
||||
)
|
||||
client.connect()
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.HOLOGRES
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Create collection table with vector and full-text indexes, then add texts."""
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Add texts with embeddings to the collection using batch upsert."""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
pks: list[str] = []
|
||||
batch_size = 100
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
|
||||
values = []
|
||||
column_names = ["id", "text", "meta", "embedding"]
|
||||
|
||||
for j, doc in enumerate(batch_docs):
|
||||
doc_id = doc.metadata.get("doc_id", "") if doc.metadata else ""
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
[
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata or {}),
|
||||
batch_embeddings[j],
|
||||
]
|
||||
)
|
||||
|
||||
table = self._client.open_table(self.table_name)
|
||||
table.upsert_multi(
|
||||
index_column="id",
|
||||
values=values,
|
||||
column_names=column_names,
|
||||
update=True,
|
||||
update_columns=["text", "meta", "embedding"],
|
||||
)
|
||||
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""Check if a text with the given doc_id exists in the collection."""
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
return False
|
||||
|
||||
result = self._client.execute(
|
||||
psql.SQL("SELECT 1 FROM {} WHERE id = {} LIMIT 1").format(
|
||||
psql.Identifier(self.table_name), psql.Literal(id)
|
||||
),
|
||||
fetch_result=True,
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
|
||||
"""Get document IDs by metadata field key and value."""
|
||||
result = self._client.execute(
|
||||
psql.SQL("SELECT id FROM {} WHERE meta->>{} = {}").format(
|
||||
psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value)
|
||||
),
|
||||
fetch_result=True,
|
||||
)
|
||||
if result:
|
||||
return [row[0] for row in result]
|
||||
return None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
"""Delete documents by their doc_id list."""
|
||||
if not ids:
|
||||
return
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
return
|
||||
|
||||
self._client.execute(
|
||||
psql.SQL("DELETE FROM {} WHERE id IN ({})").format(
|
||||
psql.Identifier(self.table_name),
|
||||
psql.SQL(", ").join(psql.Literal(id) for id in ids),
|
||||
)
|
||||
)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""Delete documents by metadata field key and value."""
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
return
|
||||
|
||||
self._client.execute(
|
||||
psql.SQL("DELETE FROM {} WHERE meta->>{} = {}").format(
|
||||
psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value)
|
||||
)
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by vector similarity."""
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
table = self._client.open_table(self.table_name)
|
||||
query = (
|
||||
table.search_vector(
|
||||
vector=query_vector,
|
||||
column="embedding",
|
||||
distance_method=self._config.distance_method,
|
||||
output_name="distance",
|
||||
)
|
||||
.select(["id", "text", "meta"])
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Apply document_ids_filter if provided
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter_sql = psql.SQL("meta->>'document_id' IN ({})").format(
|
||||
psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter)
|
||||
)
|
||||
query = query.where(filter_sql)
|
||||
|
||||
results = query.fetchall()
|
||||
return self._process_vector_results(results, score_threshold)
|
||||
|
||||
def _process_vector_results(self, results: list, score_threshold: float) -> list[Document]:
|
||||
"""Process vector search results into Document objects."""
|
||||
docs = []
|
||||
for row in results:
|
||||
# row format: (distance, id, text, meta)
|
||||
# distance is first because search_vector() adds the computed column before selected columns
|
||||
distance = row[0]
|
||||
text = row[2]
|
||||
meta = row[3]
|
||||
|
||||
if isinstance(meta, str):
|
||||
meta = json.loads(meta)
|
||||
|
||||
# Convert distance to similarity score (consistent with pgvector)
|
||||
score = 1 - distance
|
||||
meta["score"] = score
|
||||
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=meta))
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by full-text search."""
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
|
||||
table = self._client.open_table(self.table_name)
|
||||
search_query = table.search_text(
|
||||
column="text",
|
||||
expression=query,
|
||||
return_score=True,
|
||||
return_score_name="score",
|
||||
return_all_columns=True,
|
||||
).limit(top_k)
|
||||
|
||||
# Apply document_ids_filter if provided
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filter_sql = psql.SQL("meta->>'document_id' IN ({})").format(
|
||||
psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter)
|
||||
)
|
||||
search_query = search_query.where(filter_sql)
|
||||
|
||||
results = search_query.fetchall()
|
||||
return self._process_full_text_results(results)
|
||||
|
||||
def _process_full_text_results(self, results: list) -> list[Document]:
|
||||
"""Process full-text search results into Document objects."""
|
||||
docs = []
|
||||
for row in results:
|
||||
# row format: (id, text, meta, embedding, score)
|
||||
text = row[1]
|
||||
meta = row[2]
|
||||
score = row[-1] # score is the last column from return_score
|
||||
|
||||
if isinstance(meta, str):
|
||||
meta = json.loads(meta)
|
||||
|
||||
meta["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=meta))
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
"""Delete the entire collection table."""
|
||||
if self._client.check_table_exist(self.table_name):
|
||||
self._client.drop_table(self.table_name)
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
"""Create the collection table with vector and full-text indexes."""
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
if not self._client.check_table_exist(self.table_name):
|
||||
# Create table via SQL with CHECK constraint for vector dimension
|
||||
create_table_sql = psql.SQL("""
|
||||
CREATE TABLE IF NOT EXISTS {} (
|
||||
id TEXT PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
embedding float4[] NOT NULL
|
||||
CHECK (array_ndims(embedding) = 1
|
||||
AND array_length(embedding, 1) = {})
|
||||
);
|
||||
""").format(psql.Identifier(self.table_name), psql.Literal(dimension))
|
||||
self._client.execute(create_table_sql)
|
||||
|
||||
# Wait for table to be fully ready before creating indexes
|
||||
max_wait_seconds = 30
|
||||
poll_interval = 2
|
||||
for _ in range(max_wait_seconds // poll_interval):
|
||||
if self._client.check_table_exist(self.table_name):
|
||||
break
|
||||
time.sleep(poll_interval)
|
||||
else:
|
||||
raise RuntimeError(f"Table {self.table_name} was not ready after {max_wait_seconds}s")
|
||||
|
||||
# Open table and set vector index
|
||||
table = self._client.open_table(self.table_name)
|
||||
table.set_vector_index(
|
||||
column="embedding",
|
||||
distance_method=self._config.distance_method,
|
||||
base_quantization_type=self._config.base_quantization_type,
|
||||
max_degree=self._config.max_degree,
|
||||
ef_construction=self._config.ef_construction,
|
||||
use_reorder=self._config.base_quantization_type == "rabitq",
|
||||
)
|
||||
|
||||
# Create full-text search index
|
||||
table.create_text_index(
|
||||
index_name=f"ft_idx_{self._collection_name}",
|
||||
column="text",
|
||||
tokenizer=self._config.tokenizer,
|
||||
)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class HologresVectorFactory(AbstractVectorFactory):
|
||||
"""Factory class for creating HologresVector instances."""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HologresVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.HOLOGRES, collection_name))
|
||||
|
||||
return HologresVector(
|
||||
collection_name=collection_name,
|
||||
config=HologresVectorConfig(
|
||||
host=dify_config.HOLOGRES_HOST or "",
|
||||
port=dify_config.HOLOGRES_PORT,
|
||||
database=dify_config.HOLOGRES_DATABASE or "",
|
||||
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
|
||||
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
|
||||
schema_name=dify_config.HOLOGRES_SCHEMA,
|
||||
tokenizer=dify_config.HOLOGRES_TOKENIZER,
|
||||
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD,
|
||||
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE,
|
||||
max_degree=dify_config.HOLOGRES_MAX_DEGREE,
|
||||
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
|
||||
),
|
||||
)
|
||||
|
|
@ -135,8 +135,8 @@ class PGVectoRS(BaseVector):
|
|||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = None
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
|
||||
result = session.execute(select_statement).fetchall()
|
||||
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>:key = :value")
|
||||
result = session.execute(select_statement, {"key": key, "value": value}).fetchall()
|
||||
if result:
|
||||
return [item[0] for item in result]
|
||||
else:
|
||||
|
|
@ -172,9 +172,9 @@ class PGVectoRS(BaseVector):
|
|||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = :doc_id limit 1"
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
result = session.execute(select_statement, {"doc_id": id}).fetchall()
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
|
|
|||
|
|
@ -154,10 +154,8 @@ class RelytVector(BaseVector):
|
|||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = None
|
||||
with Session(self.client) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
select_statement = sql_text(f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>:key = :value""")
|
||||
result = session.execute(select_statement, {"key": key, "value": value}).fetchall()
|
||||
if result:
|
||||
return [item[0] for item in result]
|
||||
else:
|
||||
|
|
@ -201,11 +199,10 @@ class RelytVector(BaseVector):
|
|||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
with Session(self.client) as session:
|
||||
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = ANY(:doc_ids)"""
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
self.delete_by_uuids(ids)
|
||||
|
|
@ -218,9 +215,9 @@ class RelytVector(BaseVector):
|
|||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self.client) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = :doc_id limit 1"""
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
result = session.execute(select_statement, {"doc_id": id}).fetchall()
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class AbstractVectorFactory(ABC):
|
|||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
|
|
@ -191,6 +191,10 @@ class Vector:
|
|||
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
|
||||
|
||||
return IrisVectorFactory
|
||||
case VectorType.HOLOGRES:
|
||||
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
|
||||
|
||||
return HologresVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
|
|
|||
|
|
@ -34,3 +34,4 @@ class VectorType(StrEnum):
|
|||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
IRIS = "iris"
|
||||
HOLOGRES = "hologres"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ document embeddings used in retrieval-augmented generation workflows.
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid as _uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
|
@ -32,6 +33,9 @@ from models.dataset import Dataset
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
|
|
@ -99,43 +103,52 @@ class WeaviateVector(BaseVector):
|
|||
|
||||
Configures both HTTP and gRPC connections with proper authentication.
|
||||
"""
|
||||
p = urlparse(config.endpoint)
|
||||
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
|
||||
http_secure = p.scheme == "https"
|
||||
http_port = p.port or (443 if http_secure else 80)
|
||||
global _weaviate_client
|
||||
if _weaviate_client and _weaviate_client.is_ready():
|
||||
return _weaviate_client
|
||||
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python versions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
with _weaviate_client_lock:
|
||||
if _weaviate_client and _weaviate_client.is_ready():
|
||||
return _weaviate_client
|
||||
|
||||
p = urlparse(config.endpoint)
|
||||
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
|
||||
http_secure = p.scheme == "https"
|
||||
http_port = p.port or (443 if http_secure else 80)
|
||||
|
||||
# Parse gRPC configuration
|
||||
if config.grpc_endpoint:
|
||||
# Urls without scheme won't be parsed correctly in some python versions,
|
||||
# see https://bugs.python.org/issue27657
|
||||
grpc_endpoint_with_scheme = (
|
||||
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
|
||||
)
|
||||
grpc_p = urlparse(grpc_endpoint_with_scheme)
|
||||
grpc_host = grpc_p.hostname or "localhost"
|
||||
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
|
||||
grpc_secure = grpc_p.scheme == "grpcs"
|
||||
else:
|
||||
# Infer from HTTP endpoint as fallback
|
||||
grpc_host = host
|
||||
grpc_secure = http_secure
|
||||
grpc_port = 443 if grpc_secure else 50051
|
||||
|
||||
client = weaviate.connect_to_custom(
|
||||
http_host=host,
|
||||
http_port=http_port,
|
||||
http_secure=http_secure,
|
||||
grpc_host=grpc_host,
|
||||
grpc_port=grpc_port,
|
||||
grpc_secure=grpc_secure,
|
||||
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
|
||||
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
|
||||
)
|
||||
grpc_p = urlparse(grpc_endpoint_with_scheme)
|
||||
grpc_host = grpc_p.hostname or "localhost"
|
||||
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
|
||||
grpc_secure = grpc_p.scheme == "grpcs"
|
||||
else:
|
||||
# Infer from HTTP endpoint as fallback
|
||||
grpc_host = host
|
||||
grpc_secure = http_secure
|
||||
grpc_port = 443 if grpc_secure else 50051
|
||||
|
||||
client = weaviate.connect_to_custom(
|
||||
http_host=host,
|
||||
http_port=http_port,
|
||||
http_secure=http_secure,
|
||||
grpc_host=grpc_host,
|
||||
grpc_port=grpc_port,
|
||||
grpc_secure=grpc_secure,
|
||||
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
|
||||
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
|
||||
)
|
||||
if not client.is_ready():
|
||||
raise ConnectionError("Vector database is not ready")
|
||||
|
||||
if not client.is_ready():
|
||||
raise ConnectionError("Vector database is not ready")
|
||||
|
||||
return client
|
||||
_weaviate_client = client
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""Returns the vector database type identifier."""
|
||||
|
|
@ -196,6 +209,7 @@ class WeaviateVector(BaseVector):
|
|||
),
|
||||
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="doc_type", data_type=wc.DataType.TEXT),
|
||||
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
|
||||
],
|
||||
vector_config=wc.Configure.Vectors.self_provided(),
|
||||
|
|
@ -225,6 +239,8 @@ class WeaviateVector(BaseVector):
|
|||
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
|
||||
if "doc_id" not in existing:
|
||||
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
|
||||
if "doc_type" not in existing:
|
||||
to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT))
|
||||
if "chunk_index" not in existing:
|
||||
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,18 @@
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class AttachmentInfoDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
size: int
|
||||
|
||||
|
||||
class RetrievalChildChunk(BaseModel):
|
||||
"""Retrieval segments."""
|
||||
|
||||
|
|
@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel):
|
|||
segment: DocumentSegment
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
files: list[AttachmentInfoDict] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
|
|
|||
|
|
@ -9,8 +9,9 @@ from flask import current_app
|
|||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from dify_graph.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .index_processor_factory import IndexProcessorFactory
|
||||
|
|
@ -51,7 +52,7 @@ class IndexProcessor:
|
|||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
|
|
@ -131,7 +132,12 @@ class IndexProcessor:
|
|||
}
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
self,
|
||||
chunks: Any,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_structure: str,
|
||||
summary_index_setting: SummaryIndexSettingDict | None,
|
||||
) -> Preview:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
|
|
|
|||
|
|
@ -7,14 +7,16 @@ import os
|
|||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, Optional
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
|
|
@ -35,6 +37,13 @@ if TYPE_CHECKING:
|
|||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class SummaryIndexSettingDict(TypedDict):
|
||||
enable: bool
|
||||
model_name: NotRequired[str]
|
||||
model_provider_name: NotRequired[str]
|
||||
summary_prompt: NotRequired[str]
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files."""
|
||||
|
||||
|
|
@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC):
|
|||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
|
@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC):
|
|||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -294,7 +303,7 @@ class BaseIndexProcessor(ABC):
|
|||
logging.warning("Error downloading image from %s: %s", image_url, str(e))
|
||||
return None
|
||||
except Exception:
|
||||
logging.exception("Unexpected error downloading image from %s", image_url)
|
||||
logging.warning("Unexpected error downloading image from %s", image_url, exc_info=True)
|
||||
return None
|
||||
|
||||
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
|||
from core.model_manager import ModelInstance
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
|
|
@ -22,7 +23,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
|
|
@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
|
@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
|
@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
def generate_summary(
|
||||
tenant_id: str,
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
segment_id: str | None = None,
|
||||
document_language: str | None = None,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.db.session_factory import session_factory
|
|||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
|
|
@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
|
|||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
|
@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,13 +15,14 @@ from core.db.session_factory import session_factory
|
|||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
|
|
@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
):
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
|
@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
self,
|
||||
tenant_id: str,
|
||||
preview_texts: list[PreviewDetail],
|
||||
summary_index_setting: dict,
|
||||
summary_index_setting: SummaryIndexSettingDict,
|
||||
doc_language: str | None = None,
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from core.ops.utils import measure_time
|
|||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
|
|
@ -56,18 +56,18 @@ from core.rag.retrieval.template_prompts import (
|
|||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.knowledge_retrieval import exc
|
||||
from dify_graph.repositories.rag_retrieval_protocol import (
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import (
|
||||
KnowledgeRetrievalRequest,
|
||||
Source,
|
||||
SourceChildChunk,
|
||||
SourceMetadata,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
|
@ -83,7 +83,7 @@ from models.dataset import (
|
|||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -727,8 +727,8 @@ class DatasetRetrieval:
|
|||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict[str, Any] | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
|
|
@ -1008,7 +1008,7 @@ class DatasetRetrieval:
|
|||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=json.dumps(contents),
|
||||
source="app",
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=app_id,
|
||||
created_by_role=CreatorUserRole(user_from),
|
||||
created_by=user_id,
|
||||
|
|
@ -1181,8 +1181,8 @@ class DatasetRetrieval:
|
|||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"],
|
||||
reranking_model_name=retrieve_config.reranking_model["reranking_model_name"],
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
|
@ -1685,8 +1685,8 @@ class DatasetRetrieval:
|
|||
tenant_id: str,
|
||||
reranking_enable: bool,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None,
|
||||
weights: dict[str, Any] | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
weights: WeightsDict | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import concurrent.futures
|
|||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
|
@ -11,7 +12,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class SummaryIndex:
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
is_preview: bool,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att
|
|||
|
||||
from configs import dify_config
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
|
|
@ -146,7 +146,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
index=db_model.index,
|
||||
predecessor_node_id=db_model.predecessor_node_id,
|
||||
node_id=db_model.node_id,
|
||||
node_type=NodeType(db_model.node_type),
|
||||
node_type=db_model.node_type,
|
||||
title=db_model.title,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
|
|
|
|||
|
|
@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict):
|
|||
controller: ApiToolProviderController
|
||||
|
||||
|
||||
class EmojiIconDict(TypedDict):
|
||||
background: str
|
||||
content: str
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
|
|
@ -916,7 +921,7 @@ class ToolManager:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
|
|
@ -933,7 +938,7 @@ class ToolManager:
|
|||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
|
|
@ -950,7 +955,7 @@ class ToolManager:
|
|||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
|
@ -970,7 +975,7 @@ class ToolManager:
|
|||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> str | Mapping[str, str]:
|
||||
) -> str | EmojiIconDict | dict[str, str]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ class ToolParameterConfigurationManager:
|
|||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -13,11 +12,12 @@ from core.rag.models.document import Document as RagDocument
|
|||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Any, cast
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
|
|
@ -16,7 +17,19 @@ from models.dataset import Dataset
|
|||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
reranking_mode: NotRequired[str]
|
||||
weights: NotRequired[WeightsDict | None]
|
||||
score_threshold: NotRequired[float]
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
|
|
@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if metadata_condition and not document_ids_filter:
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
from collections.abc import Mapping
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
|
|
@ -20,10 +21,18 @@ class InterfaceDict(TypedDict):
|
|||
operation: dict[str, Any]
|
||||
|
||||
|
||||
class OpenAPISpecDict(TypedDict):
|
||||
openapi: str
|
||||
info: dict[str, str]
|
||||
servers: list[dict[str, Any]]
|
||||
paths: dict[str, Any]
|
||||
components: dict[str, Any]
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
|
@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser:
|
|||
@staticmethod
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
) -> OpenAPISpecDict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
|
@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser:
|
|||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
converted_openapi: dict[str, Any] = {
|
||||
converted_openapi: OpenAPISpecDict = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Any
|
|||
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.nodes.base.entities import OutputVariableEntity
|
||||
from dify_graph.variables.input_entities import VariableEntity
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class WorkflowToolConfigurationUtils:
|
|||
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
|
||||
nodes = graph.get("nodes", [])
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
|
||||
if node.get("data", {}).get("type") == BuiltinNodeTypes.HUMAN_INPUT:
|
||||
raise WorkflowToolHumanInputNotSupportedError()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
from typing import Final
|
||||
|
||||
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
|
||||
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
|
||||
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
|
||||
|
||||
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
TRIGGER_WEBHOOK_NODE_TYPE,
|
||||
TRIGGER_SCHEDULE_NODE_TYPE,
|
||||
TRIGGER_PLUGIN_NODE_TYPE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_trigger_node_type(node_type: str) -> bool:
|
||||
return node_type in TRIGGER_NODE_TYPES
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue